In [5]:
import logging
import os
import sys

import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import pyro
import graphviz
import pyro.distributions as dist
import pyro.distributions.constraints as constraints

In [13]:
train = pd.read_pickle('../results/2023-05-26/clean_and_split_data/split/train.pkl')
test = pd.read_pickle('../results/2023-05-26/clean_and_split_data/split/test.pkl')

In [14]:
train

Unnamed: 0,sample,drug,log(V_V0+1)_obs,MID_list
0,HCI-023,Vehicle,"[1.8796562129077283, 2.020186220915781, 1.6340...","[164, 165, 166, 167, 168]"
1,HCI-027,Birinapant,"[0.0, 0.0, 0.0, 0.0, 0.0]","[99, 100, 101, 102, 103]"
2,HCI-015,Docetaxel,"[0.588717879959435, 0.5574157459740647, 0.7234...","[30, 31, 32]"
3,HCI-019,Birinapant,"[4.249973388277115, 3.774034821334692, 4.70383...","[119, 120, 121, 122, 123]"
4,HCI-010,Vehicle,"[1.4025516524368766, 1.6483637912782196, 1.585...","[75, 76, 77, 78, 79, 80]"
5,HCI-012,Irinotecan,"[1.847207597029819, 1.8487759374041268, 2.0093...","[211, 212, 213, 214, 215]"
6,HCI-010,Navitoclax,"[1.2674221734888569, 0.9497121725523964, 0.876...","[0, 1, 2]"
7,HCI-015,Vehicle,"[1.9266688515601555, 1.7748993970121154, 1.941...","[33, 34, 35, 36, 37, 38, 139, 140, 141, 142, 1..."
8,HCI-027,Docetaxel,"[0.0042676750704502, 0.0271016137012354, 0.0]","[63, 64, 65]"
9,HCI-002,Navitoclax,"[2.1605774655083625, 2.2689939926685523, 1.579...","[12, 13, 14]"


In [15]:
test

Unnamed: 0,sample,drug,log(V_V0+1)_obs,MID_list
0,HCI-002,Vehicle,"[2.2534170356235266, 2.570040567136758, 2.6695...","[15, 16, 17, 18, 19, 20, 149, 150, 151, 152, 1..."
1,HCI-002,Irinotecan,"[1.929492024545651, 0.9522673892158982, 1.4744...","[201, 202, 203, 204, 205]"
2,HCI-001,Birinapant,"[5.113505485685279, 4.576495959133186, 3.87578...","[109, 110, 111, 112, 113]"
3,HCI-001,Vehicle,"[2.6044063778532296, 3.2476149246326367, 2.506...","[93, 94, 95, 96, 97, 98, 144, 145, 146, 147, 148]"
4,HCI-027,Navitoclax,"[2.231724789216095, 2.698815312853063, 1.65234...","[9, 10, 11]"
5,HCI-010,Docetaxel,"[0.6966110197902743, 0.6617240371529612, 1.095...","[72, 73, 74]"
6,HCI-002,RO4929097,"[3.099328512453143, 2.698806786398127, 3.17931...","[191, 192, 193, 194]"
7,HCI-023,Birinapant + Irinotecan,"[0.0, 0.0, 0.0, 0.0, 0.0]","[226, 227, 228, 229, 230]"
8,HCI-003,Fulvestrant (40 mg/kg),"[2.6987904609803244, 1.4216693857312022]","[231, 232]"
9,HCI-003,Fulvestrant (200 mg/kg),[0.579485798718202],[234]


In [7]:
# Returns list of normally-distributed variables with means from dictionary d and constant variance "variance"
def normal_variables_from_dict(d, variance):
    var_dict = {}
    for key in d.keys():
        var_dict[key] = pyro.sample(key, dist.Normal(d[key], variance))
    return var_dict

def model(sample_list, drug_list, obs_list, sample_means, drug_means):
    num_observations = len(obs_list)
    assert len(sample_list) == num_observations
    assert len(drug_list) == num_observations
    # create variables for each sample and drug
    samples = normal_variables_from_dict(sample_means, 1)
    drugs = normal_variables_from_dict(drug_means, 1)
    # create variable for each (sample, drug) pair observed
    sigma = pyro.param("sigma", lambda: torch.ones(()), constraint=constraints.positive)
    for i in pyro.plate("data", num_observations):
        name = sample_list[i] + '_' + drug_list[i]
        mean = samples[sample_list[i]] * drugs[drug_list[i]]
        pyro.sample(name, dist.Normal(mean, sigma), obs=obs_list[i])

# given dataframe with columns 'sample', 'drug', and 'log(V_V0)_obs', return lists to pass to model
def format_for_model(d, vol_name):
    sample_list = list(d['sample'])
    drug_list = list(d['drug'])
    obs_list = []
    for obs in d[vol_name]:
        obs_list.append(torch.Tensor(obs))
    return sample_list, drug_list, obs_list

def get_means(var_list):
    means = {}
    for v in var_list:
        means[v] = 0
    return means


smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.8.4')

pyro.enable_validation(True)
pyro.set_rng_seed(1)
logging.basicConfig(format='%(message)s', level=logging.INFO)

# Set matplotlib settings
plt.style.use('default')

df = pd.read_pickle('../results/2023-05-26/clean_and_split_data/split/train.pkl')

In [8]:
df.head()

Unnamed: 0,sample,drug,log(V_V0+1)_obs,MID_list
0,HCI-023,Irinotecan,"[0.0646090814230765, 0.0, 0.2106294794980008, ...","[221, 222, 223, 224, 225]"
1,HCI-023,Birinapant + Irinotecan,"[0.0, 0.0, 0.0, 0.0, 0.0]","[226, 227, 228, 229, 230]"
2,HCI-002,Birinapant + Irinotecan,"[0.9045428407520958, 1.8598931160713488, 1.454...","[206, 207, 208, 209, 210]"
3,HCI-001,Vehicle,"[2.6044063778532296, 3.2476149246326367, 2.506...","[93, 94, 95, 96, 97, 98, 144, 145, 146, 147, 148]"
4,HCI-016,Vehicle,"[1.6353472115141618, 2.2895301652099707, 2.088...","[51, 52, 53, 54, 55, 56]"


In [9]:
vol_name = 'log(V_V0+1)_obs'
sample_means = get_means(df['sample'].unique())
drug_means = get_means(df['drug'].unique())
sample_list, drug_list, obs_list = format_for_model(df, vol_name)

In [10]:
pyro.render_model(model, 
	model_args=(sample_list, drug_list, obs_list, sample_means, drug_means), 
	render_distributions=True, 
	filename=write_dir + '/model_diagram.png')
pyro.clear_param_store()
kernel = pyro.infer.mcmc.NUTS(model, jit_compile=True)
mcmc = pyro.infer.MCMC(kernel, num_samples=500, warmup_steps=500)
mcmc.run(sample_list, drug_list, obs_list, sample_means, drug_means)

mcmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

sns.histplot(mcmc_samples['HCI-015'], ax=axes[0])
axes[0].set_xlabel('HCI-015')

sns.histplot(mcmc_samples['Navitoclax'], ax=axes[1])
axes[1].set_xlabel('Navitoclax')

sns.histplot(mcmc_samples['Vehicle'], ax=axes[2])
axes[2].set_xlabel('Vehicle')
plt.savefig(write_dir + '/model_stats.png')


NameError: name 'write_dir' is not defined