In [1]:
%load_ext autoreload
%aimport data
%autoreload 1

In [2]:
%load_ext tensorboard

In [3]:
from typing import Optional, Callable
from functools import partial
from itertools import takewhile
from pipe import select

from omegaconf import OmegaConf

import pandas as pd
import numpy as np

import plotly.express as px
import plotly.graph_objects as go

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split, StackDataset, Subset
from torch import optim
from torch import distributions as distr

import pyro
import pyro.nn as pnn
import pyro.distributions as pdistr
from pyro.distributions import Delta
import pyro.infer
from pyro.infer import Trace_ELBO
from pyro.infer.autoguide import AutoNormal

import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import EarlyStopping

from torchmetrics.classification import Accuracy

from bmm_multitask_learning.variational.elbo import MultiTaskElbo
from bmm_multitask_learning.variational.distr import build_predictive

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
config = OmegaConf.load("config.yaml")

In [5]:
torch.manual_seed(config.seed);

## Data

In [6]:
from data import build_linked_datasets, build_solo_dataset

NUM_MODELS = 3


datasets = [*build_linked_datasets(config.size, config.dim), build_solo_dataset(config.size, config.dim)]
# extract w
w_list = list(
    datasets | select(lambda w_dataset: w_dataset[0])
)
# extract (X, y) pairs
datasets = list(
    datasets |
    select(lambda w_dataset: w_dataset[1]) |
    select(lambda d: random_split(d, [1 - config.test_ratio, config.test_ratio]))
)

train_datasets, test_datasets = zip(*datasets)

In [7]:
# TODO: draw X and hyperplane
points_df = []

for i, dataset in enumerate(train_datasets):
    X = dataset.dataset.tensors[0].numpy()
    cur_df = pd.DataFrame(X, columns=["x", "y"])
    cur_df["dataset"] = str(i)
    points_df.append(cur_df)
points_df = pd.concat(points_df, axis=0)

points_df.head()

Unnamed: 0,x,y,dataset
0,-0.094634,-2.335958,0
1,10.453236,3.087356,0
2,-2.813742,-2.073199,0
3,-6.502639,-0.509825,0
4,-5.492137,-2.334638,0


In [8]:
plane_df = []
for i, w in enumerate(w_list):
    w = w.numpy()
    cur_df = pd.DataFrame(np.linspace(-20, 20, 10)[:, None] * w[None, :], columns=["x", "y"])
    cur_df["dataset"] = str(i)
    plane_df.append(cur_df)
plane_df = pd.concat(plane_df, axis=0)

In [9]:
fig_points = px.scatter(points_df, x="x", y="y", color="dataset", symbol="dataset")
fig_plane = px.line(plane_df, x="x", y="y", color="dataset")
fig_plane.update_traces(line=dict(width=3))
fig = go.Figure(data=fig_points.data + fig_plane.data)
fig.show()

## Variational multitask models

Register KL computation for $\delta$-distribution. Here we don't have state variables and don't count its KL at all

In [10]:
@distr.kl.register_kl(Delta, Delta)
def kl_delta_delta(d1: Delta, d2: Delta):
    return torch.zeros(d1.batch_shape)
    #return torch.zeros(d1.batch_shape) if torch.allclose(d1.v, d2.v) else torch.full(torch.inf, d1.batch_shape)

Define batched distributions for tasks

In [11]:
# same for all tasks
def target_distr(Z: torch.Tensor, W: torch.Tensor) -> distr.Distribution:
    return distr.Bernoulli(logits=torch.tensordot(Z, W.swapaxes(0, -1), dims=1))

# same for all tasks
def predictive_distr(Z: torch.Tensor, W: torch.Tensor) -> distr.Distribution:
    return distr.Bernoulli(logits=torch.tensordot(Z, W.swapaxes(0, -1), dims=1).flatten(1, 2))

# we don't have latents here, but we need it formaly as distribution
def latent_distr(X: torch.Tensor) -> distr.Distribution:
    return Delta(X, event_dim=1)

In [None]:
task_distrs = [target_distr for _ in range(NUM_MODELS)]
task_num_samples = list(train_datasets | select(len))

# parametric variational distr
cl_loc = torch.zeros((config.dim, ))
cl_scale = torch.ones((config.dim, ))
classifier_distrs = [
    distr.Independent(
        distr.Normal(
            nn.Parameter(cl_loc + 1e-1 * torch.rand_like(cl_loc)),
            nn.Parameter(cl_scale + 1e-1 * torch.rand_like(cl_scale))
        ),
        reinterpreted_batch_ndims=1
    )
    for _ in range(NUM_MODELS)
]
latent_distrs = [latent_distr for _ in range(NUM_MODELS)]

