In [2]:
import pandas as pd
import numpy as np
import os
import pickle
import matplotlib.pyplot as plt
import pymc as pm
import arviz as az

regions = pickle.load(open('test/luad34.regions.entropies.pkl', 'rb'))
data = regions[regions.chrom != 19].loc[:, 'log2_corrected'].values

minibatches = pm.Minibatch(data, batch_size=1_000)
states = np.array([[c, c-1] for c in range(1, 6)] + 
                    [[c, c] for c in range(6)] + 
                    [[c, c+1] for c in range(6)] +
                        [[c, 2] for c in range(6)])
n_states = states.shape[0]

with pm.Model() as model:
    n = pm.Beta('n', alpha=15, beta=1.5)
    s = pm.Beta('s', alpha=3, beta=15)
    phi = pm.Gamma('phi', alpha=2.5, beta=2)
    w = pm.Dirichlet('w', a=np.ones(n_states))
    
    mu = pm.Deterministic('mu', np.log((2 * n + (1 - n) * ((1 - s) * states[:, 0] + s * states[:, 1])) / (2 * n + (1 - n) * phi)))
    
    likelihood = pm.Mixture('likelihood', w=w, comp_dists=pm.Constant.dist(mu=mu), observed=minibatches)
    
    %time trace = pm.sample(1000, tune=1000, cores=4, chains=1)


: 

In [9]:
samples = approx.sample(1000)




KeysView(Inference data with groups:
	> posterior)

In [None]:
samples
# values = {}
# for i, state in enumerate(states):
#     mu = trace['mu'][:, i].mean()
#     mu_sd = trace['mu'][:, i].std()
#     w = trace['w'][:, i].mean()
#     nu = trace['nu'][:, i].mean()
#     values[i] = [state[0], state[1], mu, mu_sd, w, nu]
# df = pd.DataFrame(values, index=['state_c', 'state_cs', 'mu', 'mu_sd', 'w', 'nu']).T
# df.sort_values(['state_c', 'state_cs'])[df.w > 0.01]

In [None]:
import numpyro
import numpyro.distributions as dist
from numpyro.infer import HMCECS, MCMC, NUTS
from numpyro.handlers import seed, trace
import blackjax
import jax
from jax import random, vmap
import jax.numpy as jnp
import pickle
from numpyro.infer.util import initialize_model
from numpyro.infer import Predictive, SVI, Trace_ELBO
from numpyro.distributions import constraints
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16"

def model(states, n_states, data=None):

    weights = numpyro.sample("weights", dist.Dirichlet(concentration=jnp.ones(n_states)))
    n = numpyro.sample("n", dist.Beta(15, 1.5))
    s = numpyro.sample("s", dist.Beta(3, 15))
    phi = numpyro.sample("phi", dist.Gamma(2.5, 2))

    mu = numpyro.deterministic("mu", jnp.log((2 * n + (1 - n) * ((1 - s) * states[:, 0] + s * states[:, 1])) / (2 * n + (1 - n) * phi)))

    with numpyro.plate("states", n_states):
        nu = numpyro.sample("nu", dist.Gamma(2, 2))
        scale = numpyro.sample("scale", dist.Gamma(2, 2))

    with numpyro.plate("data", len(data)):
        numpyro.sample("obs", dist.MixtureSameFamily(dist.Categorical(weights), dist.StudentT(nu, mu, scale)), obs=data)

