## [Clustered Multitask GP (w/ Pyro/GPyTorch High-Level Interface)](https://docs.gpytorch.ai/en/stable/examples/07_Pyro_Integration/Clustered_Multitask_GP_Regression.html#Clustered-Multitask-GP-(w/-Pyro/GPyTorch-High-Level-Interface))

In this example, we use the Pyro integration for a GP model with additional latent variables.

We are modelling a multitask GP in this example. Rather than assuming a linear correlation among the different tasks, we assume that there is cluster structure for the different tasks. Let’s assume there are k different clusters of tasks. The generative model for task i is:

$$p(\mathbf y_i \mid \mathbf x_i) = \int \sum_{z_i=1}^k p(\mathbf y_i \mid \mathbf f (\mathbf x_i), z_i) \: p(z_i) \: p(\mathbf f (\mathbf x_i) ) \: d \mathbf f$$

where $\mathbf z_i$ is the cluster assignment for task $i$. there are k latent functions $\mathbf f = [f_1, ..., f_k]$ each modelled by a GP, representing each cluster

Our goal is therefore to infer:

The latent functions $f_1, …, f_k$
The cluster assignments $z_i$ for each task

In [1]:
import math
import torch
import pyro
import tqdm
import gpytorch

In [2]:
from matplotlib import pyplot as plt
%matplotlib inline
import matplotlib as mpl
mpl.rc_file_defaults()

Customized Likelihood $\rightarrow \sum_{z_i=1}^k p(\mathbf y_i \mid \mathbf f (\mathbf x_i), z_i) \: p(z_i)$


GPyTorch’s likelihoods are capable of modeling additional latent variables. Our custom likelihood needs to define the following three functions:

1. pyro_model (needs to call through to super().pyro_model at the end), which defines the prior distribution for additional latent variables

2. pyro_guide (needs to call through to super().pyro_guide at the end), which defines the variational (guide) distribution for additional latent variables

3. forward, which defines the observation distributions conditioned on $\mathbf f (\mathbf x_i)$ and any additional latent variables.

In [3]:
def pyro_model(self, fun_dist, target):
    cluster_assignment_samples = pyro.sample(
        self.name_prefix + ".cluster_logits",
        pyro.distributions.OneHotCategorical(logits=self.prior_cluster_logits).to_event(1)
    )

    return super().pyro_model(
        fun_dist,
        target,
        cluster_assignment_samples = cluster_assignment_samples
    )

In [None]:
def pyro_guide(self, fun_dist, target):
    pyro.sample(
        self.name_prefix + ".cluster_logits",
        pyro.distributions.OneHotCategorical(logits=self.variational_cluster_logits).to_event(1)
    )

    return super().pyro_guide(fun_dist, target)