In [1]:
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import pyro
import seaborn as sns
import sys
import torch

from scipy import stats

import pyro.distributions as dist
import pyro.distributions.constraints as constraints

NUM_ARGS = 5
A_SIGMA_INIT = 5
G_ALPHA_INIT = 10
G_BETA_INIT = 2
ALPHA_INIT = 2
BETA_INIT = 1

# evaluate model
N_SYNTH = 200
LO = .05
HI = .95
NBINS = 20
NUM_ARGS = 4

In [2]:
def read_pickle(fn):
    with open(fn, 'rb') as handle:
        obj = pickle.load(handle)
    return obj

def get_model_inputs(train_fn, sample_fn, drug_fn):
    df = pd.read_pickle(train_fn)
    sample_dict = read_pickle(sample_fn)
    drug_dict = read_pickle(drug_fn)
    n_samp = len(sample_dict.keys())
    n_drug = len(drug_dict.keys())
    s_idx = df['s_idx'].to_numpy()
    d_idx = df['d_idx'].to_numpy()
    obs = torch.Tensor(df['log(V_V0)'])
    return n_samp, n_drug, s_idx, d_idx, obs

def model(n_samp, n_drug, s_idx, d_idx, obs):
    # create global offset
    a_sigma = pyro.param('a_sigma', torch.Tensor([A_SIGMA_INIT]), constraint=constraints.positive)
    a = pyro.sample('a', dist.Normal(torch.zeros(()), a_sigma * torch.ones(())))   
    # create s
    s_g_alpha = pyro.param('s_g_alpha', torch.Tensor([G_ALPHA_INIT]), constraint=constraints.positive)
    s_g_beta = pyro.param('s_g_beta', torch.Tensor([G_BETA_INIT]), constraint=constraints.positive)
    s_sigma = pyro.param('s_sigma', dist.Gamma(s_g_alpha, s_g_beta), constraint=constraints.positive)
    a_s_sigma = pyro.param('a_s_sigma', torch.Tensor([A_SIGMA_INIT]), constraint=constraints.positive)
    with pyro.plate('s_plate', n_samp):
        a_s = pyro.sample('a_s', dist.Normal(torch.zeros(n_samp), a_s_sigma * torch.ones(n_samp)))
        s = pyro.sample('s', dist.Normal(torch.zeros(n_samp), s_sigma * torch.ones(n_samp)))
    # create d
    d_g_alpha = pyro.param('d_g_alpha', torch.Tensor([G_ALPHA_INIT]), constraint=constraints.positive)
    d_g_beta = pyro.param('d_g_beta', torch.Tensor([G_BETA_INIT]), constraint=constraints.positive)
    d_sigma = pyro.param('d_sigma', dist.Gamma(d_g_alpha, d_g_beta), constraint=constraints.positive)
    a_d_sigma = pyro.param('a_d_sigma', torch.Tensor([A_SIGMA_INIT]), constraint=constraints.positive)
    with pyro.plate('d_plate', n_drug):
        a_d = pyro.sample('a_d', dist.Normal(torch.zeros(n_drug), a_d_sigma * torch.ones(n_drug)))
        d = pyro.sample('d', dist.Normal(torch.zeros(n_drug), d_sigma))
    # create data
    mean = s[s_idx] * d[d_idx] + a_s[s_idx] + a_d[d_idx] + a
    sigma_g_alpha = pyro.param('sigma_g_alpha', torch.Tensor([ALPHA_INIT]), constraint=constraints.positive)
    sigma_g_beta = pyro.param('sigma_g_beta', torch.Tensor([BETA_INIT]), constraint=constraints.positive)
    sigma = pyro.sample('sigma', dist.Gamma(sigma_g_alpha, sigma_g_beta))
    with pyro.plate('data_plate', obs.shape[0]):
        pyro.sample('data', dist.Normal(mean, sigma * torch.ones(obs.shape[0])), obs=obs)
        
def small_model(n_samp, n_drug, s_idx, d_idx, obs):
    with pyro.plate('s_plate', n_samp):
        s = pyro.sample('s', dist.Normal(torch.zeros(n_samp), torch.ones(n_samp)))
    mean = s[s_idx]
    with pyro.plate('data_plate', obs.shape[0]):
        pyro.sample('data', dist.Normal(mean, torch.ones(obs.shape[0])), obs=obs)
        
def generate_synthetic_samples(n_samp, n_drug, s_idx, d_idx):
    n_synth = len(s_idx)
    # # create global offset
    a = np.random.normal(loc=0, scale=A_SIGMA_INIT)  
    # create s
    s_sigma = np.random.gamma(5)
    a_s = np.random.normal(0, A_SIGMA_INIT, size=(n_samp,))
    s = np.random.normal(0, s_sigma, size=(n_samp,))
    # create d
    d_sigma = np.random.gamma(5)
    a_d = np.random.normal(0, A_SIGMA_INIT, size=(n_drug,))
    d = np.random.normal(0, d_sigma, size=(n_drug,))
    # create data
    mean = s[s_idx] * d[d_idx] + a_s[s_idx] + a_d[d_idx] + a
    sigma = np.random.gamma(5)
    data = mean + sigma * np.random.normal(loc=0, scale=1, size=(n_synth,))
    return torch.Tensor(data)