temp_scheduler = lambda step: 1.

In [13]:
mt_elbo = MultiTaskElbo(
    task_distrs,
    task_num_samples,
    classifier_distrs,
    latent_distrs,
    temp_scheduler=temp_scheduler,
    **dict(config.mt_elbo)
)

In [14]:
class LitMtModel(L.LightningModule):
    def __init__(
        self,
        mt_elbo: MultiTaskElbo
    ):
        super().__init__()

        num_tasks = mt_elbo.num_tasks
        self.accuracy_computers = [Accuracy('binary') for _ in range(num_tasks)]
        self.mt_elbo = mt_elbo

    def training_step(self, batch: tuple[tuple[torch.Tensor]], batch_idx: int):
        mt_elbo_loss = self.mt_elbo(*list(zip(*batch)), step=self.global_step)
        self.log("Train/ELBO", mt_elbo_loss, prog_bar=True)

        return mt_elbo_loss
    
    def on_train_epoch_end(self):
        # log mixing params
            
        fig, ax = plt.subplots()
        cax = ax.matshow(self.mt_elbo.latent_mixings_params.detach().numpy())
        fig.colorbar(cax)
        self.logger.experiment.add_figure("Latent_mixing", fig, self.global_step)

        fig, ax = plt.subplots()
        cax = ax.matshow(self.mt_elbo.classifier_mixings_params.detach().numpy())
        fig.colorbar(cax)
        self.logger.experiment.add_figure("Classifier_mixing", fig, self.global_step)

    def validation_step(self, batch: tuple[tuple[torch.Tensor]], batch_idx: int):
        for i, (X, y) in enumerate(batch):
            NUM_PREDICTIVE_SAMPLES = 20
            cur_predictive = build_predictive(
                predictive_distr,
                classifier_distrs[i],
                latent_distrs[i],
                X,
                config.mt_elbo.classifier_num_particles,
                config.mt_elbo.latent_num_particles
            )
            y_pred = (cur_predictive.sample((NUM_PREDICTIVE_SAMPLES, )).mean(dim=0) > 0.5).float()
            self.accuracy_computers[i].update(y_pred, y)

    def on_validation_epoch_end(self):
        for i, accuracy_computer in enumerate(self.accuracy_computers):
            self.log(f"Test/Accuracy_{i}", accuracy_computer.compute())
            accuracy_computer.reset()

    def configure_optimizers(self):
        return optim.Adam(self.mt_elbo.parameters(), lr=1e-3)

In [15]:
lit_mt_model = LitMtModel(mt_elbo)

In [16]:
# stack task datasets
unified_train_dataset = StackDataset(*train_datasets)
unified_test_dataset = StackDataset(*test_datasets)

mt_train_dataloader = DataLoader(unified_train_dataset, config.batch_size, shuffle=True)
mt_test_dataloader = DataLoader(unified_test_dataset, config.batch_size, shuffle=False)

In [17]:
logger = TensorBoardLogger("mt_logs")

callbacks = [
    #EarlyStopping(monitor="Train/ELBO", min_delta=1e-3, patience=5, mode="min")
]

In [18]:
trainer = L.Trainer(logger=logger, callbacks=callbacks, **dict(config.trainer))
trainer.fit(lit_mt_model, mt_train_dataloader, mt_test_dataloader)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs



GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


  | Name    | Type          | Params | Mode 
--------------------------------------------------
0 | mt_elbo | MultiTaskElbo | 9      | train
--------------------------------------------------
9         Trainable params
0         Non-trainable params
9         Total params
0.000     Total estimated model params size (MB)
1         Modules in train mode
0         Modules in eval mode


                                                                            


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.



Epoch 19: 100%|██████████| 15/15 [00:00<00:00, 87.86it/s, v_num=1, Train/ELBO=1.42e+3] 

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 15/15 [00:00<00:00, 48.90it/s, v_num=1, Train/ELBO=1.42e+3]


Plot metrics: see tensorboard

In [19]:
for cl_d, w_true in zip(classifier_distrs, w_list):
    print("Variational:", cl_d.mean.data, cl_d.stddev.data)
    print("True:", w_true)

Variational: tensor([0.0251, 0.0882]) tensor([1.0528, 1.0676])
True: tensor([-0.4690,  0.0904])
Variational: tensor([0.0303, 0.0645]) tensor([1.0993, 1.0573])
True: tensor([-0.4459,  0.0512])
Variational: tensor([0.0089, 0.0741]) tensor([1.0050, 1.0892])
True: tensor([1.0489, 0.9532])
