In [1]:
import numpyro
import numpyro.distributions as dist
from numpyro.infer import HMCECS, MCMC, NUTS
import jax
from jax import random, vmap
import jax.numpy as jnp
import pickle

import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16"

def t_mixture_model(states, data=None):
    n_states = len(states)

    weights = numpyro.sample("weights", dist.Dirichlet(concentration=jnp.ones(n_states)))
    n = numpyro.sample("n", dist.Beta(15, 1.5))
    s = numpyro.sample("s", dist.Beta(3, 15))
    phi = numpyro.sample("phi", dist.Gamma(2.5, 2))

    mu = numpyro.deterministic("mu", jnp.log((2 * n + (1 - n) * ((1 - s) * states[:, 0] + s * states[:, 1])) / (2 * n + (1 - n) * phi)))

    with numpyro.plate("states", n_states):
        scale = numpyro.sample("scale", dist.Gamma(2, 2))

    with numpyro.plate("data", data.shape[0], subsample_size=1_000):
        batch = numpyro.subsample(data, event_dim=0)
        numpyro.sample("obs", dist.MixtureSameFamily(dist.Categorical(weights), dist.Normal(mu, scale), validate_args=False), obs=batch)

states = jnp.array([[c, c-1] for c in range(1, 6)] + 
                   [[c, c] for c in range(6)] + 
                   [[c, c+1] for c in range(6)] +
                    [[c, 2] for c in range(6)])


regions = pickle.load(open('test/luad34.regions.entropies.pkl', 'rb'))
data = jnp.array(regions[regions.chrom != 19].loc[:, 'log2_corrected'].values)

kernel = HMCECS(NUTS(t_mixture_model), num_blocks=100)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=4)
%time mcmc.run(random.PRNGKey(0), states, data=data)
mcmc.print_summary(exclude_deterministic=False)


No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

CPU times: user 25min 43s, sys: 14min 28s, total: 40min 11s
Wall time: 8min 22s

                 mean       std    median      5.0%     95.0%     n_eff     r_hat
      mu[0]     -0.02      0.02     -0.03     -0.04      0.01      2.00   1026.73
      mu[1]      0.00      0.03     -0.01     -0.02      0.05      2.00    630.89
      mu[2]      0.02      0.03      0.00     -0.00      0.08      2.00    408.32
      mu[3]      0.04      0.04      0.02      0.02      0.11      2.00    327.84
      mu[4]      0.06      0.04      0.04      0.03      0.13      2.00    287.55
      mu[5]     -0.04      0.02     -0.04     -0.06     -0.02      2.00    274.09
      mu[6]     -0.02      0.02     -0.02     -0.04      0.01      2.00    751.18
      mu[7]      0.00      0.02     -0.01     -0.02      0.05      2.00    591.29
      mu[8]      0.02      0.03      0.01      0.00      0.08      2.00    394.63
      mu[9]      0.04      0.04      0.03      0.02      0.11      2.00    318.47
     mu[10]      

In [None]:
import pandas as pd
import numpy as np
import os
import pickle
import matplotlib.pyplot as plt
import pymc as pm
import arviz as az

regions = pickle.load(open('test/luad34.regions.entropies.pkl', 'rb'))
data = regions[regions.chrom != 19].loc[:, 'log2_corrected'].values

minibatches = pm.Minibatch(data, batch_size=1_000)
states = np.array([[c, c-1] for c in range(1, 6)] + 
                    [[c, c] for c in range(6)] + 
                    [[c, c+1] for c in range(6)] +
                        [[c, 2] for c in range(6)])
n_states = states.shape[0]

with pm.Model() as model:
    n = pm.Beta('n', alpha=15, beta=1.5)
    s = pm.Beta('s', alpha=3, beta=15)
    phi = pm.Gamma('phi', alpha=2.5, beta=2)
    w = pm.Dirichlet('w', a=np.ones(n_states))
    
    mu = pm.Deterministic('mu', np.log((2 * n + (1 - n) * ((1 - s) * states[:, 0] + s * states[:, 1])) / (2 * n + (1 - n) * phi)))
    
    likelihood = pm.Categorical('likelihood', p=w, observed=sta
    
    %time trace = pm.sample(1000, tune=1000, cores=4, chains=1)

In [26]:
trace.stat_names

{'depth',
 'diverging',
 'energy',
 'energy_error',
 'index_in_trajectory',
 'largest_eigval',
 'max_energy_error',
 'mean_tree_accept',
 'model_logp',
 'perf_counter_diff',
 'perf_counter_start',
 'process_time_diff',
 'reached_max_treedepth',
 'smallest_eigval',
 'step_size',
 'step_size_bar',
 'tree_size',
 'tune',

In [57]:
values = {}
for i, state in enumerate(states):
    mu = trace['mu'][:, i].mean()
    mu_sd = trace['mu'][:, i].std()
    w = trace['w'][:, i].mean()
    nu = trace['nu'][:, i].mean()
    values[i] = [state[0], state[1], mu, mu_sd, w, nu]
df = pd.DataFrame(values, index=['state_c', 'state_cs', 'mu', 'mu_sd', 'w', 'nu']).T
df.sort_values(['state_c', 'state_cs'])[df.w > 0.01]

  df.sort_values(['state_c', 'state_cs'])[df.w > 0.01]


Unnamed: 0,state_c,state_cs,mu,mu_sd,w,nu
17,0.0,2.0,-0.00764,0.010755,0.207221,2.056432
12,1.0,2.0,-0.000385,0.010913,0.015526,0.908837
1,2.0,1.0,0.005392,0.01131,0.011749,1.193112
7,2.0,2.0,0.006807,0.01205,0.010975,0.731556
21,4.0,2.0,0.021003,0.016176,0.019126,1.315985
3,4.0,3.0,0.02239,0.017063,0.016499,0.446803
9,4.0,4.0,0.023774,0.017951,0.209064,1.279603
22,5.0,2.0,0.028011,0.018715,0.2075,1.349599
4,5.0,4.0,0.030757,0.020508,0.205462,1.905273
10,5.0,5.0,0.032126,0.021402,0.011359,1.053316