def guide(states, n_states, data=None):
    
    # Variational parameters for the weights
    alpha_q = numpyro.param("alpha_q", jnp.ones(n_states), constraint=constraints.positive)
    weights_q = numpyro.sample("weights", dist.Dirichlet(concentration=alpha_q))
    
    # Variational parameters for n
    n_a_q = numpyro.param("n_a_q", 1.0, constraint=constraints.positive)
    n_b_q = numpyro.param("n_b_q", 1.0, constraint=constraints.positive)
    n_q = numpyro.sample("n", dist.Beta(n_a_q, n_b_q))
    
    # Variational parameters for s
    s_a_q = numpyro.param("s_a_q", 1.0, constraint=constraints.positive)
    s_b_q = numpyro.param("s_b_q", 1.0, constraint=constraints.positive)
    s_q = numpyro.sample("s", dist.Beta(s_a_q, s_b_q))
    
    # Variational parameters for phi
    phi_alpha_q = numpyro.param("phi_alpha_q", 1.0, constraint=constraints.positive)
    phi_beta_q = numpyro.param("phi_beta_q", 1.0, constraint=constraints.positive)
    phi_q = numpyro.sample("phi", dist.Gamma(phi_alpha_q, phi_beta_q))
    
    with numpyro.plate("states", n_states):
        # Variational parameters for nu
        nu_alpha_q = numpyro.param("nu_alpha_q", 1.0, constraint=constraints.positive)
        nu_beta_q = numpyro.param("nu_beta_q", 1.0, constraint=constraints.positive)
        nu_q = numpyro.sample("nu", dist.Gamma(nu_alpha_q, nu_beta_q))
        
        # Variational parameters for scale
        scale_alpha_q = numpyro.param("scale_alpha_q", 1.0, constraint=constraints.positive)
        scale_beta_q = numpyro.param("scale_beta_q", 1.0, constraint=constraints.positive)
        scale_q = numpyro.sample("scale", dist.Gamma(scale_alpha_q, scale_beta_q))


states = jnp.array([[c, c-1] for c in range(1, 6)] + 
                   [[c, c] for c in range(6)] + 
                   [[c, c+1] for c in range(6)] +
                    [[c, 2] for c in range(6)])
n_states = len(states)

rng_key = random.PRNGKey(0)

regions = pickle.load(open('test/luad34.regions.entropies.pkl', 'rb'))
data = jnp.array(regions[regions.chrom != 19].loc[:, 'log2_corrected'].values)

optimizer = numpyro.optim.Adam(step_size=0.005)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(rng_key, 1000, states, n_states, data=data)
params = svi_result.params
predictive = Predictive(guide, params=params, num_samples=1000)
posterior_samples = predictive(random.PRNGKey(1), states, n_states, data=None)
posterior_samples['n'].mean(), posterior_samples['s'].mean(), posterior_samples['phi'].mean()

In [20]:
params = svi_result.params
# inferred_mean = params["alpha_q"] / (params["alpha_q"] + params["beta_q"])
# use guide to make predictive
# predictive = Predictive(model, guide=guide, params=params, num_samples=1000)
# samples = predictive(random.PRNGKey(1), states, n_states, data=None)
# samples
# get posterior samples
predictive = Predictive(guide, params=params, num_samples=1000)
posterior_samples = predictive(random.PRNGKey(1), states, n_states, data=None)
# posterior_samples['n'].mean(), posterior_samples['s'].mean(), posterior_samples['phi'].mean()
# use posterior samples to make predictive
predictive = Predictive(model, posterior_samples, params=params, num_samples=1000)
samples = predictive(random.PRNGKey(1), states, n_states, data=None)
samples['n'].mean(), samples['s'].mean(), samples['phi'].mean()

TypeError: object of type 'NoneType' has no len()

In [None]:
import pandas as pd
import numpy as np
import os
import pickle
import concurrent.futures
import matplotlib.pyplot as plt
import pymc as pm
from pymc.sampling.jax import get_jaxified_graph
import arviz as az
import pytensor
from pymc.sampling_jax import get_jaxified_logp
import blackjax
import jax
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16"

regions = pickle.load(open('test/luad34.regions.entropies.pkl', 'rb'))
data = regions[regions.chrom != 19].loc[:, 'log2_corrected'].values

states = np.array([[c, c-1] for c in range(1, 6)] + 
                    [[c, c] for c in range(6)] + 
                    [[c, c+1] for c in range(6)] +
                        [[c, 2] for c in range(6)])
n_states = states.shape[0]

