In [1]:
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import pyro
import seaborn as sns
import sys
import torch

from scipy import stats

import pyro.distributions as dist
import pyro.distributions.constraints as constraints

NUM_ARGS = 5
A_SIGMA_INIT = 5
G_ALPHA_INIT = 10
G_BETA_INIT = 2
ALPHA_INIT = 2
BETA_INIT = 1

# evaluate model
N_SYNTH = 200
LO = .05
HI = .95
NBINS = 20
NUM_ARGS = 4

In [7]:
def read_pickle(fn):
    with open(fn, 'rb') as handle:
        obj = pickle.load(handle)
    return obj

def get_model_inputs(train_fn, sample_fn, drug_fn):
    df = pd.read_pickle(train_fn)
    sample_dict = read_pickle(sample_fn)
    drug_dict = read_pickle(drug_fn)
    n_samp = len(sample_dict.keys())
    n_drug = len(drug_dict.keys())
    s_idx = df['s_idx'].to_numpy()
    d_idx = df['d_idx'].to_numpy()
    obs = torch.Tensor(df['log(V_V0)'])
    return n_samp, n_drug, s_idx, d_idx, obs

def model(n_samp, n_drug, s_idx, d_idx, obs=None, n_obs=None):
    if n_obs is None:
        print('Error: n_obs is none!')
    # create global offset
    a_sigma = pyro.param('a_sigma', torch.Tensor([A_SIGMA_INIT]), constraint=constraints.positive)
    a = pyro.sample('a', dist.Normal(torch.zeros(()), a_sigma * torch.ones(())))   
    # create s
    s_g_alpha = pyro.param('s_g_alpha', torch.Tensor([G_ALPHA_INIT]), constraint=constraints.positive)
    s_g_beta = pyro.param('s_g_beta', torch.Tensor([G_BETA_INIT]), constraint=constraints.positive)
    s_sigma = pyro.param('s_sigma', dist.Gamma(s_g_alpha, s_g_beta), constraint=constraints.positive)
    a_s_sigma = pyro.param('a_s_sigma', torch.Tensor([A_SIGMA_INIT]), constraint=constraints.positive)
    with pyro.plate('s_plate', n_samp):
        a_s = pyro.sample('a_s', dist.Normal(torch.zeros(n_samp), a_s_sigma * torch.ones(n_samp)))
        s = pyro.sample('s', dist.Normal(torch.zeros(n_samp), s_sigma * torch.ones(n_samp)))
    # create d
    d_g_alpha = pyro.param('d_g_alpha', torch.Tensor([G_ALPHA_INIT]), constraint=constraints.positive)
    d_g_beta = pyro.param('d_g_beta', torch.Tensor([G_BETA_INIT]), constraint=constraints.positive)
    d_sigma = pyro.param('d_sigma', dist.Gamma(d_g_alpha, d_g_beta), constraint=constraints.positive)
    a_d_sigma = pyro.param('a_d_sigma', torch.Tensor([A_SIGMA_INIT]), constraint=constraints.positive)
    with pyro.plate('d_plate', n_drug):
        a_d = pyro.sample('a_d', dist.Normal(torch.zeros(n_drug), a_d_sigma * torch.ones(n_drug)))
        d = pyro.sample('d', dist.Normal(torch.zeros(n_drug), d_sigma))
    # create data
    mean = s[s_idx] * d[d_idx] + a_s[s_idx] + a_d[d_idx] + a
    sigma_g_alpha = pyro.param('sigma_g_alpha', torch.Tensor([ALPHA_INIT]), constraint=constraints.positive)
    sigma_g_beta = pyro.param('sigma_g_beta', torch.Tensor([BETA_INIT]), constraint=constraints.positive)
    sigma = pyro.sample('sigma', dist.Gamma(sigma_g_alpha, sigma_g_beta))
    with pyro.plate('data_plate', n_obs):
        data = pyro.sample('data', dist.Normal(mean, sigma * torch.ones(n_obs)), obs=obs)
    return data

def predict(mcmc_samples, s_test_idx, d_test_idx):
    assert len(s_test_idx) == len(d_test_idx)
    n = len(s_test_idx)
    # read in mcmc samples for each variable
    s = np.array(mcmc_samples['s']) 
    d = np.array(mcmc_samples['d'])
    a = np.array(mcmc_samples['a'])
    a_s = np.array(mcmc_samples['a_s'])
    a_d = np.array(mcmc_samples['a_d'])
    sigma = np.array(mcmc_samples['sigma'])
    # combine above matrices to create mu
    m = s.shape[0]
    mu = np.multiply(s[0:m, s_test_idx], d[0:m, d_test_idx]) + a_s[0:m, s_test_idx] + a_d[0:m, d_test_idx] + a
    assert (mu.shape[0] == m) and (mu.shape[1] == n)
    assert (sigma.shape[0] == m) and (sigma.shape[1] == 1)
    return mu, sigma

