## Variational Inference

The dataset required is small and is available preprocessed here:

- https://drive.google.com/drive/folders/1Tg_3SlKbdv0pDog6k2ys0J79e1-vgRyd?usp=sharing

In [None]:
import torch
import numpy as np
from gpytorch.optim import NGD
from torch.optim import Adam
from torch.nn import Parameter
from matplotlib import pyplot as plt
from os import path

from lafomo.datasets import P53Data, ToyTranscriptomicGenerator
from lafomo.configuration import VariationalConfiguration
from lafomo.models import OrdinaryLFM, MultiOutputGP, generate_multioutput_rbf_gp
from lafomo.plot import Plotter1d, Colours, tight_kwargs
from lafomo.trainers import VariationalTrainer, PreEstimator
from lafomo.utilities.data import p53_ground_truth
from experiments.mae import get_datasets


Let's start by importing our dataset...

In [None]:
p53 = True
if p53:
    dataset = P53Data(replicate=0, data_dir='../../../data')
    ground_truths = p53_ground_truth()
    class ConstrainedTrainer(VariationalTrainer):
        def after_epoch(self):
            with torch.no_grad():
                sens = torch.tensor(1.)
                dec = torch.tensor(0.8)
                self.lfm.raw_sensitivity[3] = self.lfm.positivity.inverse_transform(sens)
                self.lfm.raw_decay[3] = self.lfm.positivity.inverse_transform(dec)
            super().after_epoch()

else:
    datasets = get_datasets(data_dir='../../../experiments')
    dataset = datasets[5]
    # dataset.generate_single(lengthscale=1.3)
    dataset.variance = 1e-5 * torch.ones(dataset.m_observed.shape[-1], dtype=torch.float32)
    ground_truths = [
        dataset.lfm.basal_rate.detach().view(-1).numpy(),
        dataset.lfm.sensitivity.detach().view(-1).numpy(),
        dataset.lfm.decay_rate.detach().view(-1).numpy()
    ]
    class ConstrainedTrainer(VariationalTrainer):
        def after_epoch(self):
            with torch.no_grad():
                sens = dataset.lfm.sensitivity[0].squeeze()
                dec = dataset.lfm.decay_rate[0].squeeze()
                self.lfm.raw_sensitivity[0] = self.lfm.positivity.inverse_transform(sens)
                self.lfm.raw_decay[0] = self.lfm.positivity.inverse_transform(dec)
            super().after_epoch()


num_genes = 5
num_tfs = 1

plt.figure(figsize=(4, 2))
for i in range(5):
    plt.plot(dataset[i][1])
plt.plot(dataset.f_observed[0, 0])
t_end = dataset.t_observed[-1]

We use the ordinary differential equation (ODE):

`dy/dt = b + sf(t) - dy`

`f(t) ~ GP(0, k(t, t'))`

Since this is an ODE, we inherit from the `OrdinaryLFM` class.

In [None]:
from gpytorch.constraints import Positive
class TranscriptionLFM(OrdinaryLFM):
    def __init__(self, num_outputs, gp_model, config: VariationalConfiguration, **kwargs):
        super().__init__(num_outputs, gp_model, config, **kwargs)
        self.positivity = Positive()
        self.raw_decay = Parameter(
            self.positivity.inverse_transform(0.1 + torch.rand(torch.Size([self.num_outputs, 1]), dtype=torch.float64)))
        self.raw_basal = Parameter(
            self.positivity.inverse_transform(0.1 * torch.rand(torch.Size([self.num_outputs, 1]), dtype=torch.float64)))
        self.raw_sensitivity = Parameter(
            self.positivity.inverse_transform(2*torch.rand(torch.Size([self.num_outputs, 1]), dtype=torch.float64)))

    @property
    def decay_rate(self):
        return self.positivity.transform(self.raw_decay)

    @decay_rate.setter
    def decay_rate(self, value):
        self.raw_decay = self.positivity.inverse_transform(value)

    @property
    def basal_rate(self):
        return self.positivity.transform(self.raw_basal)

    @basal_rate.setter
    def basal_rate(self, value):
        self.raw_basal = self.positivity.inverse_transform(value)

    @property
    def sensitivity(self):
        return self.positivity.transform(self.raw_sensitivity)

    @sensitivity.setter
    def sensitivity(self, value):
        self.raw_sensitivity = self.decay_constraint.inverse_transform(value)

    def initial_state(self):
        return self.basal_rate / self.decay_rate

    def odefunc(self, t, h):
        """h is of shape (num_samples, num_outputs, 1)"""
        self.nfe += 1
        # if (self.nfe % 100) == 0:
        #     print(t)
        f = self.f
        if not self.pretrain_mode:
            f = self.f[:, :, self.t_index].unsqueeze(2)
            if t > self.last_t:
                self.t_index += 1
            self.last_t = t

        dh = self.basal_rate + self.sensitivity * f - self.decay_rate * h
        return dh

