# Multitask Gaussian Process Imputation

This Jupyter notebook contains all the code and examples to for most things concerning Gaussian Process.

In [None]:
from copy import deepcopy
import numpy as np
import torch
import torch.optim as optim
import models
import utils
import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm

from imputation.hetvae.src.train import HETVAE
from toy_dataset import data_utils
import utils

import math
import gpytorch


%matplotlib inline
%load_ext autoreload
%autoreload 2

# Get Dataloader

In [None]:
# load my own synthetic data (Josh)
from toy_dataset import data_utils
name = 'toydataset_50000'
path = data_utils.datasets_dict[name]
dataset = data_utils.ToyDataDf(path)
dataset.create_mcar_missingness(0.6, -1)
dataloader_dict = dataset.prepare_data_mtan(batch_size=128)
train_loader = dataloader_dict['train']
gt_train_loader = dataloader_dict['train_ground_truth']
val_loader = dataloader_dict['validation']
gt_validation_loader = dataloader_dict['validation_ground_truth']
test_loader = val_loader
union_tp = utils.union_time(train_loader)

---
# Define HadamardGP

In [None]:
# Implementation of Hadamard Multitask GP based on 
# https://docs.gpytorch.ai/en/stable/examples/03_Multitask_Exact_GPs/Hadamard_Multitask_GP_Regression.html

import math
import torch
import gpytorch
import matplotlib.pyplot as plt
import os

from gpytorch.models import ExactGP
from gpytorch.means import ConstantMean, MultitaskMean
from gpytorch.kernels import Kernel, IndexKernel, MaternKernel, AdditiveKernel
from gpytorch.likelihoods import _GaussianLikelihoodBase
from gpytorch.likelihoods.noise_models import _HomoskedasticNoiseBase
from gpytorch.distributions import MultivariateNormal
from gpytorch.lazy import DiagLazyTensor

class HadamardGP(ExactGP):
    """The base class for a Hadamard multitask GP regression to be used in 
    conjunction with exact inference.

    Args: 
        num_tasks (int): number of tasks fitted by the model 
        num_kernels (int, optional): number of kernels to fit; kernels are 
            combined additively
        rank (int, optional): rank of the inter-task correlation
    """
    def __init__(self, train_inputs, train_targets, likelihood, num_tasks, num_kernels=1, rank=1):
        super(HadamardGP, self).__init__(train_inputs, train_targets, likelihood)
        self.mean_module = HadamardMean(ConstantMean(), num_tasks)
        self.covar_module = AdditiveKernel(*[
            HadamardKernel(MaternKernel(), num_tasks, rank)
            for _ in range(num_kernels)
        ])
        # print("Additive Kernel", self.covar_module)

    def forward(self, input):
        mean = self.mean_module(input)
        covar = self.covar_module(input)
        return MultivariateNormal(mean, covar)


class HadamardMean(MultitaskMean):
    """Mean function for a Hadamard Multitask GP with one learnable constant mean per task

    Args:
        base_means (:obj:`list` or :obj:`gpytorch.means.Mean`): If a list, each mean is applied to the data.
            If a single mean (or a list containing a single mean), that mean is copied `t` times.
        num_tasks (int): Number of tasks. If base_means is a list, this should equal its length.
    """
    def __init__(self, base_means, num_tasks):
        super(HadamardMean, self).__init__(base_means, num_tasks)

    def forward(self, input):
        """
        Evaluate the mean in self.base_means corresponding to each element of 
        the input data, and return as an n-vector of means
        """
        i, x = input[..., [0]], input[..., 1:]
        
        # Get means at x for each possible task and then gather the right one 
        # for each row based on the task number i.
        means = torch.cat(
            [sub_mean(x).unsqueeze(-1) for sub_mean in self.base_means], 
            dim=-1
        )
        
        # print(means)
        # print(means.gather(dim=-1, index=i.long()).squeeze(-1).shape)
        # which mean to take based on xi: means(x1, x2, x3, x4) -> means(xi)
        return means.gather(dim=-1, index=i.long()).squeeze(-1)


class HadamardKernel(Kernel):
    """Kernel function for a Hadamard Multitask GP of the form 

    K_x(x_1, x_2) \times K_i(i_1, i_2)
    
    where x denotes locations (e.g., in time) and i denotes a task identifier.

    Args:
        base_kernel (:obj:`gpytorch.kernels.Kernel): the base class for the location kernel K_x, 
        num_tasks (int): number of tasks denoting size of task covariance K_i
        rank (int): rank of the inter-task correlation 
    """
    def __init__(self, base_kernel, num_tasks, rank):
        super(HadamardKernel, self).__init__()
        self.num_tasks = num_tasks
        self.base_kernel = base_kernel
        
        if rank is None:
            rank = num_tasks
        self.task_covar_module = IndexKernel(num_tasks, rank)

    def forward(self, input1, input2, **params):
        # print("input1:",input1)
        # print("input2:",input2)
        i1, x1 = input1[..., 0], input1[..., 1:]
        i2, x2 = input2[..., 0], input2[..., 1:]

        # Get input-input covariance
        covar_x = self.base_kernel(x1, x2, **params)
        # print("covar_x:", covar_x.shape)
        # Get task-task covariance
        covar_i = self.task_covar_module(i1, i2)
        # print("covar_i:", covar_i.shape)
        # Multiply the two together to get the covariance we want
        # print(covar_x.mul(covar_i).shape)
        return covar_x.mul(covar_i)