def r_squared(mu, test):
    means = np.mean(mu, axis=0)
    assert means.shape[0] == test.shape[0]
    pearson_corr = np.corrcoef(test, means)
    r = pearson_corr[0, 1]
    return np.power(r, 2)

# function to compute coverage
def coverage(mu, sigma, obs, hi, lo):
    # generate synthetic samples from normal distribution with mean mu
    m = mu.shape[0]
    n = mu.shape[1]
    # generate synthetic samples for each observation
    # TODO: Figure out how to get correct variance in here
    synth = mu + sigma * np.random.normal(loc=0, scale=1, size=(m, n))
    # sort synthetic samples for each observation
    sorted_synth = np.sort(synth, axis=0)
    # compute hi and lo index
    lo_idx = int(np.ceil(lo * m))
    hi_idx = int(np.floor(hi * m))
    # get synthetic samples at hi and lo indices
    lo_bound = sorted_synth[lo_idx, :]
    hi_bound = sorted_synth[hi_idx, :]
    # is obs in [hi, lo]?
    frac = np.sum(np.logical_and(lo_bound < obs, obs < hi_bound) / (1.0 * len(obs)))
    return frac

In [8]:
def get_thinning_idx(n_total, n_desired):
    idx = np.linspace(0, n_total, num=n_desired)
    assert (np.floor(idx) == np.ceil(idx)).all()
    return np.array(idx, dtype=int)

# thin mcmc_samples
def thinning(mcmc_samples, keys, n_total, n_desired):
    indices = get_thinning_idx(n_total, n_desired+1)
    thinned_samples = {}
    for key in keys:
        thinned_samples[key] = mcmc_samples[key][indices[:-1]]
    return thinned_samples

In [9]:
# read in model inputs
base_dir = '../results/2023-06-09/clean_and_split_data/split'
train_fn = base_dir + '/train.pkl'
test_fn = base_dir + '/test.pkl'
sample_fn = base_dir + '/sample_dict.pkl'
drug_fn = base_dir + '/drug_dict.pkl'

n_samp, n_drug, s_idx, d_idx, _ = get_model_inputs(train_fn, sample_fn, drug_fn)
_, _, s_test_idx, d_test_idx, _ = get_model_inputs(test_fn, sample_fn, drug_fn)

In [10]:
# generate synthetic data
n_train = len(s_idx)
n_test = len(s_test_idx)
n_obs = n_train + n_test
s_indices = np.concatenate((s_idx, s_test_idx))
d_indices = np.concatenate((d_idx, d_test_idx))
obs = model(n_samp, n_drug, s_indices, d_indices, n_obs=n_obs)
obs_train = obs[0:n_train]
obs_test = obs[n_train:]

In [11]:
# fit the model
pyro.clear_param_store()
kernel = pyro.infer.mcmc.NUTS(model, jit_compile=True)
mcmc = pyro.infer.MCMC(kernel, num_samples=500, warmup_steps=500)
mcmc.run(n_samp, n_drug, s_idx, d_idx, obs=obs_train, n_obs=n_train)
mcmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}

  a_sigma = pyro.param('a_sigma', torch.Tensor([A_SIGMA_INIT]), constraint=constraints.positive)
  s_g_alpha = pyro.param('s_g_alpha', torch.Tensor([G_ALPHA_INIT]), constraint=constraints.positive)
  s_g_beta = pyro.param('s_g_beta', torch.Tensor([G_BETA_INIT]), constraint=constraints.positive)
  a_s_sigma = pyro.param('a_s_sigma', torch.Tensor([A_SIGMA_INIT]), constraint=constraints.positive)
  d_g_alpha = pyro.param('d_g_alpha', torch.Tensor([G_ALPHA_INIT]), constraint=constraints.positive)
  d_g_beta = pyro.param('d_g_beta', torch.Tensor([G_BETA_INIT]), constraint=constraints.positive)
  a_d_sigma = pyro.param('a_d_sigma', torch.Tensor([A_SIGMA_INIT]), constraint=constraints.positive)
  sigma_g_alpha = pyro.param('sigma_g_alpha', torch.Tensor([ALPHA_INIT]), constraint=constraints.positive)
  sigma_g_beta = pyro.param('sigma_g_beta', torch.Tensor([BETA_INIT]), constraint=constraints.positive)
  result = torch.tensor(0.0, device=self.device)
Sample: 100%|█| 1000/1000 [02:31,  6.59it/s

In [None]:
test = obs_test.numpy()
# write samples to file
with open('../mcmc_samples.pkl', 'wb') as handle:
    pickle.dump(mcmc_samples, handle)

In [None]:
mu, sigma = predict(mcmc_samples, s_test_idx, d_test_idx)
r_sq = r_squared(mu, test)
fracs = coverage(mu, sigma, test, HI, LO)
print("fracs: " + str(fracs))
print("r_sq: " + str(r_sq))