In [None]:
config = VariationalConfiguration(
    preprocessing_variance=dataset.variance,
    num_samples=80,
    initial_conditions=False
)

num_inducing = 20  # (I x m x 1)
inducing_points = torch.linspace(0, t_end, num_inducing).repeat(num_tfs, 1).view(num_tfs, num_inducing, 1)
t_predict = torch.linspace(0, t_end, 80, dtype=torch.float32)
# t_predict = torch.linspace(0, t_end+2, 80, dtype=torch.float32)
step_size = 5e-1
num_training = dataset.m_observed.shape[-1]
use_natural = True
gp_model = generate_multioutput_rbf_gp(num_tfs, inducing_points, gp_kwargs=dict(natural=use_natural))

lfm = TranscriptionLFM(num_genes, gp_model, config, num_training_points=num_training)
plotter = Plotter1d(lfm, dataset.gene_names, style='seaborn')

In [None]:
track_parameters = [
    'raw_basal',
    'raw_decay',
    'raw_sensitivity',
    'gp_model.covar_module.raw_lengthscale',
]
if use_natural:
    variational_optimizer = NGD(lfm.variational_parameters(), num_data=num_training, lr=0.09)
    parameter_optimizer = Adam(lfm.nonvariational_parameters(), lr=0.02)
    optimizers = [variational_optimizer, parameter_optimizer]
    pre_variational_optimizer = NGD(lfm.variational_parameters(), num_data=num_training, lr=0.1)
    pre_parameter_optimizer = Adam(lfm.nonvariational_parameters(), lr=0.005)
    pre_optimizers = [pre_variational_optimizer, pre_parameter_optimizer]

else:
    optimizers = [Adam(lfm.parameters(), lr=0.05)]
    pre_optimizers = [Adam(lfm.parameters(), lr=0.05)]

trainer = ConstrainedTrainer(lfm, optimizers, dataset, track_parameters=track_parameters)
pre_estimator = PreEstimator(lfm, pre_optimizers, dataset, track_parameters=track_parameters)
strat = lfm.gp_model.variational_strategy.base_variational_strategy
dist = strat._variational_distribution
# plt.imshow(dist.chol_variational_covar.detach().squeeze())
# plt.colorbar()

### Outputs prior to training:

In [None]:
titles = ['Basal rates', 'Sensitivities', 'Decay rates']
kinetics = list()
for key in ['raw_basal', 'raw_sensitivity', 'raw_decay']:
    kinetics.append(
        lfm.positivity.transform(trainer.parameter_trace[key][-1].squeeze()).numpy())
kinetics = np.array(kinetics)
print(kinetics.shape)
plotter.plot_double_bar(kinetics,
                        ground_truths=p53_ground_truth(),
                        titles=titles)
q_m = lfm.predict_m(t_predict, step_size=1e-1)
q_f = lfm.predict_f(t_predict)

plotter.plot_gp(q_m, t_predict, replicate=0,
                t_scatter=dataset.t_observed,
                y_scatter=dataset.m_observed, num_samples=0)
plotter.plot_gp(q_f, t_predict, ylim=(-1, 3))
plt.title('Latent')


In [None]:
lfm.pretrain(True)
# lfm.loss_fn.num_data = 61
# pre_estimator.train(50, report_interval=20);