with pm.Model() as model:
    n = pm.Beta('n', alpha=99, beta=1)
    s = pm.Beta('s', alpha=10, beta=90)
    phi = pm.Gamma('phi', alpha=3.5, beta=1)

    w = pm.Dirichlet('w', a = np.ones(n_states))

    mu = pm.Deterministic('mu', np.log((2 * n + (1 - n) * ((1 - s) * states[:, 0] + s * states[:, 1])) / (2 * n + (1 - n) * phi)))
    lambda_ = pm.Gamma('lambda', alpha=3, beta=1, shape=n_states)
    nu = pm.Gamma('nu', alpha=3, beta=1, shape=n_states)

    obs = pm.Mixture('obs', w, pm.StudentT.dist(mu=mu, lam=lambda_, nu=nu, shape=n_states), observed=data)


def build_pymc_loglik_logprior(model):
    loglike_fn = get_jaxified_graph(
        inputs=model.value_vars + [model.rvs_to_values[obs]], 
        outputs=[model.datalogp],
        )
    def pymc_loglikelihood(theta, x):
        return loglike_fn(*theta, x)[0]


    logp_fn = get_jaxified_graph(
        outputs=[model.varlogp],
        )
    def pymc_logprior(theta):
        return logp_fn(*theta)[0]

    return pymc_loglikelihood, pymc_logprior

In [2]:
loglikelihood_fn, logprior_fn = build_pymc_loglik_logprior(model)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [None]:
def sample_fn(rng_key, num_samples):
    key1, key2, key3 = jax.random.split(rng_key, 3)
    prob_mixture = jax.random.bernoulli(key1, p=0.5, shape=(num_samples, 1))
    mixture_1 = jax.random.normal(key2, shape=(num_samples, 1)) * sigma + mu
    mixture_2 = jax.random.normal(key3, shape=(num_samples, 1)) * sigma + gamma - mu
    return prob_mixture * mixture_1 + (1 - prob_mixture) * mixture_2
data_size = 1000

rng_key = jax.random.PRNGKey(888)
rng_key, sample_key = jax.random.split(rng_key)
X_data = sample_fn(sample_key, data_size)

In [7]:
rvs = [rv.name for rv in model.value_vars]
init_position_dict = model.initial_point()
init_position = [init_position_dict[rv] for rv in rvs]
init_position

