In [1]:
import pandas as pd
import numpy as np
import requests

from rpy2.robjects import pandas2ri
from rpy2.robjects.conversion import rpy2py
import rpy2.robjects as ro
import json
import jax
from collections import OrderedDict

import jax.numpy as jnp
import numpyro
from numpyro import distributions as dist
from numpyro.infer import MCMC, NUTS
import requests
from numpyro.infer import util



def model(x1, x2, x3, y):
    # Priors
    alpha = numpyro.sample("alpha", dist.Cauchy(0., 1.))
    beta1 = numpyro.sample("beta1", dist.Cauchy(0., 1.))
    beta2 = numpyro.sample("beta2", dist.Cauchy(0., 1.))
    beta3 = numpyro.sample("beta3", dist.Cauchy(0., 1.))

    # Likelihood
    mu = alpha + beta1 * x1 + beta2 * x2 + beta3 * x3
    numpyro.sample("y", dist.StudentT(df=5., loc=mu, scale=1.), obs=y)





# Load and prepare the dataset
url = "https://github.com/faosorios/heavy/blob/master/data/creatinine.rda?raw=true"
with requests.get(url) as resp:
    with open("creatinine.rda", "wb") as f:
        f.write(resp.content)

# Load RDA file into Python
ro.r['load']("creatinine.rda")
df = pandas2ri.rpy2py_dataframe(ro.r['creatinine'])

data_df = pd.DataFrame(columns=['log_SC', 'log_WT', 'log_140_minus_A', 'log_CR'])

# Apply transformations following https://openreview.net/pdf?id=HltJfwwfhX
data_df['log_SC'] = np.log(df['SC'])
data_df['log_WT'] = np.log(df['WT'])
data_df['log_CR'] = np.log(df['CR'])
data_df['log_140_minus_A'] = np.log(140 - df['Age'])
data_df = data_df.dropna() # remove any rows with NaN values after transformation

# Convert data to JAX array
data_for_numpyro = {
    'x1': jnp.array(data_df['log_SC'].values),
    'x2': jnp.array(data_df['log_WT'].values),
    'x3': jnp.array(data_df['log_140_minus_A'].values),
    'y': jnp.array(data_df['log_CR'].values),
}

# Initialize NUTS sampler
nuts_kernel = NUTS(model)

# Initialize MCMC method
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)

# Run MCMC
mcmc.run(jax.random.PRNGKey(0), **data_for_numpyro)

# Extract samples
mcmc_samples = mcmc.get_samples()


# Define your model as previously

# Define a function that computes the log density for a single sample
def single_sample_log_density(sample):
    log_density, _ = util.log_density(model, sample, **data_for_numpyro)
    return log_density

# Convert the dictionary of samples to a suitable structure for vmap
samples_array_dict = {k: jnp.stack(v) for k, v in mcmc_samples.items()}
stacked_samples = {k: v for k, v in zip(samples_array_dict.keys(), jax.tree_multimap(lambda *x: jnp.stack(x), *samples_array_dict.values()))}

# Vectorize the function using vmap
vectorized_log_density = jax.vmap(single_sample_log_density)

# Compute the log posterior densities for all samples
log_posterior_densities = vectorized_log_density(stacked_samples)







I0000 00:00:1695914510.458722       1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
sample: 100%|██████████| 1500/1500 [00:01<00:00, 974.59it/s, 63 steps of size 5.22e-02. acc. prob=0.94] 
