In [26]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import pyro
import seaborn as sns
import sys
import torch

from scipy import stats

import pyro.distributions as dist
import pyro.distributions.constraints as constraints

NUM_ARGS = 5
A_SIGMA_INIT = 5
G_ALPHA_INIT = 10
G_BETA_INIT = 2
ALPHA_INIT = 2
BETA_INIT = 1

In [29]:
def read_pickle(fn):
    with open(fn, 'rb') as handle:
        obj = pickle.load(handle)
    return obj

def get_model_inputs(train_fn, sample_fn, drug_fn):
    df = pd.read_pickle(train_fn)
    sample_dict = read_pickle(sample_fn)
    drug_dict = read_pickle(drug_fn)
    n_samp = len(sample_dict.keys())
    n_drug = len(drug_dict.keys())
    s_idx = df['s_idx'].to_numpy()
    d_idx = df['d_idx'].to_numpy()
    obs = torch.Tensor(df['log(V_V0)'])
    return n_samp, n_drug, s_idx, d_idx, obs

def model(n_samp, n_drug, s_idx, d_idx, n_synth, obs):
    # create global offset
    a_sigma = pyro.param('a_sigma', torch.Tensor([A_SIGMA_INIT]), constraint=constraints.positive)
    a = pyro.sample('a', dist.Normal(torch.zeros(()), a_sigma * torch.ones(())))   
    # create s
    s_g_alpha = pyro.param('s_g_alpha', torch.Tensor([G_ALPHA_INIT]), constraint=constraints.positive)
    s_g_beta = pyro.param('s_g_beta', torch.Tensor([G_BETA_INIT]), constraint=constraints.positive)
    s_sigma = pyro.param('s_sigma', dist.Gamma(s_g_alpha, s_g_beta), constraint=constraints.positive)
    a_s_sigma = pyro.param('a_s_sigma', torch.Tensor([A_SIGMA_INIT]), constraint=constraints.positive)
    with pyro.plate('s_plate', n_samp):
        a_s = pyro.sample('a_s', dist.Normal(torch.zeros(n_samp), a_s_sigma * torch.ones(n_samp)))
        s = pyro.sample('s', dist.Normal(torch.zeros(n_samp), s_sigma * torch.ones(n_samp)))
    # create d
    d_g_alpha = pyro.param('d_g_alpha', torch.Tensor([G_ALPHA_INIT]), constraint=constraints.positive)
    d_g_beta = pyro.param('d_g_beta', torch.Tensor([G_BETA_INIT]), constraint=constraints.positive)
    d_sigma = pyro.param('d_sigma', dist.Gamma(d_g_alpha, d_g_beta), constraint=constraints.positive)
    a_d_sigma = pyro.param('a_d_sigma', torch.Tensor([A_SIGMA_INIT]), constraint=constraints.positive)
    with pyro.plate('d_plate', n_drug):
        a_d = pyro.sample('a_d', dist.Normal(torch.zeros(n_drug), a_d_sigma * torch.ones(n_drug)))
        d = pyro.sample('d', dist.Normal(torch.zeros(n_drug), d_sigma))
    # create data
    mean = s[s_idx] * d[d_idx] + a_s[s_idx] + a_d[d_idx] + a
    sigma_g_alpha = pyro.param('sigma_g_alpha', torch.Tensor([ALPHA_INIT]), constraint=constraints.positive)
    sigma_g_beta = pyro.param('sigma_g_beta', torch.Tensor([BETA_INIT]), constraint=constraints.positive)
    sigma = pyro.sample('sigma', dist.Gamma(sigma_g_alpha, sigma_g_beta))
    with pyro.plate('data_plate', n_synth):
        pyro.sample('data', dist.Normal(mean, sigma * torch.ones(n_synth)), obs=obs)

