In [95]:
# enable autorreload of modules
%load_ext autoreload
import pickle
import numpy as np
import torch
import matplotlib.pyplot as plt
from new_models import *
from SDE_helper import *
import warnings
warnings.filterwarnings('always', module='.*')
import logging
logging.getLogger().setLevel(logging.ERROR)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [96]:
from new_models import SDENet, NewEncoderLSTM, NewVariationalInference
with open('data/datafile_dim8.pkl', "rb") as f:
    data_gen = pickle.load(f)
dg = data_gen
dg.set_device('cpu')
dg.set_train_size(1000)
    
data_config = sim_config.dim8_config

obs_dim = data_config.obs_dim
latent_dim = data_config.latent_dim
action_dim = data_config.action_dim
t_max = data_config.t_max
step_size = data_config.step_size
encoder_latent_ratio = 2.0
encoder_output_dim = 6
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
normalize = True
prior = ExponentialPrior.log_density
ablate = False
condition_w = True
condition_w_only = False



encoder = NewEncoderLSTM(
            obs_dim + action_dim,
            int(obs_dim * encoder_latent_ratio),
            encoder_output_dim,
            device=device,
            normalize=normalize,
        )


new_decoder = SDENet(
    input_size=(encoder_output_dim,),
    device=device,
    condition_w=condition_w,
    obs_dim=obs_dim,
    y_net_hidden_width=16,
    weight_network_sizes=(32,), #128
    control_network_width=16, #8
    ablate=ablate,
    condition_w_only=condition_w_only,
)

data = data_gen.get_split("test", 32, 0)
print(data['actions'].shape)

train_size -100
n_sample 1000
Running ablation study: False
Conditioning on w only: False
Number of parameters of y net: 82
Conditioning on expert ODEs: True
Conditioning expert ODEs only False
torch.Size([15, 32, 1])


In [97]:
from training_utils import *
# set action
common_action = data['actions']
new_decoder.set_action(common_action)
new_dose_func = new_decoder.dose_at_time
new_dose_func(5/14)

tensor([0.2676, 0.0000, 2.2849, 0.0000, 1.5599, 0.0000, 0.0971, 0.0000, 8.4267,
        0.0000, 0.0000, 0.0334, 0.0252, 0.0878, 1.2365, 0.0000, 0.0000, 0.0000,
        6.7690, 0.0000, 0.0000, 0.2288, 0.0000, 5.1879, 0.1072, 1.1693, 0.0000,
        0.0000, 0.0468, 0.0000, 0.0000, 0.0000], grad_fn=<MulBackward0>)

In [98]:
from model import *
from training_utils import *
old_encoder = EncoderLSTM(
            obs_dim + action_dim,
            int(obs_dim * encoder_latent_ratio),
            4,
            device=device,
            normalize=normalize,
        )

old_decoder = RocheExpertDecoder(
            obs_dim,
            4,
            action_dim,
            14,
            1,
            roche=True,
            method="dopri5",
            device=device,
            ablate=False,
        )
old_decoder.ode.set_action(common_action)
old_dose_func = old_decoder.ode.dose_at_time
# print(old_decoder.ode.times)
old_dose_func(5)

ml_dim: 0


tensor([0.2676, 0.0000, 2.2849, 0.0000, 1.5599, 0.0000, 0.0971, 0.0000, 8.4267,
        0.0000, 0.0000, 0.0334, 0.0252, 0.0878, 1.2365, 0.0000, 0.0000, 0.0000,
        6.7690, 0.0000, 0.0000, 0.2288, 0.0000, 5.1879, 0.1072, 1.1693, 0.0000,
        0.0000, 0.0468, 0.0000, 0.0000, 0.0000], grad_fn=<MulBackward0>)

In [99]:
old_interval = torch.arange(0, 15, 1)
new_interval = torch.arange(0, 15, 1) / 14
for old_t, new_t in zip(old_interval, new_interval):
    old = old_dose_func(old_t)
    new = new_dose_func(new_t)
    assert torch.allclose(old, new, atol=1e-5), f"{old_t} {new_t} {old} {new}"

In [100]:
from SDE_helper import expertODE
y = data['latents']
Disease = y[0, :, 0]
ImmuneReact = y[0, :, 1]
Immunity = y[0, :, 2]
Dose2 = y[0, :, 3]
def new_ode(t, y):
    Disease = y[:, 0]
    ImmuneReact = y[:, 1]
    Immunity = y[:, 2]
    Dose2 = y[:, 3]
    return expertODE(t,
                Disease,
                ImmuneReact,
                Immunity,
                Dose2,
                new_dose_func,
                device=device)