class HadamardGaussianLikelihood(_GaussianLikelihoodBase):
    r"""
    Likelihood for a Hadamard multitask GP regression. Assumes a different 
    homoskedastic noise for each task i

    p(y_i \mid f_i) = f_i + \epsilon_i, \quad \epsilon_i \sim \mathcal N (0, \sigma_i^2)

    where :math:`\sigma_i^2` is the noise parameter of task i.

    .. note::
        Does not currently allow for batched training. 

    :param num_tasks: The number of tasks in the multitask GP.
    :type num_tasks: int
    :param noise_prior: Prior for noise parameter :math:`\sigma^2`.
    :type noise_prior: ~gpytorch.priors.Prior, optional
    :param noise_constraint: Constraint for noise parameter :math:`\sigma^2`.
    :type noise_constraint: ~gpytorch.constraints.Interval, optional
    
    :var torch.Tensor noise: :math:`\sigma_i^2` parameters (noise)
    """
    def __init__(self, num_tasks, noise_prior=None, noise_constraint=None, **kwargs):
        noise_covar = HadamardHomoskedasticNoise(
            noise_prior=noise_prior, noise_constraint=noise_constraint, num_tasks=num_tasks
        )
        super().__init__(noise_covar=noise_covar)

    @property
    def noise(self):
        return self.noise_covar.noise

    @noise.setter
    def noise(self, value):
        self.noise_covar.initialize(noise=value)

    @property
    def raw_noise(self):
        return self.noise_covar.raw_noise

    @raw_noise.setter
    def raw_noise(self, value):
        self.noise_covar.initialize(raw_noise=value)

    def __call__(self, input, *args, **kwargs):
        if not args:
            raise ValueError(
                "The first element of *args must be a list of the training" 
                "inputs."
            )
        
        # Extract the task identifiers from the first column of the inputs
        # to pass on to the evaluation of self.noise_covar
        xi = args[0][0]
        i = xi[..., [0]].long()
        
        # Conditional
        if torch.is_tensor(input):
            return super().__call__(input, i=i, *args, **kwargs)
        # Marginal
        elif isinstance(input, MultivariateNormal):
            return self.marginal(input, i=i, *args, **kwargs)
        # Error
        else:
            raise RuntimeError(
                "Likelihoods expects a MultivariateNormal or Normal input to make marginal predictions, or a "
                "torch.Tensor for conditional predictions. Got a {}".format(input.__class__.__name__)
            )


class HadamardHomoskedasticNoise(_HomoskedasticNoiseBase):
    r"""
    Noise for a Hadamard multitask GP regression with a different homoskedastic 
    noise for each task i:
    """
    def __init__(self, noise_prior=None, noise_constraint=None, num_tasks=1):
        super().__init__(noise_prior, noise_constraint, torch.Size(), num_tasks)

    def forward(self, *params, shape=None, noise=None, i=None, **kwargs):
        # Note: removed batching and additional checks/logic for simplicity
        
        # For each observation, pick the noise indicated by i
        noise = self.noise
        noise_diag = noise.expand(shape[0], len(noise)).contiguous()
        noise_diag = noise_diag.gather(-1, i).squeeze(-1)
        return DiagLazyTensor(noise_diag)

---
# train kernel parameters over entire dataset

In [None]:
# assume that batch size is 1
# value = values[0]
# mask = masks[0]

likelihood = HadamardGaussianLikelihood(num_tasks=4)
# as a dumy initializer, could be better
model = HadamardGP(None, None, likelihood, num_tasks=4, num_kernels=10, rank=4)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)


model.train()
likelihood.train()

optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

epoch = 10
fig_save_folder_path = "./figs_over_epoch"
os.makedirs(fig_save_folder_path, exist_ok=True)

# whether condition the imputer on single task or all tasks
condition_on_all_tasks = True