def small_synth(n_samp, n_drug, s_idx, d_idx):
    s = np.random.normal(0, 1, size=(n_samp,))
    data = s[s_idx]
    return torch.Tensor(data)

def predict(mcmc_samples, s_test_idx, d_test_idx):
    assert len(s_test_idx) == len(d_test_idx)
    n = len(s_test_idx)
    # read in mcmc samples for each variable
    s = np.array(mcmc_samples['s']) 
    d = np.array(mcmc_samples['d'])
    a = np.array(mcmc_samples['a'])
    a_s = np.array(mcmc_samples['a_s'])
    a_d = np.array(mcmc_samples['a_d'])
    sigma = np.array(mcmc_samples['sigma'])
    # combine above matrices to create mu
    m = s.shape[0]
    mu = np.multiply(s[0:m, s_test_idx], d[0:m, d_test_idx]) + a_s[0:m, s_test_idx] + a_d[0:m, d_test_idx] + a
    assert (mu.shape[0] == m) and (mu.shape[1] == n)
    assert (sigma.shape[0] == m) and (sigma.shape[1] == 1)
    return mu, sigma

def small_predict(mcmc_samples, s_test_idx, d_test_idx):
    assert len(s_test_idx) == len(d_test_idx)
    n = len(s_test_idx)
    # read in mcmc samples for each variable
    s = np.array(mcmc_samples['s']) 
    # combine above matrices to create mu
    m = s.shape[0]
    mu = s[0:m, s_test_idx]
    assert (mu.shape[0] == m) and (mu.shape[1] == n)
    return mu, 1

def r_squared(mu, test):
    means = np.mean(mu, axis=0)
    assert means.shape[0] == test.shape[0]
    pearson_corr = np.corrcoef(test, means)
    r = pearson_corr[0, 1]
    return np.power(r, 2)

# function to compute coverage
def coverage(mu, sigma, obs, hi, lo):
    # generate synthetic samples from normal distribution with mean mu
    m = mu.shape[0]
    n = mu.shape[1]
    # generate synthetic samples for each observation
    # TODO: Figure out how to get correct variance in here
    synth = mu + sigma * np.random.normal(loc=0, scale=1, size=(m, n))
    # sort synthetic samples for each observation
    sorted_synth = np.sort(synth, axis=0)
    # compute hi and lo index
    lo_idx = int(np.ceil(lo * m))
    hi_idx = int(np.floor(hi * m))
    # get synthetic samples at hi and lo indices
    lo_bound = sorted_synth[lo_idx, :]
    hi_bound = sorted_synth[hi_idx, :]
    # is obs in [hi, lo]?
    frac = np.sum(np.logical_and(lo_bound < obs, obs < hi_bound) / (1.0 * len(obs)))
    return frac

In [3]:
def get_thinning_idx(n_total, n_desired):
    idx = np.linspace(0, n_total, num=n_desired)
    assert (np.floor(idx) == np.ceil(idx)).all()
    return np.array(idx, dtype=int)

# thin mcmc_samples
def thinning(mcmc_samples, keys, n_total, n_desired):
    indices = get_thinning_idx(n_total, n_desired+1)
    thinned_samples = {}
    for key in keys:
        thinned_samples[key] = mcmc_samples[key][indices[:-1]]
    return thinned_samples

In [4]:
# read in model inputs
base_dir = '../results/2023-06-09/clean_and_split_data/split'
train_fn = base_dir + '/train.pkl'
test_fn = base_dir + '/test.pkl'
sample_fn = base_dir + '/sample_dict.pkl'
drug_fn = base_dir + '/drug_dict.pkl'

n_samp, n_drug, s_idx, d_idx, _ = get_model_inputs(train_fn, sample_fn, drug_fn)
_, _, s_test_idx, d_test_idx, _ = get_model_inputs(test_fn, sample_fn, drug_fn)

In [10]:
# SMALL MODEL
obs_train = small_synth(n_samp, n_drug, s_idx, d_idx)
obs_test = small_synth(n_samp, n_drug, s_test_idx, d_test_idx)

# fit model
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.8.4')
pyro.enable_validation(True)
logging.basicConfig(format='%(message)s', level=logging.INFO)

# Set matplotlib settings
plt.style.use('default')

n_samp, n_drug, s_idx, d_idx, obs = get_model_inputs(train_fn, sample_fn, drug_fn)
pyro.render_model(small_model, model_args=(n_samp, n_drug, s_idx, d_idx, obs_train), render_params=True, 
                  render_distributions=True)
