In [1]:
%load_ext autoreload
%aimport data
%autoreload 1

In [None]:
from typing import Optional, Callable
from functools import partial
from itertools import takewhile
from pipe import select

from omegaconf import OmegaConf

import pandas as pd
import numpy as np

import plotly.express as px

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split, StackDataset
from torch import optim
from torch import distributions as distr

import pyro
import pyro.nn as pnn
import pyro.distributions as pdistr
from pyro.distributions import Delta
import pyro.infer
from pyro.infer import Trace_ELBO
from pyro.infer.autoguide import AutoNormal

import lightning as L
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import EarlyStopping

from torchmetrics.classification import Accuracy

from bmm_multitask_learning.variational.elbo import MultiTaskElbo
from bmm_multitask_learning.variational.distr import build_predictive

In [3]:
config = OmegaConf.load("config.yaml")

In [4]:
torch.manual_seed(config.seed);

## Data

In [5]:
from data import build_linked_datasets, build_solo_dataset

NUM_MODELS = 3

train_datasets, test_datasets = zip(*
    (*build_linked_datasets(config.size, config.dim), build_solo_dataset(config.size, config.dim)) |
    select(lambda dataset: random_split(dataset, [1 - config.test_ratio, config.test_ratio]))
)

## Solo models

In [6]:
class SoloModel(pnn.PyroModule):
    def __init__(
        self,
        dim: int = 2,
        num_data_samples: Optional[int] = None
    ):
        super().__init__()

        self.num_data_samples = num_data_samples

        # set parametric prior on w
        self.w_loc = pnn.PyroParam(torch.zeros((dim, )))
        self.log_w_scale = pnn.PyroParam(torch.zeros((dim, )))
        self.w = pnn.PyroSample(
            lambda self: pdistr.Normal(self.w_loc, torch.exp(self.log_w_scale)).to_event(1)
        )

    def forward(self, X: torch.Tensor, y: torch.Tensor = None):
        batch_size = X.shape[0]
        if self.num_data_samples:
            size = self.num_data_samples
            subsample_size = batch_size
        else:
            size = batch_size
            subsample_size = None

        p = torch.sigmoid(X.matmul(self.w))
        with pyro.plate("data_batch", size=size, subsample_size=subsample_size):
            pyro.sample("y", pdistr.Bernoulli(p), obs=y)

In [7]:
class LitSoloModel(L.LightningModule):
    def __init__(
        self,
        elbo_f: pyro.infer.elbo.ELBOModule,
        predictive: pyro.infer.Predictive,
        num_data_samples: int,
    ):
        super().__init__()

        self.num_data_samples = num_data_samples
        self.accuracy_computer = Accuracy('binary')

        self.elbo_f = elbo_f
        self.model: SoloModel = elbo_f.model
        self.guide = elbo_f.guide

        self.predictive = predictive

    def training_step(self, batch: tuple[torch.Tensor], batch_idx: int):
        X, y = batch

        elbo_loss = self.elbo_f(X, y)

        self.log("Train/ELBO", elbo_loss, prog_bar=True)

        return elbo_loss
    
    def validation_step(self, batch: tuple[torch.Tensor], batch_idx: int):
        X, y = batch

        y_pred = (self.predictive(X, y=None)["y"].mean(dim=0) > 0.5).to(torch.float32)

        self.accuracy_computer.update(y_pred, y)

    def on_validation_epoch_end(self):
        self.log("Test/Accuracy", self.accuracy_computer.compute())
        self.accuracy_computer.reset()  

    def configure_optimizers(self):
        return optim.Adam(self.elbo_f.parameters())

Now train individual models

