In [None]:
#| default_exp train

In [None]:
#| hide
from nbdev.showdoc import *

# train
> Some code to ease up training

In [None]:
#| export
import torch
from torch import nn

A loss function meant to be used with [Latent ODEs for Irregularly-Sampled Time Series](https://github.com/YuliaRubanova/latent_ode)'s `LatentODE`. Some modifications were applied:
- a mask is required since the sparse observations are supported
- likelihood is averaged over all the trials (rather than added across)

In [None]:
#| export
class LatentODELoss:
    "A loss function meant to be paired with Rubanova's `LatentODE`"
    
    def __init__(self,
                 noise_std: torch.Tensor, # Standard deviation of the noise assumed when computing the likelihood
                 prior: torch.distributions.normal.Normal # Prior distribution for the initial state
                ):
        
        self.noise_std = noise_std
        self.prior = prior
        
        self.mse_loss = nn.MSELoss()
        
    def __str__(self):
        
        return f'LatentODELoss with:\n\tnoise standard deviation = {self.noise_std}\n\tprior: {self.prior}'
    
    __repr__ = __str__
           
    def __call__(self,
        pred: torch.Tensor, # Predictions [time, trial, batch, feature]
        mean: torch.Tensor, # Mean [batch, ...]
        std: torch.Tensor, # Standard deviation [batch, ...]
        target: torch.Tensor, # Targets [batch, time, feature]
        target_mask: torch.BoolTensor, # Targets [batch, time, feature]
        kl_weight: float # KL divergence weight on the loss
    ) -> tuple[torch.Tensor, dict]: # Loss and some extra info


        # -------------- KL divergence
        
        # the *posterior* (observations have already been processed) distribution of the latent state at the beginning
        z0_posterior = torch.distributions.normal.Normal(mean, std)

        # [1, batch, latent feature]
        kl = torch.distributions.kl.kl_divergence(z0_posterior, self.prior)

        kl_average = kl.mean()
        
        # -------------- likelihood
        
        # we'd rather have [trial, batch, time, feature]
        pred = pred.permute([1, 2, 0, 3])

        assert pred.shape[1:] == target.shape

        # the distribution of the predictions...
        pred_distribution = torch.distributions.normal.Normal(loc=pred, scale=self.noise_std)

        # ...is used to compute the likelihood
        likelihood = pred_distribution.log_prob(target)

        n_samples = len(pred)
        
        # -------------- MSE

        with torch.no_grad():

            mse = self.mse_loss(pred, target.expand((n_samples,) + target.shape))
            
        # --------------

        loss = - (torch.masked_select(likelihood, target_mask).mean() - kl_weight * kl_average)

        return loss, dict(kl_average=kl_average, mse=mse)

The parameters required to..

In [None]:
prior = torch.distributions.normal.Normal(torch.tensor(0.0), torch.tensor(1.))
noise_std = torch.tensor(0.01)

...instantiate the class

In [None]:
loss_func = LatentODELoss(noise_std, prior)
loss_func

LatentODELoss with:
	noise standard deviation = 0.009999999776482582
	prior: Normal(loc: 0.0, scale: 1.0)

Some random data for testing purposes

In [None]:
n_time_instants = 12
n_trials = 3
batch_size = 32
features_size = 2
latent_size = 13

pred = torch.randn(n_time_instants, n_trials, batch_size, features_size)
mean = torch.randn(1, batch_size, latent_size)
std = torch.rand_like(mean)
target = torch.randn(batch_size, n_time_instants, features_size)
target_mask = (torch.randn_like(target) > 0.).bool()
kl_weight = 0.2

The loss function is applied

In [None]:
loss_func(pred, mean, std, target, target_mask, kl_weight)

(tensor(9830.3252), {'kl_average': tensor(1.2236), 'mse': tensor(2.0446)})

In [None]:
#| hide
from nbdev.doclinks import nbdev_export

In [None]:
#| hide
nbdev_export('30_train.ipynb')