pyro.clear_param_store()
kernel = pyro.infer.mcmc.NUTS(small_model, jit_compile=True)
mcmc = pyro.infer.MCMC(kernel, num_samples=50000, warmup_steps=500)
mcmc.run(n_samp, n_drug, s_idx, d_idx, obs_train)
mcmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}

  result = torch.tensor(0.0, device=self.device)
Sample: 100%|█| 50500/50500 [02:25, 347.28it/s, step size=6.21e-01, acc. prob=0.


In [11]:
test = obs_test.numpy()
# write samples to file
with open('../mcmc_samples.pkl', 'wb') as handle:
    pickle.dump(mcmc_samples, handle)

In [12]:
len(mcmc_samples['s'])

50000

In [13]:
# try on all samples
mu, sigma = small_predict(mcmc_samples, s_test_idx, d_test_idx)
r_sq = r_squared(mu, test)
fracs = coverage(mu, sigma, test, HI, LO)
print("fracs: " + str(fracs))
print("r_sq: " + str(r_sq))

fracs: 0.6081081081081081
r_sq: 0.4394464938252599


In [9]:
mu

array([[1.5758101 , 1.5758101 , 1.5758101 , ..., 0.8813013 , 0.8813013 ,
        0.8813013 ],
       [1.7985806 , 1.7985806 , 1.7985806 , ..., 0.75204885, 0.75204885,
        0.75204885],
       [1.8482765 , 1.8482765 , 1.8482765 , ..., 2.0345535 , 2.0345535 ,
        2.0345535 ],
       ...,
       [1.47381   , 1.47381   , 1.47381   , ..., 1.6391413 , 1.6391413 ,
        1.6391413 ],
       [1.9076245 , 1.9076245 , 1.9076245 , ..., 1.0366187 , 1.0366187 ,
        1.0366187 ],
       [1.4720681 , 1.4720681 , 1.4720681 , ..., 1.6586899 , 1.6586899 ,
        1.6586899 ]], dtype=float32)

In [14]:
# try on thinned samples
thinned_samples = thinning(mcmc_samples, ['s'], 50000, 500)
mu, sigma = small_predict(thinned_samples, s_test_idx, d_test_idx)
r_sq = r_squared(mu, test)
fracs = coverage(mu, sigma, test, HI, LO)
print("fracs: " + str(fracs))
print("r_sq: " + str(r_sq))

fracs: 0.6351351351351352
r_sq: 0.4474179130058454


In [None]:
# try on first 500 samples
thinned_samples = thinning(mcmc_samples, ['s'], 50000, 500)
mu, sigma = small_predict(thinned_samples, s_test_idx, d_test_idx)
r_sq = r_squared(mu, test)
fracs = coverage(mu, sigma, test, HI, LO)
print("fracs: " + str(fracs))
print("r_sq: " + str(r_sq))

In [None]:
#indices = np.array(np.linspace(0, 50000, num=501), dtype=int)[0:499]
thinned_samples = thinning(mcmc_samples, ['s'], 50000, 500)
mu, sigma = small_predict(thinned_samples, s_test_idx, d_test_idx)
r_sq = r_squared(mu, test)
fracs = coverage(mu, sigma, test, HI, LO)
print("fracs: " + str(fracs))
print("r_sq: " + str(r_sq))

In [None]:
no_thin_idx = np.array(range(0, 500))
no_thin_samples = thinning(mcmc_samples, 's', no_thin_idx)
test = obs_test.numpy()
mu, sigma = small_predict(no_thin_samples, s_test_idx, d_test_idx)
r_sq = r_squared(mu, test)
fracs = coverage(mu, sigma, test, HI, LO)
print("fracs: " + str(fracs))
print("r_sq: " + str(r_sq))

In [None]:
# OLD/OTHER

In [None]:
# NORMAL MODEL
obs_train = generate_synthetic_samples(n_samp, n_drug, s_idx, d_idx)
obs_test = generate_synthetic_samples(n_samp, n_drug, s_test_idx, d_test_idx)

# fit model
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.8.4')
pyro.enable_validation(True)
logging.basicConfig(format='%(message)s', level=logging.INFO)

# Set matplotlib settings
plt.style.use('default')

n_samp, n_drug, s_idx, d_idx, obs = get_model_inputs(train_fn, sample_fn, drug_fn)
pyro.render_model(model, model_args=(n_samp, n_drug, s_idx, d_idx, obs_train), render_params=True, 
                  render_distributions=True)
pyro.clear_param_store()
kernel = pyro.infer.mcmc.NUTS(model, jit_compile=True)
mcmc = pyro.infer.MCMC(kernel, num_samples=500, warmup_steps=500)
mcmc.run(n_samp, n_drug, s_idx, d_idx, obs)
mcmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}