In [34]:
def generate_synthetic_samples(n_samp, n_drug, s_idx, d_idx, n_synth, obs):
    # # create global offset
    a = np.random.normal(loc=0, scale=A_SIGMA_INIT)  
    # create s
    s_g_alpha = pyro.param('s_g_alpha', torch.Tensor([G_ALPHA_INIT]), constraint=constraints.positive)
    s_g_beta = pyro.param('s_g_beta', torch.Tensor([G_BETA_INIT]), constraint=constraints.positive)

    s_sigma = dist.Gamma(G_ALPHA_INIT, G_BETA_INIT)
    
    s_sigma = pyro.param('s_sigma', dist.Gamma(s_g_alpha, s_g_beta), constraint=constraints.positive)
    a_s_sigma = pyro.param('a_s_sigma', torch.Tensor([A_SIGMA_INIT]), constraint=constraints.positive)
    with pyro.plate('s_plate', n_samp):
        a_s = pyro.sample('a_s', dist.Normal(torch.zeros(n_samp), a_s_sigma * torch.ones(n_samp)))
        s = pyro.sample('s', dist.Normal(torch.zeros(n_samp), s_sigma * torch.ones(n_samp)))
    # create d
    d_g_alpha = pyro.param('d_g_alpha', torch.Tensor([G_ALPHA_INIT]), constraint=constraints.positive)
    d_g_beta = pyro.param('d_g_beta', torch.Tensor([G_BETA_INIT]), constraint=constraints.positive)
    d_sigma = pyro.param('d_sigma', dist.Gamma(d_g_alpha, d_g_beta), constraint=constraints.positive)
    a_d_sigma = pyro.param('a_d_sigma', torch.Tensor([A_SIGMA_INIT]), constraint=constraints.positive)
    with pyro.plate('d_plate', n_drug):
        a_d = pyro.sample('a_d', dist.Normal(torch.zeros(n_drug), a_d_sigma * torch.ones(n_drug)))
        d = pyro.sample('d', dist.Normal(torch.zeros(n_drug), d_sigma))
    # create data
    mean = s[s_idx] * d[d_idx] + a_s[s_idx] + a_d[d_idx] + a
    sigma_g_alpha = pyro.param('sigma_g_alpha', torch.Tensor([ALPHA_INIT]), constraint=constraints.positive)
    sigma_g_beta = pyro.param('sigma_g_beta', torch.Tensor([BETA_INIT]), constraint=constraints.positive)
    sigma = pyro.sample('sigma', dist.Gamma(sigma_g_alpha, sigma_g_beta))
    with pyro.plate('data_plate', n_synth):
        pyro.sample('data', dist.Normal(mean, sigma * torch.ones(n_synth)), obs=obs)

In [24]:
# read in model inputs
base_dir = '../results/2023-06-09/clean_and_split_data/split'
train_fn = base_dir + '/train.pkl'
test_fn = base_dir + '/test.pkl'
sample_fn = base_dir + '/sample_dict.pkl'
drug_fn = base_dir + '/drug_dict.pkl'

n_samp, n_drug, s_idx, d_idx, _ = get_model_inputs(train_fn, sample_fn, drug_fn)
_, _, s_test_idx, d_test_idx, _ = get_model_inputs(test_fn, sample_fn, drug_fn)

In [31]:
# generate samples from model: split into train, test
prior_samples = pyro.infer.Predictive(model, {}, num_samples=200)(n_samp, n_drug, s_idx, d_idx, len(s_idx), None)
print(prior_samples)

{'a': tensor([[ 2.3832e+00],
        [-7.8171e-01],
        [ 2.1455e+00],
        [-1.4701e+00],
        [ 6.9609e+00],
        [ 6.8172e-02],
        [ 2.6696e+00],
        [-2.0298e+00],
        [ 1.2653e+00],
        [ 4.0996e+00],
        [ 7.4133e+00],
        [ 9.2642e-01],
        [-7.4653e+00],
        [-9.0633e+00],
        [-6.1947e+00],
        [-4.1844e+00],
        [-6.4961e+00],
        [ 1.1506e+00],
        [-1.9524e+00],
        [-5.4245e+00],
        [ 2.7900e-02],
        [-2.2750e+00],
        [ 6.5842e+00],
        [-2.9696e+00],
        [-1.0206e+01],
        [-1.0626e+01],
        [-7.1471e+00],
        [ 4.4575e+00],
        [ 3.3710e-01],
        [ 6.5820e-01],
        [-1.1093e+01],
        [ 2.1193e+00],
        [-1.6455e+00],
        [-1.3788e+00],
        [ 9.1119e-01],
        [ 8.6699e+00],
        [ 6.8782e+00],
        [ 3.5627e+00],
        [ 7.8041e+00],
        [ 1.8480e+00],
        [ 2.8927e+00],
        [-2.5554e+00],
        [ 2.4405e+00],
     

In [None]:
# fit model

In [None]:
# evaluate model