## Variational Inference

The dataset required is small and is available preprocessed here:

- https://drive.google.com/drive/folders/1Tg_3SlKbdv0pDog6k2ys0J79e1-vgRyd?usp=sharing

In [None]:
import torch

from lafomo.datasets import P53Data
from lafomo.variational.kernels import RBF
from lafomo.variational.models import SingleLinearLFM
from lafomo.variational.trainer import P53ConstrainedTrainer
from lafomo.configuration import VariationalConfiguration
from lafomo.utilities.torch import save, load
from lafomo.plot import Plotter

from matplotlib import pyplot as plt

In [None]:
dataset = P53Data(replicate=0, data_dir='../../../data')
num_genes = 5
num_tfs = 1

t_inducing = torch.linspace(0, 12, 10, dtype=torch.float64)
t_observed = torch.linspace(0, 12, 7)
t_predict = torch.linspace(-1, 13, 80, dtype=torch.float64)

plt.figure(figsize=(4, 2))
plt.plot(dataset[0][1])
plt.plot(dataset.m_observed[0, 0])

In [None]:
options = VariationalConfiguration(
    preprocessing_variance=dataset.variance,
    learn_inducing=False,
    num_samples=50,
    kernel_scale=False,
    initial_conditions=False
)
rtol = 1e-1
atol = rtol/10

model_kwargs = {
    'rtol': rtol, 'atol': atol
}
kernel = RBF(dataset.num_latents, scale=options.kernel_scale, dtype=torch.float64)
model = SingleLinearLFM(options, kernel, t_inducing, dataset)
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
trainer = P53ConstrainedTrainer(model, optimizer, dataset)
plotter = Plotter(model, dataset.gene_names)

### Outputs prior to training:

In [None]:
plotter.plot_kinetics()
plotter.plot_outputs(t_predict, replicate=0,
                     t_scatter=dataset.t_observed, y_scatter=dataset.m_observed,
                     model_kwargs=model_kwargs)
plotter.plot_latents(t_predict, ylim=(-1, 3), plot_barenco=True, plot_inducing=False)

In [None]:
tol = 1e-1
import time
start = time.time()

output = trainer.train(100, rtol=tol, atol=tol/10,
                       report_interval=5, plot_interval=5)
end = time.time()
print(end - start)

### Outputs after training

In [None]:
plotter.plot_losses(trainer, last_x=100)
plotter.plot_outputs(t_predict, replicate=0,# ylim=(0, 3),
                     t_scatter=dataset.t_observed,
                     y_scatter=dataset.m_observed,
                     model_kwargs=model_kwargs)
plotter.plot_latents(t_predict, ylim=(-2, 3.2), plot_barenco=False, plot_inducing=False)
plotter.plot_kinetics()
plotter.plot_convergence(trainer)

In [None]:
S = torch.tril(torch.stack(trainer.cholS).squeeze())
S = torch.matmul(S, S.transpose(1, 2))
plt.imshow(S[-1])
plt.colorbar()

In [None]:
plt.imshow(model.Kmm[0].detach())
plt.colorbar()

In [None]:
print(model.inducing_inputs)

In [None]:
Ksm = model.kernel(t_predict, model.inducing_inputs)  # (I, T*, Tu)
α = torch.cholesky_solve(Ksm.permute([0, 2, 1]), model.L, upper=False).permute([0, 2, 1])  # (I, T*, Tu)
m_s = torch.matmul(α, model.q_m)  # (I, T*, 1)
m_s = torch.squeeze(m_s, 2)
Kss = model.kernel(t_predict)  # (I, T*, T*) this is always scale=1
S_Kmm = model.S - model.Kmm  # (I, Tu, Tu)
AS_KA = torch.matmul(torch.matmul(α, S_Kmm), torch.transpose(α, 1, 2))  # (I, T*, T*)
S_s = (Kss + AS_KA)  # (I, T*, T*)
print(S_s.shape)
# plt.imshow(S_s.detach()[0])
# plt.colorbar()

std = torch.sqrt(torch.diagonal(S_s[0])).detach()
print(std.shape, std)
plt.plot(torch.linspace(0, 1, 80), torch.ones(80))
plt.fill_between(torch.linspace(0, 1, 80), torch.ones(80) + std, torch.ones(80) - std)

In [None]:
# import numpy as np
# timepoints = 20
# t_temp = torch.linspace(0, 12, timepoints, dtype=torch.float64)
# initial_value = torch.zeros((options.num_samples, 5, 1))
# samples = model(t_temp, initial_value, return_samples=True, rtol=1e-3, atol=1e-3)
# samples = samples.detach().numpy()
# fig, ax = plt.subplots(nrows=1) #, figsize=(10, 10))
# full_cov = np.zeros((timepoints*5, timepoints*5))
# for j in range(5):
#     x = samples[:, :, j].squeeze()
#     covxx = np.cov(x)
#     full_cov[j*timepoints:(j+1)*timepoints, j*timepoints:(j+1)*timepoints] = covxx
#
#     for k in range(j+1, 5):
#         y = samples[:,:, j+1].squeeze()
#         covxy = np.cov(x, y)
#         full_cov[j*timepoints:(j+1)*timepoints, k*timepoints:(k+1)*timepoints] = covxy[:timepoints, timepoints:]
#         full_cov[k*timepoints:(k+1)*timepoints, j*timepoints:(j+1)*timepoints] = covxy[:timepoints, timepoints:]
#
# ax.imshow(full_cov)
# plt.axis('off')

In [None]:
# mu = samples.mean(axis=1).reshape(-1)
# mu = torch.tensor(mu, dtype=torch.float32)
# cov = torch.tensor(full_cov, dtype=torch.float32) + torch.eye(5*timepoints) * 1e-1
# print(mu.shape, cov.shape)
# post_dist = torch.distributions.MultivariateNormal(mu, cov)
# fig, ax = plt.subplots(nrows=num_genes, figsize=(5, 10))
# for j in range(num_genes):
#     ax[j].plot(mu.view(timepoints, num_genes)[:, j])
#     for _ in range(10):
#         sample = post_dist.sample().view(timepoints, num_genes)
#         ax[j].plot(sample[:, j])


In [None]:
save(model, 'variational_linear')

In [None]:
do_load = False
if do_load:
    model = load('variational_linear', SingleLinearLFM, num_genes, num_tfs,
                 t_inducing, dataset, extra_points=2, fixed_variance=dataset.variance)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    trainer = P53ConstrainedTrainer(model, optimizer, dataset)
print(do_load)