In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
#| default_exp learner

In [None]:
from fastcore.test import *

# GPFA Learner

> Utilities to train and visualize a GPFA model 

In [None]:
#| export
import torch
from torch import Tensor

import gpytorch
from gpfa_imputation.core import *
from collections import namedtuple

from fastcore.foundation import *
from fastprogress.fastprogress import progress_bar, master_bar
from fastcore.foundation import patch

The first thing that we need is a Learner object to keep track of:

- input data, output data
- model
- likelihood

and that has methods to help with training/visualization

The first thing we need is a training loop, just wrap in a function the example one from GPyTorch


### Normalization

The different variables in the can have pretty different values so we normalize so they are more comparable. Have numbers between 0 and 1 should also help with the computation accuracy.

One additional complexity is the need to backtransform not only the mean but also the standard deviation.

So we need a but of math

$$x_{norm} = \frac{x - \mu_x}{\sigma_x}$$
then
$$x = x_{norm}\sigma_x + \mu_x $$

using properties of Guassian distributions ^[https://cs.nyu.edu/~roweis/notes/gaussid.pdf eq. 4a]

$$p(x_{norm}) = \mathcal{N}(\mu_{norm}, \sigma^2_{norm})$$

$$p(x) = \mathcal{N}(\sigma_x\mu_{norm} + \mu_x, \sigma^2_x \sigma^2_{norm})$$

In [None]:
def normalize(x: Tensor # up to 2D tensor 
             ) -> tuple[Tensor, Tensor, Tensor]: # Tuple of `x_norm`, `x_mean` and `x_std`
    "Normalize (substract mean and divide by standard deviation) input tensor"
    x_mean = x.mean(axis=0)
    x_std = x.std(axis=0)

    return ((x - x_mean) / x_std), x_mean, x_std 

In [None]:
def reverse_normalize(x_norm, # Normalized array
                      x_mean, # mean used in normalization
                      x_std   # std dev used in normalization
                      ) -> Tensor:       # Array after reversing normalization
    return x_norm * x_std + x_mean

In [None]:
def reverse_normalize_std(x_std_norm, # Normalized array of standard deviations
                      x_std   # std dev used in normalization
                      ) -> Tensor:       # Array after reversing normalization
    return x_std_norm * x_std

In [None]:
x = torch.randn(20).reshape(-1,2)
test_close(x, reverse_normalize(*normalize(x)))
# need to add test for reverse_normalize_std

## Learner

In [None]:
#| export
class GPFALearner():
    def __init__(self, X):
        self.prepare_X(X)
        self.prepare_time(X)
        
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()
        latent_kernel = gpytorch.kernels.RBFKernel()
        self.model = GPFA(self.T, self.X, self.likelihood, self.n_features, latent_kernel)
        
    @torch.no_grad()
    def prepare_X(self, X):
        X, self.x_mean, self.x_std = normalize(X)
        # flatten Matrix to vector
        self.X = X.reshape(-1) 
        self.n_features = X.shape[1]
        
    @torch.no_grad()
    def prepare_time(self, X):
        self.T = torch.arange(X.shape[0])
        
    
    def train(self, n_iter=100, lr=0.1):
        # need to enable training mode
        self.model.train()
        self.likelihood.train()
        
        # Use the adam optimizer
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) 
        
        self.losses = torch.zeros(n_iter)
        # "Loss" for GPs - the marginal log likelihood
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self.model)
        self.pb = master_bar([1])
        for _ in self.pb:
            for i in progress_bar(range(n_iter), parent=self.pb):
                # Zero gradients from previous iteration
                optimizer.zero_grad()
                # Output from model
                output = self.model(self.T)
                # Calc loss and backprop gradients
                loss = -mll(output, self.X)
                self.losses[i] = loss.detach()
                loss.backward()
                self.printer(i)

                optimizer.step()
        
        
    def printer(self, i):
        pass
        

In [None]:
# test data
T = torch.arange(0,6)

X = torch.vstack([(torch.arange(0,3, dtype=torch.float32) + 2 + i) * i for i in T]) 

In [None]:
X

tensor([[ 0.,  0.,  0.],
        [ 3.,  4.,  5.],
        [ 8., 10., 12.],
        [15., 18., 21.],
        [24., 28., 32.],
        [35., 40., 45.]])

In [None]:
# l for learner
l = GPFALearner(X)

In [None]:
test_eq(T, l.T)

In [None]:
test_eq(l.n_features, 3)

In [None]:
l.X