[array(4.59511985),
 array(-2.19722458),
 array(1.25276297),
 array([8.8817842e-16, 8.8817842e-16, 8.8817842e-16, 8.8817842e-16,
        8.8817842e-16, 8.8817842e-16, 8.8817842e-16, 8.8817842e-16,
        8.8817842e-16, 8.8817842e-16, 8.8817842e-16, 8.8817842e-16,
        8.8817842e-16, 8.8817842e-16, 8.8817842e-16, 8.8817842e-16,
        8.8817842e-16, 8.8817842e-16, 8.8817842e-16, 8.8817842e-16,
        8.8817842e-16, 8.8817842e-16]),
 array([1.09861229, 1.09861229, 1.09861229, 1.09861229, 1.09861229,
        1.09861229, 1.09861229, 1.09861229, 1.09861229, 1.09861229,
        1.09861229, 1.09861229, 1.09861229, 1.09861229, 1.09861229,
        1.09861229, 1.09861229, 1.09861229, 1.09861229, 1.09861229,
        1.09861229, 1.09861229, 1.09861229]),
 array([1.09861229, 1.09861229, 1.09861229, 1.09861229, 1.09861229,
        1.09861229, 1.09861229, 1.09861229, 1.09861229, 1.09861229,
        1.09861229, 1.09861229, 1.09861229, 1.09861229, 1.09861229,
        1.09861229, 1.09861229, 1.098

In [12]:
from fastprogress import progress_bar
import jax.numpy as jnp

import blackjax
import blackjax.sgmcmc.gradients as gradients

# Specify hyperparameters for SGLD
total_iter = 10_000
thinning_factor = 10

batch_size = 100
lr = 1e-3
temperature = 50.0


# Build the SGDL sampler
grad_fn = gradients.grad_estimator(logprior_fn, loglikelihood_fn, data.shape[0])
sgld = blackjax.sgld(grad_fn)


# Initialize and take one step using the vanilla SGLD algorithm
position = init_position
sgld_sample_list = jnp.array([])


In [20]:
jax.jit(sgld)(position, data, batch_size, lr, temperature)

TypeError: Expected a callable value, got SamplingAlgorithm(init=<function sgld.__new__.<locals>.init_fn at 0x7f4fbd160220>, step=<function sgld.__new__.<locals>.step_fn at 0x7f4fbd1609a0>)

In [25]:

pb = progress_bar(range(total_iter))
for iter_ in pb:
    rng_key, batch_key, sample_key = jax.random.split(rng_key, 3)
    data_batch = jax.random.shuffle(batch_key, data)[:batch_size, :]
    position = jax.jit(sgld.step)(sample_key, position, data_batch, lr, temperature)
    if iter_ % thinning_factor == 0:
        sgld_sample_list = jnp.append(sgld_sample_list, position)
        pb.comment = f"| position: {position: .2f}"

 |----------------------------------------| 0.00% [0/10000 00:00<?]



IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1.

In [24]:
import jax
import jax.numpy as jnp
from jax import random, vmap, jit

import numpy as np
import pymc as pm
import pymc.sampling_jax
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")

from sgmcmcjax.samplers import build_sgld_sampler, build_sgldCV_sampler
import optax
from sgmcmcjax.optimizer import build_optax_optimizer

print(f"Running on PyMC v{pm.__version__}")

Running on PyMC v5.7.2


In [25]:
from pymc.sampling_jax import get_jaxified_graph

def build_pymc_loglik_logprior(model):

    loglike_fn = get_jaxified_graph(
        inputs=model.value_vars + [model.rvs_to_values[obs]], 
        outputs=[model.datalogp],
        )
    def pymc_loglikelihood(theta, x):
        return loglike_fn(*theta, x)[0]


    logp_fn = get_jaxified_graph(
        outputs=[model.varlogp],
        )
    def pymc_logprior(theta):
        return logp_fn(*theta)[0]

    return pymc_loglikelihood, pymc_logprior

In [36]:
batch_size = int(0.1*len(data))
dt = 1e-3
key = random.PRNGKey(0)

rvs = [rv.name for rv in model.value_vars]
init_position_dict = model.initial_point()
init_position = [init_position_dict[rv] for rv in rvs]
init_position
sgld_sampler = build_sgld_sampler(dt, loglikelihood_fn, logprior_fn, (data,), batch_size)

In [37]:
%time samples = sgld_sampler(key, 1_000, init_position)

  0%|          | 0/10000 [00:00<?, ?it/s]

In [None]:
import jax
import jax.numpy as jnp
import pickle
from jax.scipy import stats
from jax import random, jit, vmap
from sgmcmcjax.samplers import build_sgld_sampler
import matplotlib.pyplot as plt
from functools import partial
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16"

def flatten_params(weights, n, s, phi, nu, scale):
    return (weights, n, s, phi, nu, scale)

def unflatten_params(params):
    return params

def log_prior(params):
    weights, n, s, phi, nu, scale = unflatten_params(params)
    
    logp = 0.0
    # weights ~ Dirichlet
    logp += jnp.sum(stats.dirichlet.logpdf(weights, jnp.ones(len(weights))))
    # n ~ Beta
    logp += stats.beta.logpdf(n, 99, 1)
    # s ~ Beta
    logp += stats.beta.logpdf(s, 10, 90)
    # phi ~ Gamma
    logp += stats.gamma.logpdf(phi, 3, 1)
    # nu ~ Gamma (assumed to be elementwise independent)
    logp += jnp.sum(stats.gamma.logpdf(nu, 1, 1))
    # scale ~ Gamma (assumed to be elementwise independent)
    logp += jnp.sum(stats.gamma.logpdf(scale, 1, 1))

    return logp

def categorical_logpmf(probs, sample_idx):
    return jnp.log(probs[sample_idx])

def log_likelihood(params, data):
    weights, n, s, phi, nu, scale = unflatten_params(params)
    
    mu = jnp.log((2 * n + (1 - n) * ((1 - s) * states[:, 0] + s * states[:, 1])) / (2 * n + (1 - n) * phi))
    mixture_log_probs = vmap(lambda nu_, scale_: stats.t.logpdf(data, df=nu_, loc=mu, scale=scale_))(nu, scale)
    
    # Compute the index of the highest log probability for each data point
    sample_idxs = jnp.argmax(mixture_log_probs, axis=-1)
    
    # Compute the log likelihood using the categorical distribution
    logp = jnp.sum(categorical_logpmf(weights, sample_idxs))

    return logp

states = jnp.array([[c, c-1] for c in range(1, 6)] + 
                    [[c, c] for c in range(6)] + 
                    [[c, c+1] for c in range(6)] +
                        [[c, 2] for c in range(6)])
D = states.shape[0]

# Example data and parameters setup
key = random.PRNGKey(0)
regions = pickle.load(open('test/luad34.regions.entropies.pkl', 'rb'))
X_data = jnp.array(regions[regions.chrom != 19].loc[:, 'log2_corrected'].values)
batch_size = int(0.1 * len(X_data))
dt = 1e-5

params_init = flatten_params(jnp.ones(D)/D, 0.99, 0.1, 2.0, jnp.ones(D), jnp.ones(D))

my_sampler = build_sgld_sampler(dt, log_likelihood, log_prior, (X_data,), batch_size)
my_sampler = partial(jit, static_argnums=(1,))(my_sampler)
Nsamples = 10000
samples = my_sampler(key, Nsamples, params_init)


In [5]:
import jax.numpy as jnp
import pandas as pd

def summary(samples):
    # Assuming `samples` is a tuple containing arrays for each parameter
    weights_samples, n_samples, s_samples, phi_samples, nu_samples, scale_samples = samples

    summary_data = {}

    for name, param_samples in zip(["weights", "n", "s", "phi", "nu", "scale"], samples):
        # If parameter has multiple dimensions (e.g., weights, nu, scale), handle each dimension separately
        if len(param_samples.shape) > 1:
            for dim in range(param_samples.shape[1]):
                summary_data[name + f"[{dim}]"] = {
                    "Mean": jnp.mean(param_samples[:, dim]),
                    "StdDev": jnp.std(param_samples[:, dim]),
                    "Median": jnp.median(param_samples[:, dim]),
                    "25%": jnp.percentile(param_samples[:, dim], 25),
                    "75%": jnp.percentile(param_samples[:, dim], 75)
                }
        else:
            summary_data[name] = {
                "Mean": jnp.mean(param_samples),
                "StdDev": jnp.std(param_samples),
                "Median": jnp.median(param_samples),
                "25%": jnp.percentile(param_samples, 25),
                "75%": jnp.percentile(param_samples, 75)
            }

    # Create DataFrame and print
    df = pd.DataFrame.from_dict(summary_data, orient='index')
    return df

df = summary(samples)
# show all rows in the dataframe
pd.set_option('display.max_rows', None)
df.head(100)

Unnamed: 0,Mean,StdDev,Median,25%,75%
weights[0],702.17456,247.88286,744.7227,526.60516,911.9613
weights[1],22.515112,0.2057682,22.575712,22.385656,22.675909
weights[2],24.320227,0.0857654,24.318047,24.256092,24.382143
weights[3],4.52515,0.1667575,4.5355177,4.4063683,4.66517
weights[4],7.4071264,0.07805767,7.416716,7.356103,7.4618773
weights[5],457.7731,0.09321319,457.77094,457.70544,457.8487
weights[6],6.270423,0.26262575,6.2243795,6.050833,6.351941
weights[7],5.073107,0.11321778,5.094407,5.00559,5.141355
weights[8],4.306258,0.18179582,4.2425947,4.142879,4.494214
weights[9],5.2620564,0.10695958,5.2659564,5.230654,5.3148565