In [8]:
for i in range(NUM_MODELS):
    print(f"Training model {i}\n")

    model = SoloModel(config.dim, config.size)
    guide = AutoNormal(model)

    # num elbo particles is equivallent to variational multitask
    num_elbo_particles = config.classifier_num_particles * config.latent_num_particles
    elbo_f = Trace_ELBO(num_elbo_particles)(model, guide)

    # All relevant parameters need to be initialized before ``configure_optimizer`` is called.
    # Since we used AutoNormal guide our parameters have not be initialized yet.
    # Therefore we initialize the model and guide by running one mini-batch through the loss.
    mini_batch = next(iter(DataLoader(train_datasets[0], batch_size=1)))
    elbo_f(*mini_batch)

    # this choice of num_predictive_particles is rather balancing
    num_predictive_particles = num_elbo_particles
    predictive = pyro.infer.Predictive(model, guide=guide, num_samples=num_predictive_particles)

    lit_model = LitSoloModel(elbo_f, predictive, config.size)
 
    train_dataloader = DataLoader(train_datasets[i], batch_size=config.batch_size, shuffle=True)
    test_dataloader = DataLoader(test_datasets[i], batch_size=config.batch_size)

    logger = CSVLogger("solo_logs", f"model_{i}")

    callbacks = [
        EarlyStopping(monitor="Train/ELBO", min_delta=1e-3, patience=5, mode="min")
    ]

    trainer = L.Trainer(logger=logger, callbacks=callbacks, **dict(config.trainer))
    trainer.fit(lit_model, train_dataloader, test_dataloader)

Training model 0



GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type           | Params | Mode 
-------------------------------------------------------------
0 | accuracy_computer | BinaryAccuracy | 0      | train
1 | elbo_f            | ELBOModule     | 8      | train
2 | model             | SoloModel      | 4      | train
3 | guide             | AutoNormal     | 4      | train
4 | predictive        | Predictive     | 8      | train
-------------------------------------------------------------
8         Trainable params
0         Non-trainable params
8         Total params
0.000     Total estimated model params size (MB)
7         Modules in train mode
0         Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/home/kirill/.cache/pypoetry/virtualenvs/bmm-multitask-learning-_LQwnQjl-py3.12/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


                                                                           

/home/kirill/.cache/pypoetry/virtualenvs/bmm-multitask-learning-_LQwnQjl-py3.12/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/kirill/.cache/pypoetry/virtualenvs/bmm-multitask-learning-_LQwnQjl-py3.12/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (7) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 9: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s, v_num=0, Train/ELBO=131.0]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s, v_num=0, Train/ELBO=131.0]
Training model 1



GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type           | Params | Mode 
-------------------------------------------------------------
0 | accuracy_computer | BinaryAccuracy | 0      | train
1 | elbo_f            | ELBOModule     | 8      | train
2 | model             | SoloModel      | 4      | train
3 | guide             | AutoNormal     | 4      | train
4 | predictive        | Predictive     | 8      | train
-------------------------------------------------------------
8         Trainable params
0         Non-trainable params
8         Total params
0.000     Total estimated model params size (MB)
7         Modules in train mode
0         Modules in eval mode


Epoch 9: 100%|██████████| 7/7 [00:04<00:00,  1.65it/s, v_num=0, Train/ELBO=108.0]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 7/7 [00:04<00:00,  1.65it/s, v_num=0, Train/ELBO=108.0]
Training model 2



GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type           | Params | Mode 
-------------------------------------------------------------
0 | accuracy_computer | BinaryAccuracy | 0      | train
1 | elbo_f            | ELBOModule     | 8      | train
2 | model             | SoloModel      | 4      | train
3 | guide             | AutoNormal     | 4      | train
4 | predictive        | Predictive     | 8      | train
-------------------------------------------------------------
8         Trainable params
0         Non-trainable params
8         Total params
0.000     Total estimated model params size (MB)
7         Modules in train mode
0         Modules in eval mode


Epoch 9: 100%|██████████| 7/7 [00:03<00:00,  1.91it/s, v_num=0, Train/ELBO=197.0]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 7/7 [00:03<00:00,  1.91it/s, v_num=0, Train/ELBO=197.0]


In [9]:
models_log = [pd.read_csv(f"solo_logs/model_{i}/version_0/metrics.csv") for i in range(NUM_MODELS)]

In [10]:
models_log[0].head()

Unnamed: 0,Test/Accuracy,Train/ELBO,epoch,step
0,0.625,,0,6
1,0.75,,1,13
2,0.6,,2,20
3,0.6,,3,27
4,0.75,,4,34


## Variational multitask models

Define batched distributions for tasks

In [None]:
# same for all tasks
def target_distr(Z: torch.Tensor, W: torch.Tensor) -> distr.Distribution:
    return distr.Bernoulli(logits=torch.tensordot(Z, W.swapaxes(0, -1), dims=1))