print(num_training)

from torch.nn.functional import l1_loss
m_targ = dataset.m_observed_highres.squeeze().t()
f_targ = dataset.f_observed_highres.squeeze(0).t()

In [None]:
lfm.pretrain(False)
t_predict = torch.linspace(0, t_end, 111)
f_maes = list()
m_maes = list()
for i in range(700 // 10):
    trainer.train(epochs=10, report_interval=50, step_size=5e-1)
    m_pred = lfm.predict_m(t_predict, jitter=1e-3)
    f_pred = lfm.predict_f(t_predict, jitter=1e-3)
    m_mae = l1_loss(m_pred.mean, m_targ).mean().item()
    f_mae = l1_loss(f_pred.mean, f_targ).mean().item()
    f_maes.append(f_mae)
    m_maes.append(m_mae)


f_maes = torch.tensor(f_maes)
m_maes = torch.tensor(m_maes)

In [None]:
min_index = (f_maes + m_maes).argmin(dim=0)
f_mae = f_maes[min_index]
m_mae = m_maes[min_index]

print(t_end)
print(m_maes)
print(f_maes)
print(m_mae, f_mae)

In [None]:
lfm.pretrain(False)
lfm.loss_fn.num_data = num_training
step_size = 5e-1
trainer.train(50, report_interval=10, step_size=step_size);

In [None]:
t_predict = torch.linspace(0, t_end, 111, dtype=torch.float32)

# plotter.plot_losses(trainer, last_x=200)
q_m = lfm.predict_m(t_predict, step_size=1e-1)
q_f = lfm.predict_f(t_predict)


labels = ['Basal rates', 'Sensitivities', 'Decay rates']
kinetics = list()
for key in ['raw_basal', 'raw_sensitivity', 'raw_decay']:
    kinetics.append(
        lfm.positivity.transform(trainer.parameter_trace[key][-1].squeeze()).numpy())

plotter.plot_double_bar(kinetics, labels, ground_truths=ground_truths,
                        figsize=(6.5, 2.3),
                        yticks=[
                            np.linspace(0, 0.12, 5),
                            np.linspace(0, 1.2, 4),
                            np.arange(0, 1.1, 0.2),
                        ])
plt.tight_layout()
plt.savefig('./kinetics.pdf', **tight_kwargs)

plotter.plot_gp(q_m, t_predict,
                t_scatter=dataset.t_observed, y_scatter=dataset.m_observed)
plotter.plot_gp(q_f, t_predict, t_scatter=dataset.t_observed, y_scatter=dataset.f_observed)
plt.plot(t_predict, f_targ.squeeze())

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 1.6),
                         gridspec_kw=dict(width_ratios=[1, 1, 0.3, 1.9], wspace=0))

row = 0
col = 0
ub = [3.5, 3.5]
for i in range(2):
    ax = axes[i]
    plotter.plot_gp(q_m, t_predict, replicate=0, ax=ax,
                    color=Colours.line_color, shade_color=Colours.shade_color,
                    t_scatter=dataset.t_observed, y_scatter=dataset.m_observed,
                    num_samples=0, only_plot_index=i)
    ax.set_ylim([-0.2, ub[i]])
    ax.set_yticks([0, 3])
    ax.set_title(dataset.gene_names[i])
    ax.set_xlim(-0.4, 15)
    if col > 0:
        ax.set_yticks([])
        ax.set_xticks([5, 10, 15])

    col += 1
plotter.plot_gp(q_f, t_predict, ax=axes[3],
                ylim=(-1, 3.2),
                num_samples=3,
                color=Colours.line2_color,
                shade_color=Colours.shade2_color)
plotter.plot_gp(exact_q_f, t_predict, ax=axes[3],
                ylim=(-1, 3.2), color=Colours.line_color,
                shade_color='red')
axes[3].set_title('Latent force (p53)')
axes[3].set_yticks([-1, 3])
axes[3].set_xlabel('Time (h)')
axes[3].set_xlim(0, 15)
axes[3].set_xticks([0, 5, 10, 15])
axes[2].set_visible(False)

