# Train CIFAR10 Image Classification on top of ResNet18 features from ImageNet

1. Import `mlmodule` modules for the task

In [None]:
from mlmodule.models.resnet.modules import TorchResNetModule
from mlmodule.models.classification import LinearClassifierTorchModule
from mlmodule.torch.datasets import TorchTrainingDataset
from mlmodule.torch.runners import TorchTrainingRunner
from mlmodule.torch.runners import TorchInferenceRunner
from mlmodule.torch.options import TorchTrainingOptions
from mlmodule.torch.options import TorchRunnerOptions
from mlmodule.labels.base import LabelSet
from mlmodule.callbacks.memory import (
    CollectFeaturesInMemory,
)
from mlmodule.torch.datasets import (
    ListDataset,
    ListDatasetIndexed,
)
from mlmodule.states import StateKey
from mlmodule.stores import Store

Enable logging into notebook

In [None]:
import logging
import sys

logging.basicConfig(format='%(asctime)s | %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

2. Load CIFAR10 dataset from torchvision

In [None]:
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms

root_dir = '/home/lebret/data'
train_cifar10 = CIFAR10(root=root_dir, train=True, download=True,  transform=None)
valid_cifar10 = CIFAR10(root=root_dir, train=False, download=True,  transform=None)


3. Format inputs and labels for `mlmodule`

In [None]:
labels_dict = {0:"airplane", 1:"automobile", 2:"bird", 3:"cat", 4:"deer", 5:"dog", 6:"frog", 7:"horse", 8:"ship", 9:"truck"}
train_samples = [(img, labels_dict[label]) for img, label in train_cifar10]
train_images, train_labels = zip(*train_samples)

4. Load `resnet18` pretrained model

In [None]:
torch_device = "cuda"
resnet = TorchResNetModule(
    resnet_arch="resnet18", 
    device=torch_device,
    training_mode="features"
)
Store().load(resnet, StateKey(resnet.state_type, training_id="imagenet"))


5. Extract image features

In [None]:
# Callbacks
ff = CollectFeaturesInMemory()

# Runner
runner = TorchInferenceRunner(
    model=resnet,
    dataset=ListDataset(train_images),
    callbacks=[ff],
    options=TorchRunnerOptions(
        data_loader_options={'batch_size': 32},
        device=torch_device, 
        tqdm_enabled=True
    ),
)
runner.run()

6. Create a linear classifier on top of ResNet features

In [None]:
from mlmodule.models.classification import LinearClassifierTorchModule

labels = list(labels_dict.values())
labels.sort()
label_set = LabelSet(
            label_set_unique_id="cifar10",
            label_list=labels
        )
        
classifier = LinearClassifierTorchModule(
    in_features=resnet.classifier_module.in_features,
    label_set=label_set
)

7. Create train and validation splits

In [None]:
import torch

# split samples into train and valid sets
train_indices, valid_indices = torch.split(torch.randperm(len(ff.indices)), int(len(ff.indices)*.9))
# define training set
train_dset = TorchTrainingDataset(
    dataset=ListDatasetIndexed(train_indices, ff.features[train_indices]),
    targets=label_set.get_label_ids([train_labels[idx] for idx in train_indices])
)
# define valid set
valid_dset = TorchTrainingDataset(
    dataset=ListDatasetIndexed(valid_indices, ff.features[valid_indices]),
    targets=label_set.get_label_ids([train_labels[idx] for idx in valid_indices])
)

8. Train the image classifier using `TorchTrainingRunner` module

In [None]:
from ignite.metrics import Precision, Recall, Loss, Accuracy
import torch.nn.functional as F
import torch.optim as optim

precision = Precision(average=False)
recall = Recall(average=False)
F1 = (precision * recall * 2 / (precision + recall)).mean()

loss_fn =  F.cross_entropy
trainer = TorchTrainingRunner(
    model=classifier,
    dataset=(train_dset, valid_dset),
    callbacks=[],
    options=TorchTrainingOptions(
        data_loader_options={'batch_size': 32},
        criterion=loss_fn,
        optimizer=optim.Adam(classifier.parameters(), lr=1e-3),
        metrics={
            "pre": precision,
            "recall": recall,
            "f1": F1,
            "acc": Accuracy(),
            "ce_loss": Loss(loss_fn),
        },
        validate_every=1,
        num_epoch=5,
        tqdm_enabled=True
    ),
)
trainer.run()