for i_epoch in range(epoch):
    for i_batch, sample in enumerate(train_loader):
        masks = sample[:, :, 4: -1]
        values = sample[:, :,:4]
        t = sample[0, :, -1]
        # go over each sequence within each batch
        loss = 0
        num_seq = 0
        optimizer.zero_grad()
        for mask, value in zip(masks, values):
            # go over each task within each sequence
            train_tasks = []
            train_ts = []
            train_xs = []
            for i in range(4):
                value_i = value[:, i]
                mask_i = mask[:, i]
                # print(value_i.shape)
                # print(mask_i.shape)
                train_ts.append(t[mask_i==1])
                train_xs.append(value_i[mask_i==1])
                # print(i, train_ts[-1].shape)

                train_task = torch.full((train_ts[-1].shape[0],1), dtype=torch.long, fill_value=i)
                train_tasks.append(train_task)

            full_train_t = torch.cat(train_ts)
            full_train_tasks = torch.cat(train_tasks)
            full_train_x = torch.cat(train_xs)
            num_seq += 1


            input_task_t = torch.cat((full_train_tasks, full_train_t.view(full_train_t.shape[0], 1)), dim=-1)
            target_x = full_train_x

            model.set_train_data(input_task_t, target_x, strict=False)
            output = model(input_task_t)
            loss += -mll(output, target_x, [input_task_t])
            # global_seq_counter += 1

        loss /= num_seq
        loss.backward()
        print('Iter %dth batch - Loss: %.3f' % (i_batch + 1, loss.item()))
        optimizer.step()

        # validation step
        if i_batch % 50 == 0:
            sample_val = next(iter(val_loader))
            masks_val = sample_val[:, :, 4: -1]
            values_val = sample_val[:, :,:4]
            t_val = sample_val[0, :, -1]

            gt_sample_val = next(iter(gt_validation_loader))
            gt_masks_val = gt_sample_val[:, :, 4: -1]
            gt_values_val = gt_sample_val[:, :,:4]
            gt_t_val = gt_sample_val[0, :, -1]


            value_val = values_val[0]
            mask_val = masks_val[0]
            
            ### condition the model on all context data
            if condition_on_all_tasks:
                val_tasks_context = []
                val_ts_context = []
                val_xs_context = []
                for i in range(4):
                    value_i_context = values_val[0][:, i]
                    mask_i_context = masks_val[0][:, i]
                    val_ts_context.append(t_val[mask_i_context==1])
                    val_xs_context.append(value_i_context[mask_i_context==1])
                    val_task_context = torch.full((val_ts_context[-1].shape[0],1), dtype=torch.long, fill_value=i)
                    val_tasks_context.append(val_task_context)
                full_val_t_context = torch.cat(val_ts_context)
                full_val_tasks_context = torch.cat(val_tasks_context)
                full_val_x_context = torch.cat(val_xs_context)
                input_task_t_context = torch.cat((full_val_tasks_context, full_val_t_context.view(full_val_t_context.shape[0], 1)), dim=-1)
                model.set_train_data(input_task_t_context, full_val_x_context, strict=False)

            
            f, axes = plt.subplots(4, 1, figsize=(10, 28))
            titles = ['Noise', 'Trend', "Seasonality", "Trend + Seasonality"]
            
            for i in range(4):
                
                ### condition the model on one task
                value_i_context = value_val[:, i]
                mask_i_context = mask_val[:, i]

                val_t_context = t_val[mask_i_context==1]
                val_x_context = value_i_context[mask_i_context==1]


                val_task_context = torch.full((val_t_context.shape[0],1), dtype=torch.long, fill_value=i)
                val_task_target = torch.full((t_val.shape[0],1), dtype=torch.long, fill_value=i)


                input_task_t_context = torch.cat((val_task_context, val_t_context.view(val_t_context.shape[0], 1)), dim=-1)
                input_task_t_target = torch.cat((val_task_target, t_val.view(t_val.shape[0], 1)), dim=-1)
                if not condition_on_all_tasks:
                    model.set_train_data(input_task_t_context, val_x_context, strict=False)
                
                with torch.no_grad(), gpytorch.settings.fast_pred_var():
                    # output_val = model(input_task_t)
                    model.eval() # necessary if you want to plot lower and upper bound!!!
                    observed_pred = likelihood(model(input_task_t_target), [input_task_t_target])
                    # loss_val = -mll(output_val, val_x, [input_task_t])
                    # print(f"val loss: {loss_val}")
                    # Plot training data as black stars
                    axes[i].plot(val_t_context.detach().numpy(), val_x_context.detach().numpy(), 'k*')
                    lower_2d, upper_2d = observed_pred.confidence_region()
                    std = (upper_2d - observed_pred.mean) / 2.0
                    upper_1d = observed_pred.mean + std
                    lower_1d = observed_pred.mean - std
                    
                    
                    # Predictive mean as blue line
                    axes[i].plot(t_val.detach().numpy(), observed_pred.mean.detach().numpy(), 'b')
                    # Shade in confidence
                    axes[i].fill_between(t_val.detach().numpy(), lower_1d.detach().numpy(), upper_1d.detach().numpy(), alpha=0.5)
                    # ax.set_ylim([-3, 3])
                    axes[i].plot(t_val.detach().numpy(), gt_values_val[0,:,i].detach().numpy(), 'r')
                    axes[i].legend(['Observed Data', 'Mean', 'Confidence', 'ground truth'])
                    axes[i].set_title(titles[i])
                    model.train()
                    for _ in range(100):
                        sampled_seq = observed_pred.sample()
                        axes[i].plot(t_val.detach().numpy(), sampled_seq.detach().numpy(), 'g', alpha=0.1)
                        
            print(f"subplots for {i_epoch}th epoch {i_batch}th batch: ")
            # plt.show()
            fig_path = os.path.join(fig_save_folder_path, f"plots_{i_epoch}e_{i_batch}b.png")
            f.savefig(fig_path)
            
            
            
            

        

