In [1]:
import sys
import copy
from abc import ABC, abstractmethod
from collections import OrderedDict

from collections.abc import Sequence

from typing import Dict, Any

import numpy as np
import structlog
import torch
from pydantic import BaseModel
from pytorch_lightning import LightningModule, Trainer, Callback
from torch import optim
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.models import vgg16
from torchvision.transforms import transforms
from tqdm import tqdm

from baal.active import ActiveLearningDataset, ActiveLearningLoop
from baal.active.heuristics import BALD
from baal.modelwrapper import mc_inference
from baal.utils.cuda_utils import to_cuda
from baal.utils.iterutils import map_on_tensor

from baal.utils.pytorch_lightning import ActiveLearningMixin, ResetCallback, BaalTrainer

### We need to implement our model based on the PytorchLightning specifications. Bellow you can see an example using VGG16

In [25]:
class VGG16(ActiveLearningMixin, LightningModule):
    def __init__(self, active_dataset, hparams):
        super().__init__()
        self.name = "VGG16"
        self.version = "0.0.1"
        self.active_dataset = active_dataset
        self.hparams = hparams
        self.criterion = CrossEntropyLoss()

        self.train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                                   transforms.ToTensor()])
        self.test_transform = transforms.Compose([transforms.ToTensor()])
        self._build_model()

    def _build_model(self):
        self.vgg16 = vgg16(num_classes=self.hparams.num_classes)

    def forward(self, x):
        return self.vgg16(x)

    def log_hyperparams(self, *args):
        print(args)

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop
        :param batch:
        :return:
        """
        # forward pass
        x, y = batch
        y_hat = self(x)

        # calculate loss
        loss_val = self.criterion(y_hat, y)

        tqdm_dict = {'train_loss': loss_val}
        output = OrderedDict({
            'loss': loss_val,
            'progress_bar': tqdm_dict,
            'log': tqdm_dict
        })
        return output

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)

        # calculate loss
        loss_val = self.criterion(y, y_hat)

        tqdm_dict = {'val_loss': loss_val}
        output = OrderedDict({
            'loss': loss_val,
            'progress_bar': tqdm_dict,
            'log': tqdm_dict
        })
        return output

    def configure_optimizers(self):
        """
        return whatever optimizers we want here
        :return: list of optimizers
        """
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return [optimizer], []

    def train_dataloader(self):
        return DataLoader(self.active_dataset, self.hparams.batch_size, shuffle=True,
                          num_workers=4)

    def test_dataloader(self):
        ds = CIFAR10(root=self.hparams.data_root, train=False,
                     transform=self.test_transform, download=True)
        return DataLoader(ds, self.hparams.batch_size, shuffle=False,
                          num_workers=4)

    def pool_loader(self):
        return DataLoader(self.active_dataset.pool, self.hparams.batch_size, shuffle=False,
                          num_workers=4)

    def log_metrics(self, metrics, step_num):
        print('Epoch', step_num, metrics)

    def agg_and_log_metrics(self, metrics, step):
        self.log_metrics(metrics, step)

    def validation_epoch_end(self, outputs):
        return self.epoch_end(outputs)

    def epoch_end(self, outputs):
        out = {}
        if len(outputs) > 0:
            out = {key: torch.stack([x[key] for x in outputs]).mean() for key in outputs[0].keys()}
        return out

    def test_epoch_end(self, outputs):
        return self.epoch_end(outputs)

    def training_epoch_end(self, outputs):
        return self.epoch_end(outputs)

### Now we need to specify our hyperparameters

In [26]:
class HParams(BaseModel):
    batch_size: int = 10
    data_root: str = '/tmp'
    num_classes: int = 10
    learning_rate: float = 0.001
    query_size: int = 100
    max_sample: int = -1
    iterations: int = 20
    replicate_in_memory: bool = True

### Define the transformations to be used with our dataset

In [27]:
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor()])

In [28]:
test_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor()])

### We've defined above our hyperparameters using a Pydantic based class which will ensure we're using the correct data types. We instntiate that class here

In [29]:
hparams = HParams()

### We instantiate an ActiveLearning Dataset

In [30]:
active_set = ActiveLearningDataset(
        CIFAR10(hparams.data_root, train=True, transform=train_transform, download=True),
        pool_specifics={
            'transform': test_transform
        })

Files already downloaded and verified


### Label a few random items

In [31]:
active_set.label_randomly(10)

In [32]:
heuristic = BALD()

In [33]:
model = VGG16(active_set, hparams)

In [36]:
trainer = BaalTrainer(callbacks=[ResetCallback(copy.deepcopy(model.state_dict()))])

GPU available: True, used: False
TPU available: False, using: 0 TPU cores


In [37]:
loop = ActiveLearningLoop(active_set, get_probabilities=trainer.predict_on_dataset_generator,
                          heuristic=heuristic,
                          ndata_to_label=hparams.query_size,
                          max_sample=hparams.max_sample)

In [38]:
AL_STEPS = 100

In [None]:
for al_step in range(AL_STEPS):
    print(f'Step {al_step} DS size {len(active_set)}')
    trainer.fit(model)
    should_continue = loop.step()
    if not should_continue:
        break


  | Name      | Type             | Params
-----------------------------------------------
0 | criterion | CrossEntropyLoss | 0     
1 | vgg16     | VGG              | 134 M 


Step 0 DS size 10


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…



