# Variational multitask learning elementary example

This is a practical demonstration on how to use `variational` subpackage on a simple classification example. We are going to solve 3 classification tasks with logistic regression as a model. Additionally, we will add prior on the weight so the tasks become bayessian. Two of the tasks will be probabilsitcally connected, the last will have no probabilistic connections with others.

First, we will apply variational principle to learn each task individually. We will use [`pyro`](https://pyro.ai/examples/index.html) package to automatically compute ELBO and minimize it.

Secondly, we will use `variational` subpackage and learn 3 tasks alltogether. Learning here is the same ELBO minimizing, but for special variational structure - see [doc](../../../docs/variational/intro.md) for more details.

Lastly, we will compare two approaches in accuracy terms.

In [None]:
from pipe import select

from omegaconf import OmegaConf

import pandas as pd
import numpy as np

import plotly.express as px
import plotly.graph_objects as go

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, StackDataset
from torch import optim
from torch import distributions as distr

import pyro
import pyro.nn as pnn
import pyro.distributions as pdistr
from pyro.distributions import Delta
from pyro.infer import Trace_ELBO
from pyro.infer.autoguide import AutoNormal

import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import EarlyStopping

from torchmetrics.classification import Accuracy

from bmm_multitask_learning.variational.elbo import MultiTaskElbo
from bmm_multitask_learning.variational.distr import build_predictive

In [2]:
%load_ext tensorboard

The experiment is configurable via [yaml config](config.yaml).

In [None]:
config = OmegaConf.load("config.yaml")

In [4]:
torch.manual_seed(config.seed);

## Data

For each task we generate 2-dimensional inputs $X$ where each individual input $x$ is

$$
    x \sim \mathcal{N}(\mu, \sigma_x^2)
$$

For task 1 and 2 these are very close to each other.

Then we generate logistic regression parameter $w$ as

$$
    w_1 \sim \mathcal{N}(\mathbb{E}[X_2], \sigma_w^2) \\ 
    w_2 \sim \mathcal{N}(\mathbb{E}[X_1], \sigma_w^2) \\
    w_3 \sim \mathcal{N}(\mathbf{1}, \sigma_w^2)
$$

Because $X_1$ and $X_2$ are close, distribuitions for $w_1$ and $w_2$ will be close too. This connection can be utilized by variational multitask approach.

Finally, label $y$ for input $x$ generated as 

$$
    y \sim \text{Bern}(\sigma(x^{T}w))
$$

Data generation with given schema is defined in [data.py](data.py)

In [5]:
from data import build_linked_datasets, build_solo_dataset

NUM_MODELS = 3


datasets = [*build_linked_datasets(config.size, config.dim), build_solo_dataset(config.size, config.dim)]
# extract w
w_list = list(
    datasets | select(lambda w_dataset: w_dataset[0])
)
# extract (X, y) pairs
datasets = list(
    datasets |
    select(lambda w_dataset: w_dataset[1]) |
    select(lambda d: random_split(d, [1 - config.test_ratio, config.test_ratio]))
)

train_datasets, test_datasets = zip(*datasets)

Let's vizualize generated inputs and $w$ vectors for each task

In [6]:
points_df = []

for i, dataset in enumerate(train_datasets):
    X = dataset.dataset.tensors[0].numpy()
    cur_df = pd.DataFrame(X, columns=["x1", "x2"])
    cur_df["dataset"] = str(i)
    points_df.append(cur_df)
points_df = pd.concat(points_df, axis=0)

points_df.head()

Unnamed: 0,x1,x2,dataset
0,-0.094634,-2.335958,0
1,10.453235,3.087356,0
2,-2.813742,-2.073199,0
3,-6.502639,-0.509825,0
4,-5.492137,-2.334638,0


In [7]:
plane_df = []
for i, w in enumerate(w_list):
    w = w.numpy()
    cur_df = pd.DataFrame(np.linspace(-20, 20, 10)[:, None] * w[None, :], columns=["x1", "x2"])
    cur_df["dataset"] = str(i+1)
    plane_df.append(cur_df)
plane_df = pd.concat(plane_df, axis=0)

plane_df.head()

Unnamed: 0,x1,x2,dataset
0,9.379683,-1.807698,1
1,7.295309,-1.405987,1
2,5.210935,-1.004277,1
3,3.126561,-0.602566,1
4,1.042187,-0.200855,1


In [8]:
fig_points = px.scatter(points_df, x="x1", y="x2", color="dataset", symbol="dataset")
fig_plane = px.line(plane_df, x="x1", y="x2", color="dataset")
fig_plane.update_traces(line=dict(width=3))
fig = go.Figure(data=fig_points.data + fig_plane.data)
fig.show()

As we can see inputs for tasks 1 and 2 are indeed close. Because of small $\sigma_w$ we got $w_1$ and $w_2$ also close. 

## Solo models

In [9]:
# here we define probabilistic model of each task using pyro
class SoloModel(pnn.PyroModule):
    def __init__(
        self,
        dim: int = 2,
        num_data_samples: Optional[int] = None
    ):
        super().__init__()

        self.num_data_samples = num_data_samples

        # set parametric prior on w
        self.w_loc = pnn.PyroParam(torch.zeros((dim, )))
        self.log_w_scale = pnn.PyroParam(torch.zeros((dim, )))
        self.w = pnn.PyroSample(
            lambda self: pdistr.Normal(self.w_loc, torch.exp(self.log_w_scale)).to_event(1)
        )

    def forward(self, X: torch.Tensor, y: torch.Tensor = None):
        batch_size = X.shape[0]
        if self.num_data_samples:
            size = self.num_data_samples
            subsample_size = batch_size
        else:
            size = batch_size
            subsample_size = None

        p = torch.sigmoid(X.matmul(self.w))
        with pyro.plate("data_batch", size=size, subsample_size=subsample_size):
            pyro.sample("y", pdistr.Bernoulli(p), obs=y)

We are going to use [`lightning`](https://lightning.ai/docs/pytorch/stable/) to perform training and logging

In [10]:
class LitSoloModel(L.LightningModule):
    def __init__(
        self,
        elbo_f: pyro.infer.elbo.ELBOModule,
        predictive: pyro.infer.Predictive,
        num_data_samples: int,
    ):
        super().__init__()

        self.num_data_samples = num_data_samples
        self.accuracy_computer = Accuracy('binary')

        self.elbo_f = elbo_f
        self.model: SoloModel = elbo_f.model
        self.guide = elbo_f.guide

        self.predictive = predictive

    def training_step(self, batch: tuple[torch.Tensor], batch_idx: int):
        X, y = batch

        elbo_loss = self.elbo_f(X, y)

        self.log("Train/ELBO", elbo_loss, prog_bar=True)

        return elbo_loss
    
    def validation_step(self, batch: tuple[torch.Tensor], batch_idx: int):
        X, y = batch

        y_pred = (self.predictive(X, y=None)["y"].mean(dim=0) > 0.5).to(torch.float32)

        self.accuracy_computer.update(y_pred, y)

    def on_validation_epoch_end(self):
        self.log("Test/Accuracy", self.accuracy_computer.compute())
        self.accuracy_computer.reset()  

    def configure_optimizers(self):
        return optim.Adam(self.elbo_f.parameters())

Now train individual models

In [11]:
for i in range(NUM_MODELS):
    print(f"Training model {i}\n")

    model = SoloModel(config.dim, config.size)
    guide = AutoNormal(model)

    # num elbo particles is equivallent to variational multitask
    num_elbo_particles = config.mt_elbo.classifier_num_particles * config.mt_elbo.latent_num_particles
    
    elbo_f = Trace_ELBO(num_elbo_particles)(model, guide)

    # All relevant parameters need to be initialized before ``configure_optimizer`` is called.
    # Since we used AutoNormal guide our parameters have not be initialized yet.
    # Therefore we initialize the model and guide by running one mini-batch through the loss.
    mini_batch = next(iter(DataLoader(train_datasets[0], batch_size=1)))
    elbo_f(*mini_batch)

    # this choice of num_predictive_particles is rather balancing
    num_predictive_particles = num_elbo_particles
    predictive = pyro.infer.Predictive(model, guide=guide, num_samples=num_predictive_particles)

    lit_model = LitSoloModel(elbo_f, predictive, config.size)
 
    train_dataloader = DataLoader(train_datasets[i], batch_size=config.batch_size, shuffle=True)
    test_dataloader = DataLoader(test_datasets[i], batch_size=config.batch_size)

    logger = TensorBoardLogger("mt_logs/solo", name=f"solo_{i}")

    callbacks = [
        EarlyStopping(monitor="Train/ELBO", min_delta=1e-3, patience=10, mode="min")
    ]

    trainer = L.Trainer(logger=logger, callbacks=callbacks, **dict(config.trainer))
    trainer.fit(lit_model, train_dataloader, test_dataloader)

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


  | Name              | Type           | Params | Mode 
-------------------------------------------------------------
0 | accuracy_computer | BinaryAccuracy | 0      | train
1 | elbo_f            | ELBOModule     | 8      | train
2 | model             | SoloModel      | 4      | train
3 | guide             | AutoNormal     | 4      | train
4 | predictive        | Predictive     | 8      | train
-------------------------------------------------------------
8         Trainable params
0         Non-trainable params
8         Total params
0.000     Total estimated model params size (MB)
7         Modules in train mode
0         Modules i

Training model 0



Sanity Checking: |          | 0/? [00:00<?, ?it/s]


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name              | Type           | Params | Mode 
-------------------------------------------------------------
0 | accuracy_computer | BinaryAccuracy | 0      | train
1 | elbo_f            | ELBOModule     | 8      | train
2 | model             | SoloModel      | 4      | train
3 | guide             | AutoNormal     | 4      | train
4 | predictive        | Predictive     | 8      | train
-------------------------------------------------------------
8         Trainable params
0         Non-trainable params
8         Total params
0.000     Total estimated model params size (MB)
7         Modules in train mode
0         Modules in eval mode


Training model 1



Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name              | Type           | Params | Mode 
-------------------------------------------------------------
0 | accuracy_computer | BinaryAccuracy | 0      | train
1 | elbo_f            | ELBOModule     | 8      | train
2 | model             | SoloModel      | 4      | train
3 | guide             | AutoNormal     | 4      | train
4 | predictive        | Predictive     | 8      | train
-------------------------------------------------------------
8         Trainable params
0         Non-trainable params
8         Total params
0.000     Total estimated model params size (MB)
7         Modules in train mode
0         Modules in eval mode


Training model 2



Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metrics are saved in the tensorboard

In [27]:
%tensorboard --logdir mt_logs/solo

Reusing TensorBoard on port 6007 (pid 35928), started 0:00:27 ago. (Use '!kill 35928' to kill it.)

## Variational multitask models

First, register KL computation for $\delta$-distribution. Here we don't have state variables and won't compute its KL at all.

In [13]:
@distr.kl.register_kl(Delta, Delta)
def kl_delta_delta(d1: Delta, d2: Delta):
    return torch.zeros(d1.batch_shape)
    # this is how it should be
    # return torch.zeros(d1.batch_shape) if torch.allclose(d1.v, d2.v) else torch.full(torch.inf, d1.batch_shape)

Define batched distributions for tasks. Here we assume that *latents* come with (batch_size, num_latent_particles, ...) shape, *classifiers* come with (num_classifier_samples, ...) shape

In [14]:
# same for all tasks
def target_distr(Z: torch.Tensor, W: torch.Tensor) -> distr.Distribution:
    return distr.Bernoulli(logits=torch.tensordot(Z, W, dims=[[-1], [-1]]))

# same for all tasks
def predictive_distr(Z: torch.Tensor, W: torch.Tensor) -> distr.Distribution:
    return distr.Bernoulli(logits=torch.tensordot(Z, W, dims=[[-1], [-1]]).flatten(1, 2))

In [15]:
task_distrs = [target_distr for _ in range(NUM_MODELS)]
task_num_samples = list(train_datasets | select(len))

In [16]:
# we don't have latents here, but we need it formaly as distribution
def latent_distr(X: torch.Tensor) -> distr.Distribution:
    return Delta(X, event_dim=1)

In [17]:
class NormalLogits(distr.Normal):
    """Normal distribution with scale parametrized via logits
    """
    def __init__(self, loc, logit, validate_args=None):
        self.logit = logit
        super().__init__(loc, torch.exp(logit), validate_args)

    def __getattribute__(self, name):
        if name == "scale":
            return self.logit.exp()
        else:
            return super().__getattribute__(name)


# parametric variational distr for classifiers
classifier_distrs_params = {}
classifier_distrs = []
for i in range(NUM_MODELS):
    # set inital values for distribution's parameters
    loc, scale_logit = nn.Parameter(torch.zeros((config.dim, ))), nn.Parameter(torch.zeros((config.dim, )))
    classifier_distrs_params.update({
        f"distr_{i}": [loc, scale_logit]
    })
    classifier_distrs.append(
        distr.Independent(
            NormalLogits(loc, scale_logit),
            reinterpreted_batch_ndims=1
        )
    )
# parametric variational distr for latents
latent_distrs = [latent_distr for _ in range(NUM_MODELS)]

In [18]:
# temperature must decrease over steps
temp_scheduler = lambda step: 1. / torch.sqrt(torch.tensor(step + 1))

In [19]:
# create variational multitask elbo module
mt_elbo = MultiTaskElbo(
    task_distrs,
    task_num_samples,
    classifier_distrs,
    latent_distrs,
    temp_scheduler=temp_scheduler,
    **dict(config.mt_elbo)
)

We are going to use [`lightning`](https://lightning.ai/docs/pytorch/stable/) to perform training and logging

In [20]:
class LitMtModel(L.LightningModule):
    def __init__(
        self,
        mt_elbo: MultiTaskElbo
    ):
        super().__init__()

        num_tasks = mt_elbo.num_tasks
        self.accuracy_computers = [Accuracy('binary') for _ in range(num_tasks)]
        self.mt_elbo = mt_elbo

        self.distr_params = nn.ParameterList()
        for param_list in classifier_distrs_params.values():
            self.distr_params.extend(
                param_list
            )
        

    def training_step(self, batch: tuple[tuple[torch.Tensor]], batch_idx: int):
        mt_loss_dict = self.mt_elbo(*list(zip(*batch)), step=self.global_step)
        self.log_dict(mt_loss_dict, prog_bar=True)

        # DEBUG
        return mt_loss_dict["elbo"]
    
    def on_train_batch_end(self, outputs, batch, batch_idx):
        with torch.no_grad():
            for distr_name, distr_params in classifier_distrs_params.items():
                params_grad_norm = sum(distr_params | select(lambda param: param.grad.norm()))
                self.log(f"{distr_name}_grad", params_grad_norm)

    def on_train_epoch_end(self):
        # log mixing

        fig, ax = plt.subplots()
        cax = ax.matshow(self.mt_elbo.latent_mixings_params.detach().numpy())
        fig.colorbar(cax)
        self.logger.experiment.add_figure("Latent_mixing", fig, self.global_step)

        fig, ax = plt.subplots()
        cax = ax.matshow(self.mt_elbo.classifier_mixings_params.detach().numpy())
        fig.colorbar(cax)
        self.logger.experiment.add_figure("Classifier_mixing", fig, self.global_step)

    def validation_step(self, batch: tuple[tuple[torch.Tensor]], batch_idx: int):
        for i, (X, y) in enumerate(batch):
            NUM_PREDICTIVE_SAMPLES = 10
            cur_predictive = build_predictive(
                predictive_distr,
                classifier_distrs[i],
                latent_distrs[i],
                X,
                config.mt_elbo.classifier_num_particles,
                config.mt_elbo.latent_num_particles
            )
            y_pred = (cur_predictive.sample((NUM_PREDICTIVE_SAMPLES, )).mean(dim=0) > 0.5).float()
            self.accuracy_computers[i].update(y_pred, y)

    def on_validation_epoch_end(self):
        for i, accuracy_computer in enumerate(self.accuracy_computers):
            self.log(f"Test/Accuracy_{i}", accuracy_computer.compute())
            accuracy_computer.reset()

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-3)

In [21]:
lit_mt_model = LitMtModel(mt_elbo)

In [22]:
# stack task datasets
unified_train_dataset = StackDataset(*train_datasets)
unified_test_dataset = StackDataset(*test_datasets)

mt_train_dataloader = DataLoader(unified_train_dataset, config.batch_size, shuffle=True)
mt_test_dataloader = DataLoader(unified_test_dataset, config.batch_size, shuffle=False)

In [23]:
logger = TensorBoardLogger("mt_logs", name="multitask")

callbacks = [
    EarlyStopping(monitor="elbo", min_delta=1e-3, patience=10, mode="min")
]

In [24]:
trainer = L.Trainer(logger=logger, callbacks=callbacks, **dict(config.trainer))
trainer.fit(lit_mt_model, mt_train_dataloader, mt_test_dataloader)

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


  | Name         | Type          | Params | Mode 
-------------------------------------------------------
0 | mt_elbo      | MultiTaskElbo | 9      | train
1 | distr_params | ParameterList | 12     | train
-------------------------------------------------------
21        Trainable params
0         Non-trainable params
21        Total params
0.000     Total estimated model params size (MB)
2         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=40` reached.


Metrics are saved in the tensorboard

In [26]:
%tensorboard --logdir mt_logs/multitask

Reusing TensorBoard on port 6008 (pid 35966), started 0:00:03 ago. (Use '!kill 35966' to kill it.)

As we can see, the training is successful. From the vizualized classifier mixing paramters it is clear that the algorithm has found connection between task 1 and 2.

## Solo and multitask comparasion

Final accuracy metrics are equal within the error margin for solo and multitask models. Because the tasks are simple and synthetic, we have not seen the performance difference. But for more serious tasks with possible probabilistic connections the variational multitask may give significant boost.