# same for all tasks
def predictive_distr(Z: torch.Tensor, W: torch.Tensor) -> distr.Distribution:
    return distr.Bernoulli(logits=torch.tensordot(Z, W.swapaxes(0, -1), dims=1).flatten(1, 2))

# we don't have latents here, but we need it formaly as distribution
def latent_distr(X: torch.Tensor) -> distr.Distribution:
    return Delta(X, event_dim=1)

Register KL computation for $\delta$-distribution for correct elbo computations

In [None]:
@distr.kl.register_kl(Delta, Delta)
def kl_delta_delta(d1: Delta, d2: Delta):
    return torch.zeros(d1.batch_shape) if torch.allclose(d1.v, d2.v) else torch.full(torch.inf, d1.batch_shape)

In [None]:
task_distrs = [target_distr] * NUM_MODELS
task_num_samples = list(train_datasets | select(len))

# parametric variational distr
cl_loc = torch.zeros((config.dim, ))
cl_scale = torch.ones((config.dim, ))
classifier_distrs = [
    distr.Independent(
        distr.Normal(
            nn.Parameter(cl_loc + 1e-1 * torch.rand_like(cl_loc)),
            nn.Parameter(cl_scale + 1e-1 * torch.rand_like(cl_scale))
        ),
        reinterpreted_batch_ndims=1
    )
    for _ in range(NUM_MODELS)
]
latent_distrs = [latent_distr] * NUM_MODELS

temp_scheduler = lambda step: 1.

In [None]:
mt_elbo = MultiTaskElbo(
    task_distrs,
    task_num_samples,
    classifier_distrs,
    latent_distrs,
    temp_scheduler=temp_scheduler,
    **dict(config.mt_elbo)
)

In [None]:
class LitMtModel(L.LightningModule):
    def __init__(
        self,
        mt_elbo: MultiTaskElbo
    ):
        super().__init__()

        num_tasks = mt_elbo.num_tasks

        self.accuracy_computers = [Accuracy('binary')] * num_tasks

        self.mt_elbo = mt_elbo

    def training_step(self, batch: tuple[tuple[torch.Tensor]], batch_idx: int):
        mt_elbo_loss = self.mt_elbo(*list(zip(*batch)), step=self.global_step)

        self.log("Train/ELBO", mt_elbo_loss, prog_bar=True)

        return mt_elbo_loss

    def validation_step(self, batch: tuple[tuple[torch.Tensor]], batch_idx: int):
        for i, (X, y) in enumerate(batch):
            cur_predictive = build_predictive(
                predictive_distr,
                classifier_distrs[i],
                latent_distrs[i],
                X,
                config.mt_elbo.classifier_num_particles,
                config.mt_elbo.latent_num_particles
            )
            # TODO: sample more?
            y_pred = cur_predictive.sample()

            self.accuracy_computers[i].update(y_pred, y)

    def on_validation_epoch_end(self):
        for i, accuracy_computer in enumerate(self.accuracy_computers):
            self.log(f"Test/Accuracy_{i}", accuracy_computer.compute())
            accuracy_computer.reset()

    def configure_optimizers(self):
        return optim.Adam(self.mt_elbo.parameters())

In [None]:
lit_mt_model = LitMtModel(mt_elbo)

In [None]:
# stack task datasets
unified_train_dataset = StackDataset(*train_datasets)
unified_test_dataset = StackDataset(*test_datasets)

mt_train_dataloader = DataLoader(unified_train_dataset, config.batch_size, shuffle=True)
mt_test_dataloader = DataLoader(unified_test_dataset, config.batch_size, shuffle=False)

In [None]:
logger = CSVLogger("mt_logs")

callbacks = [
    EarlyStopping(monitor="Train/ELBO", min_delta=1e-3, patience=5, mode="min")
]

In [None]:
trainer = L.Trainer(logger=logger, callbacks=callbacks, **dict(config.trainer))
trainer.fit(lit_model, mt_train_dataloader, mt_test_dataloader)

Plot metrics

In [None]:
metrics = pd.read_csv("mt_logs/lightning_logs/version_0/metrics.csv")

for col in takewhile(lambda col: col != "step", metrics.columns):
    fig = px.line(metrics[[col, "step"]].dropna(), x="step", y=col, title=col, markers=True)
    fig.show()