tensor([-1.0590, -1.0955, -1.1236, -0.8347, -0.8326, -0.8305, -0.4610, -0.4382,
        -0.4201,  0.0623,  0.0876,  0.1075,  0.7350,  0.7449,  0.7523,  1.5573,
         1.5337,  1.5145])

In [None]:
l.train()

torch.linalg.solve_triangular has its arguments reversed and does not return a copy of one of the inputs.
X = torch.triangular_solve(B, A).solution
should be replaced with
X = torch.linalg.solve_triangular(A, B). (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2183.)
  res = torch.triangular_solve(right_tensor, self.evaluate(), upper=self.upper).solution
 does not have profile information (Triggered internally at  ../torch/csrc/jit/codegen/cuda/graph_fuser.cpp:104.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


In [None]:
l.losses

tensor([ 1.3442,  1.3034,  1.2644,  1.2289,  1.1978,  1.1685,  1.1362,  1.1000,
         1.0615,  1.0225,  0.9837,  0.9453,  0.9072,  0.8689,  0.8301,  0.7906,
         0.7506,  0.7100,  0.6689,  0.6275,  0.5859,  0.5442,  0.5025,  0.4608,
         0.4190,  0.3772,  0.3353,  0.2932,  0.2510,  0.2086,  0.1662,  0.1237,
         0.0813,  0.0389, -0.0033, -0.0455, -0.0876, -0.1296, -0.1714, -0.2132,
        -0.2548, -0.2962, -0.3375, -0.3786, -0.4196, -0.4603, -0.5008, -0.5410,
        -0.5810, -0.6207, -0.6601, -0.6990, -0.7343, -0.7640, -0.8089, -0.8364,
        -0.8857, -0.9075, -0.9574, -0.9789, -1.0242, -1.0492, -1.0886, -1.1164,
        -1.1515, -1.1772, -1.2128, -1.2350, -1.2700, -1.2861, -1.3172, -1.3314,
        -1.3617, -1.3904, -1.4003, -1.4268, -1.4382, -1.4574, -1.4849, -1.4873,
        -1.5002, -1.5141, -1.5112, -1.5356, -1.5542, -1.5486, -1.5625, -1.5695,
        -1.5620, -1.5779, -1.5880, -1.5860, -1.6030, -1.6040, -1.5996, -1.6096,
        -1.6022, -1.5983, -1.6107, -1.61

## Predictions

add a function to get predictions from the model

In [None]:
#| export
@torch.no_grad() # don't calc gradients on predictions
@patch()
def predict_raw(self: GPFALearner, T):
    self.model.eval()
    self.likelihood.eval()
    return self.likelihood(self.model(T))

In [None]:
raw_out = l.predict_raw(T)
raw_out



MultivariateNormal(loc: torch.Size([18]))

the model prediction is a distribution with `len(T)*n_features` dimensions

which is in the in the wrong shape and need to be rescaled after the normalization

Also we don't need th full distribution but only the mean and stddev for each variable at every time step

And we can "fix" the shape by transforming back to a matrix

In [None]:
raw_stddev = raw_out.stddev.reshape(-1, l.n_features)
raw_mean = raw_out.mean.reshape(-1, l.n_features)

In [None]:
raw_stddev

tensor([[0.0250, 0.0198, 0.0224],
        [0.0242, 0.0187, 0.0214],
        [0.0239, 0.0183, 0.0211],
        [0.0239, 0.0183, 0.0211],
        [0.0242, 0.0187, 0.0214],
        [0.0250, 0.0198, 0.0224]], grad_fn=<ReshapeAliasBackward0>)

In [None]:
#| export
NormParam = namedtuple("NormalParameters", ["mean", "std"])

In [None]:
#| export
@torch.no_grad()
@patch
def predict(self: GPFALearner, T):
    raw_out = self.predict_raw(T)
    raw_std = raw_out.stddev.reshape(-1, self.n_features)
    raw_mean = raw_out.mean.reshape(-1, self.n_features)
    
    pred_mean = reverse_normalize(raw_mean, self.x_mean, self.x_std)
    pred_std = reverse_normalize_std(raw_std, self.x_std)
    # detach to avoid that gradients are calculated on results
    return NormParam(pred_mean.detach(), pred_std.detach())

In [None]:
l.predict(T)

NormalParameters(mean=tensor([[-4.5984e-01, -1.7553e-02,  4.4573e-01],
        [ 3.0967e+00,  4.0383e+00,  4.9975e+00],
        [ 8.2710e+00,  9.9444e+00,  1.1622e+01],
        [ 1.5309e+01,  1.7973e+01,  2.0630e+01],
        [ 2.4203e+01,  2.8114e+01,  3.2012e+01],
        [ 3.4577e+01,  3.9946e+01,  4.5289e+01]]), std=tensor([[0.3351, 0.3012, 0.3816],
        [0.3239, 0.2848, 0.3657],
        [0.3200, 0.2785, 0.3599],
        [0.3198, 0.2787, 0.3599],
        [0.3239, 0.2848, 0.3659],
        [0.3351, 0.3015, 0.3822]]))

probably something is wrong here, the data is very different from the input

### Check learning is working

The idea is to use the current model to generate a dataset, that can be for sure modelled using a GPFA (because is the output of GPFA) and then train another model and see if the parameters converge

In [None]:
# create a dummy GPFA with 3 features
Lt = GPFALearner(X)

In [None]:
test_params = {
   "Lambda": torch.tensor([-1, 0.3, .8]).reshape(Lt.n_features, -1),
   "psi": torch.tensor([1e-5, 5e-5, 2e-5]),
   "latent_kernel.lengthscale": torch.tensor(5),
}

In [None]:
Lt.model.covar_module.initialize(**test_params)

GPFAKernel(
  (latent_kernel): RBFKernel(
    (raw_lengthscale_constraint): Positive()
  )
  (raw_psi_diag_constraint): Positive()
)

In [None]:
target_X = Lt.predict(T).mean

In [None]:
l2 = GPFALearner(target_X)

In [None]:
l2.train()

In [None]:
l2.predict(T).mean - target_X

tensor([[-4.0693e-03, -8.3923e-04, -8.5640e-04],
        [-2.9421e-03, -3.9482e-04, -1.4305e-04],
        [-9.4700e-04, -1.2779e-04, -4.9591e-05],
        [ 1.1806e-03,  1.1635e-04, -3.4332e-05],
        [ 2.8696e-03,  4.3678e-04,  2.6131e-04],
        [ 3.9005e-03,  8.0872e-04,  8.1444e-04]])

they seems pretty small numbers, so the model is working! 

In [None]:
print("Lambda:\n", l2.model.covar_module.Lambda.detach())

print("psi: ", l2.model.covar_module.psi.detach())

print("lengthscale:", l2.model.covar_module.latent_kernel.lengthscale.item())


Lambda:
 tensor([[-1.6049],
        [ 1.6213],
        [ 1.6172]])
psi:  tensor([1.0492e-04, 3.8727e-05, 3.8805e-05])
lengthscale: 5.523871421813965


## Printer

This methods get called at each training iterator to show the progress

we want to extract all the parameters from the model. If there is a contraint tranfrom the parameter to get the correct value

In [None]:
def get_parameter_value(name, param, constraint):
    if constraint is not None:
        value = constraint.transform(param.data.detach())
        name = name.replace("raw_", "") # parameter is not raw anymore
    else:
        value = param.data.detach()
    return (name, value)

In [None]:
name = "covar_module.psi"
test_eq(l.model.covar_module.psi.detach(), get_parameter_value(name, l.model.covar_module.raw_psi_diag, l.model.covar_module.raw_psi_diag_constraint)[1])

In [None]:
def tensor_to_first_item(tensor):
    if tensor.dim() > 0:
        return tensor_to_first_item(tensor[0])
    return tensor.item()

In [None]:
def format_parameter(name, value):
    value = tensor_to_first_item(value)
    name = name.split(".")[-1] # get only last part of name
    return f"{name}: {value:.3f}"

In [None]:
#| export
@patch
def get_formatted_params(self: GPFALearner):
    return ", ".join([
        format_parameter(*get_parameter_value(name, value, constraint))
        for name, value, constraint in
        self.model.named_parameters_and_constraints()
    ])

In [None]:
l.get_formatted_params()

'noise: 0.000, Lambda: 2.132, psi_diag: 0.000, lengthscale: 5.250'

In [None]:
@patch
def plot_loss(self: GPFALearner, i_iter):
    if i_iter ==0: return
    x = torch.arange(0, i_iter)
    y = self.losses[:i_iter]
    plot_data = [[x, y]]
    self.pb.update_graph(plot_data)
    
    x_bounds = [x.min(), x.max()+1]
    y_bounds = [y.min(), y.max()]
    self.pb.names = ["Training loss"]

In [None]:
@patch
def printer(self: GPFALearner, i_iter):

    if i_iter%10 == 0:
        update_str = f"loss: {self.losses[i_iter].item():.3f}, " + self.get_formatted_params()
        #self.plot_loss(i_iter)
    
    #self.pb.write(update_str)

In [None]:
l.train(lr = 0.01)

In [None]:
#| hide
from nbdev import nbdev_export
nbdev_export()