for old_t, new_t in zip(old_interval, new_interval):
    old = old_decoder.ode(old_t,
                          y[0, :, :4])
    new = new_ode(new_t,
                   y[0, :, :4])
    assert torch.allclose(old, new, atol=1e-5), f"{old_t} {new_t} {old} {new}"

In [101]:
from torchdiffeq import odeint
old = odeint(old_decoder.ode, 
       y[0, :, :4], 
       # torch.arange(0, 15, 1, dtype=torch.float32, device=device),
       torch.tensor([0.0, 6.0]),
       method='dopri5')
print(old[1])

tensor([[2.8680e-01, 7.8167e-01, 1.3165e+00, 1.9698e-01],
        [1.3635e-01, 7.6141e-02, 1.4811e-01, 2.0367e-05],
        [4.4894e-01, 5.7167e-01, 9.9947e-01, 8.4058e-01],
        [6.7050e-02, 8.1953e-01, 1.8323e+00, 1.8453e-05],
        [5.1660e-01, 4.0297e-01, 1.0721e+00, 1.1478e+00],
        [1.4816e-01, 9.1637e-01, 1.6009e+00, 9.0819e-06],
        [6.3855e-01, 5.4999e-01, 4.8105e-01, 1.7852e-01],
        [1.0633e-03, 5.4722e-01, 2.4892e+00, 2.0732e-06],
        [4.0887e-01, 1.8838e-01, 1.0602e+00, 3.1001e+00],
        [6.3441e-02, 8.1742e-01, 1.8481e+00, 1.4171e-05],
        [5.9670e-02, 8.4037e-01, 1.8774e+00, 2.2253e-06],
        [2.7797e-01, 9.4979e-01, 1.3650e+00, 4.9180e-02],
        [2.1563e-01, 1.2327e-01, 1.2113e-01, 3.7139e-02],
        [3.0912e-01, 9.5148e-01, 1.3694e+00, 1.2920e-01],
        [7.3207e-01, 4.2195e-01, 5.3883e-01, 9.0978e-01],
        [1.9120e-03, 5.3694e-01, 2.3940e+00, 2.1197e-05],
        [5.1781e-01, 6.5941e-01, 6.9548e-01, 4.1305e-06],
        [3.615

In [107]:
def new_ode1(t, y):
    return new_ode(t, y)*14

new = odeint(new_ode1,
         y[0, :, :4], 
        #  torch.arange(0, 15, 1, dtype=torch.float32, device=device)/14,
        torch.tensor([0.0, 6.0/(14.0)]),
         method='dopri5')
print((new)[1])

tensor([[2.8680e-01, 7.8167e-01, 1.3165e+00, 1.9698e-01],
        [1.3635e-01, 7.6141e-02, 1.4811e-01, 2.0367e-05],
        [4.4894e-01, 5.7167e-01, 9.9947e-01, 8.4058e-01],
        [6.7050e-02, 8.1953e-01, 1.8323e+00, 1.8453e-05],
        [5.1660e-01, 4.0296e-01, 1.0721e+00, 1.1478e+00],
        [1.4816e-01, 9.1637e-01, 1.6009e+00, 9.0819e-06],
        [6.3855e-01, 5.4999e-01, 4.8105e-01, 1.7852e-01],
        [1.0633e-03, 5.4722e-01, 2.4892e+00, 2.0732e-06],
        [4.0887e-01, 1.8838e-01, 1.0602e+00, 3.1001e+00],
        [6.3441e-02, 8.1742e-01, 1.8481e+00, 1.3181e-05],
        [5.9670e-02, 8.4037e-01, 1.8774e+00, 9.9070e-07],
        [2.7797e-01, 9.4979e-01, 1.3650e+00, 4.9180e-02],
        [2.1563e-01, 1.2327e-01, 1.2113e-01, 3.7139e-02],
        [3.0912e-01, 9.5148e-01, 1.3694e+00, 1.2920e-01],
        [7.3207e-01, 4.2195e-01, 5.3883e-01, 9.0978e-01],
        [1.9120e-03, 5.3694e-01, 2.3940e+00, 2.1197e-05],
        [5.1781e-01, 6.5941e-01, 6.9548e-01, 4.1305e-06],
        [3.615

In [103]:
assert torch.allclose(old[1], new[1], atol=1e-5), f"{old} {new}"