# plt.savefig('./barenco-combined.pdf', **tight_kwargs)
B_exact, S_exact, D_exact = p53_ground_truth()
B_exact, S_exact, D_exact = np.array(B_exact), np.array(S_exact), np.array(D_exact)
B = lfm.basal_rate.detach().squeeze()
D = lfm.decay_rate.detach().squeeze()
S = lfm.basal_rate.detach().squeeze()
mse = torch.square(B-B_exact) + torch.square(D-D_exact) + torch.square(S-S_exact)
print(D)
print(mse.mean(), (D-D_exact), B_exact.shape)
mae2 = l1_loss(q_m.mean, exact_q_m.mean)
print(mae2, q_m.mean.shape, exact_q_m.mean.shape)

In [None]:
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood

from lafomo.models import ExactLFM
from lafomo.trainers import ExactTrainer

exact_lfm = ExactLFM(dataset, dataset.variance.reshape(-1))
optimizer = torch.optim.Adam(exact_lfm.parameters(), lr=0.07)

loss_fn = ExactMarginalLogLikelihood(exact_lfm.likelihood, exact_lfm)

track_parameters = [
    'mean_module.raw_basal',
    'covar_module.raw_decay',
    'covar_module.raw_sensitivity',
    'covar_module.raw_lengthscale',
]
exact_trainer = ExactTrainer(exact_lfm, [optimizer], dataset, loss_fn=loss_fn, track_parameters=track_parameters)
plotter = Plotter(exact_lfm, dataset.gene_names)

In [None]:
exact_lfm.train()
exact_lfm.likelihood.train()
exact_trainer.train(epochs=150, report_interval=10)

In [None]:
exact_q_m = exact_lfm.predict_m(t_predict)
exact_q_f = exact_lfm.predict_f(t_predict)
print(exact_q_f)



In [None]:
 # key in ['basal_rate', 'sensitivity', 'decay_rate']:
plt.plot(lfm.positivity.transform(torch.stack(trainer.parameter_trace['raw_basal'])[:, 3]))
plt.tight_layout()

In [None]:
print(lfm.basal_rate)

In [None]:
## This stuff is safe to delete:
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(9, 2.7),
                         gridspec_kw=dict(width_ratios=[1, 1, 0.3, 1.8], wspace=0, hspace=0.8))

row = 0
col = 0
ub = [3.5] * 4
for i in range(4):
    if i == 2:
        row += 1
        col = 0
    ax = axes[row, col]
    plotter.plot_gp(q_m, t_predict, replicate=0, ax=ax,
                    color=Colours.line_color, shade_color=Colours.shade_color,
                    t_scatter=dataset.t_observed, y_scatter=dataset.m_observed,
                    num_samples=0, only_plot_index=i)
    ax.set_ylim([-0.2, ub[i]])
    ax.set_yticks([0, 3])
    ax.set_title(dataset.gene_names[i])
    ax.set_xlim(-0.4, 15)
    if col > 0:
        ax.set_yticks([])
        ax.set_xticks([5, 10, 15])

    col += 1
plotter.plot_gp(q_f, t_predict, ax=axes[1, 3],
                ylim=(-1, 3.2),
                num_samples=0,
                color=Colours.line2_color,
                shade_color=Colours.shade2_color)
plotter.plot_gp(exact_q_f, t_predict, ax=axes[0, 3],
                ylim=(-1, 3.2), color=Colours.line2_color,
                shade_color=Colours.shade2_color)
titles = ['Lawrence et al., 2007', 'ours']
for i in range(2):
    axes[i, 3].set_title(f'Latent force ({titles[i]})')
    axes[i, 3].set_yticks([-1, 3])
    axes[i, 3].set_xlim(0, 15)
    axes[i, 3].set_xticks([0, 5, 10, 15])
    axes[i, 2].set_visible(False)
axes[1, 3].set_xlabel('Time (h)')

plt.savefig('./barenco-combined.pdf', **tight_kwargs)

from torch.nn.functional import l1_loss
print(l1_loss(q_f.mean, exact_q_f.mean).item())