# Train CIFAR10 Image Classification on top of ResNet18 features from ImageNet

1. Import `mlmodule` modules for the task

In [1]:
from mlmodule.models.resnet.modules import TorchResNetModule
from mlmodule.models.classification import LinearClassifierTorchModule
from mlmodule.torch.datasets import TorchTrainingDataset
from mlmodule.torch.runners import TorchTrainingRunner
from mlmodule.torch.runners import TorchInferenceRunner
from mlmodule.torch.options import TorchTrainingOptions
from mlmodule.torch.options import TorchRunnerOptions
from mlmodule.labels.base import LabelSet
from mlmodule.callbacks.memory import (
    CollectFeaturesInMemory,
)
from mlmodule.torch.datasets import (
    ListDataset,
    ListDatasetIndexed,
)
from mlmodule.states import StateKey
from mlmodule.stores import Store

  from .autonotebook import tqdm as notebook_tqdm


Enable logging into notebook

In [2]:
import logging
import sys

logging.basicConfig(format='%(asctime)s | %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

2. Load CIFAR10 dataset from torchvision

In [3]:
from torchvision.datasets import CIFAR10

root_dir = '/home/lebret/data'
train_cifar10 = CIFAR10(root=root_dir, train=True, download=True,  transform=None)

Files already downloaded and verified


3. Format inputs and labels for `mlmodule`

In [4]:
labels_dict = {0:"airplane", 1:"automobile", 2:"bird", 3:"cat", 4:"deer", 5:"dog", 6:"frog", 7:"horse", 8:"ship", 9:"truck"}
train_samples = [(img, labels_dict[label]) for img, label in train_cifar10]
train_images, train_labels = zip(*train_samples)

4. Load `resnet18` pretrained model

In [5]:
torch_device = "cuda"
resnet = TorchResNetModule(
    resnet_arch="resnet18", 
    device=torch_device,
    training_mode="features"
)
Store().load(resnet, StateKey(resnet.state_type, training_id="imagenet"))


2022-05-11 14:09:34,865 | INFO : Found credentials in shared credentials file: ~/.aws/credentials


5. Extract image features

In [6]:
# Callbacks
ff = CollectFeaturesInMemory()

# Runner
runner = TorchInferenceRunner(
    model=resnet,
    dataset=ListDataset(train_images),
    callbacks=[ff],
    options=TorchRunnerOptions(
        data_loader_options={'batch_size': 32},
        device=torch_device, 
        tqdm_enabled=True
    ),
)
runner.run()

100%|██████████| 1563/1563 [02:42<00:00,  9.62it/s]


6. Create a linear classifier on top of ResNet features

In [7]:
from mlmodule.models.classification import LinearClassifierTorchModule

labels = list(labels_dict.values())
labels.sort()
label_set = LabelSet(
            label_set_unique_id="cifar10",
            label_list=labels
        )
        
classifier = LinearClassifierTorchModule(
    in_features=ff.features.shape[1],
    label_set=label_set
)

7. Create train and validation splits

In [8]:
import torch

# split samples into train and valid sets
train_indices, valid_indices = torch.split(torch.randperm(len(ff.indices)), int(len(ff.indices)*.9))
# define training set
train_dset = TorchTrainingDataset(
    dataset=ListDatasetIndexed(train_indices, ff.features[train_indices]),
    targets=label_set.get_label_ids([train_labels[idx] for idx in train_indices])
)
# define valid set
valid_dset = TorchTrainingDataset(
    dataset=ListDatasetIndexed(valid_indices, ff.features[valid_indices]),
    targets=label_set.get_label_ids([train_labels[idx] for idx in valid_indices])
)

8. Train the image classifier using `TorchTrainingRunner` module

In [18]:
from ignite.metrics import Precision, Recall, Loss, Accuracy
from mlmodule.callbacks.states import SaveModelState
from mlmodule.stores.local import LocalStateStore

import torch.nn.functional as F
import torch.optim as optim

# define the evaluation metrics
precision = Precision(average=False)
recall = Recall(average=False)
F1 = (precision * recall * 2 / (precision + recall)).mean()

# Callbacks
model_state = SaveModelState(
    store=LocalStateStore('/home/lebret/data/mlmodule'), 
    state_key=StateKey(classifier.state_type, 'train-1'))
# define a loss function
loss_fn =  F.cross_entropy

# define the trainer
trainer = TorchTrainingRunner(
    model=classifier,
    dataset=(train_dset, valid_dset),
    callbacks=[model_state],
    options=TorchTrainingOptions(
        data_loader_options={'batch_size': 32},
        criterion=loss_fn,
        optimizer=optim.Adam(classifier.parameters(), lr=1e-3),
        metrics={
            "pre": precision,
            "recall": recall,
            "f1": F1,
            "acc": Accuracy(),
            "ce_loss": Loss(loss_fn),
        },
        validate_every=1,
        checkpoint_every=3,
        num_epoch=5,
        tqdm_enabled=True,
    ),
)
trainer.run()

2022-05-11 14:46:01,176 | INFO : Engine run starting with max_epochs=5.


Epoch [1/5]: 100%|██████████| 1407/1407 [00:14<00:00, 96.65it/s, avg_loss=0.448] 

2022-05-11 14:46:22,973 | INFO : Engine run starting with max_epochs=1.



                                                      

2022-05-11 14:46:33,581 | INFO : Epoch[1] Complete. Time taken: 00:00:11
2022-05-11 14:46:33,582 | INFO : Engine run complete. Time taken: 00:00:11
2022-05-11 14:46:33,583 | INFO : Epoch 1 - Evaluation time (seconds): 10.61 - Train metrics

2022-05-11 14:46:33,584 | INFO : 	f1: 0.8512




2022-05-11 14:46:33,585 | INFO : 	acc: 0.8524
2022-05-11 14:46:33,586 | INFO : 	ce_loss: 0.4333
2022-05-11 14:46:33,586 | INFO : Engine run starting with max_epochs=1.


                                                    

2022-05-11 14:46:34,780 | INFO : Epoch[1] Complete. Time taken: 00:00:01
2022-05-11 14:46:34,781 | INFO : Engine run complete. Time taken: 00:00:01
2022-05-11 14:46:34,781 | INFO : Epoch 1 - Evaluation time (seconds): 1.19 - Test metrics

2022-05-11 14:46:34,782 | INFO : 	f1: 0.8437




2022-05-11 14:46:34,782 | INFO : 	acc: 0.8440
2022-05-11 14:46:34,783 | INFO : 	ce_loss: 0.4312
2022-05-11 14:46:34,784 | INFO : Epoch[1] Complete. Time taken: 00:00:34


Epoch [2/5]: 100%|██████████| 1407/1407 [00:14<00:00, 96.02it/s, avg_loss=0.406] 


2022-05-11 14:46:49,437 | INFO : Engine run starting with max_epochs=1.


                                                      

2022-05-11 14:46:59,957 | INFO : Epoch[1] Complete. Time taken: 00:00:11
2022-05-11 14:46:59,958 | INFO : Engine run complete. Time taken: 00:00:11
2022-05-11 14:46:59,958 | INFO : Epoch 2 - Evaluation time (seconds): 10.52 - Train metrics

2022-05-11 14:46:59,959 | INFO : 	f1: 0.8618
2022-05-11 14:46:59,960 | INFO : 	acc: 0.8628
2022-05-11 14:46:59,960 | INFO : 	ce_loss: 0.4002
2022-05-11 14:46:59,961 | INFO : Engine run starting with max_epochs=1.


                                                    

2022-05-11 14:47:01,118 | INFO : Epoch[1] Complete. Time taken: 00:00:01
2022-05-11 14:47:01,119 | INFO : Engine run complete. Time taken: 00:00:01
2022-05-11 14:47:01,120 | INFO : Epoch 2 - Evaluation time (seconds): 1.16 - Test metrics

2022-05-11 14:47:01,121 | INFO : 	f1: 0.8491
2022-05-11 14:47:01,122 | INFO : 	acc: 0.8492
2022-05-11 14:47:01,122 | INFO : 	ce_loss: 0.4078
2022-05-11 14:47:01,123 | INFO : Epoch[2] Complete. Time taken: 00:00:26


Epoch [3/5]: 100%|██████████| 1407/1407 [00:14<00:00, 96.53it/s, avg_loss=0.387] 

2022-05-11 14:47:15,701 | INFO : Engine run starting with max_epochs=1.



                                                      

2022-05-11 14:47:26,240 | INFO : Epoch[1] Complete. Time taken: 00:00:11
2022-05-11 14:47:26,241 | INFO : Engine run complete. Time taken: 00:00:11
2022-05-11 14:47:26,242 | INFO : Epoch 3 - Evaluation time (seconds): 10.54 - Train metrics

2022-05-11 14:47:26,242 | INFO : 	f1: 0.8664
2022-05-11 14:47:26,243 | INFO : 	acc: 0.8673
2022-05-11 14:47:26,244 | INFO : 	ce_loss: 0.3854
2022-05-11 14:47:26,244 | INFO : Engine run starting with max_epochs=1.


                                                    

2022-05-11 14:47:27,438 | INFO : Epoch[1] Complete. Time taken: 00:00:01
2022-05-11 14:47:27,439 | INFO : Engine run complete. Time taken: 00:00:01
2022-05-11 14:47:27,439 | INFO : Epoch 3 - Evaluation time (seconds): 1.19 - Test metrics





2022-05-11 14:47:27,440 | INFO : 	f1: 0.8532
2022-05-11 14:47:27,440 | INFO : 	acc: 0.8532
2022-05-11 14:47:27,441 | INFO : 	ce_loss: 0.4006
2022-05-11 14:47:27,441 | INFO : Epoch[3] Complete. Time taken: 00:00:26


Epoch [4/5]: 100%|██████████| 1407/1407 [00:14<00:00, 96.07it/s, avg_loss=0.375] 

2022-05-11 14:47:42,090 | INFO : Engine run starting with max_epochs=1.



                                                      

2022-05-11 14:47:52,878 | INFO : Epoch[1] Complete. Time taken: 00:00:11
2022-05-11 14:47:52,878 | INFO : Engine run complete. Time taken: 00:00:11
2022-05-11 14:47:52,879 | INFO : Epoch 4 - Evaluation time (seconds): 10.79 - Train metrics

2022-05-11 14:47:52,880 | INFO : 	f1: 0.8688




2022-05-11 14:47:52,880 | INFO : 	acc: 0.8696
2022-05-11 14:47:52,881 | INFO : 	ce_loss: 0.3762
2022-05-11 14:47:52,882 | INFO : Engine run starting with max_epochs=1.


                                                    

2022-05-11 14:47:54,015 | INFO : Epoch[1] Complete. Time taken: 00:00:01
2022-05-11 14:47:54,015 | INFO : Engine run complete. Time taken: 00:00:01
2022-05-11 14:47:54,016 | INFO : Epoch 4 - Evaluation time (seconds): 1.13 - Test metrics





2022-05-11 14:47:54,017 | INFO : 	f1: 0.8543
2022-05-11 14:47:54,017 | INFO : 	acc: 0.8542
2022-05-11 14:47:54,018 | INFO : 	ce_loss: 0.3976
2022-05-11 14:47:54,018 | INFO : Epoch[4] Complete. Time taken: 00:00:27


Epoch [5/5]: 100%|██████████| 1407/1407 [00:14<00:00, 96.17it/s, avg_loss=0.367] 

2022-05-11 14:48:08,649 | INFO : Engine run starting with max_epochs=1.



                                                      

2022-05-11 14:48:19,469 | INFO : Epoch[1] Complete. Time taken: 00:00:11
2022-05-11 14:48:19,470 | INFO : Engine run complete. Time taken: 00:00:11
2022-05-11 14:48:19,470 | INFO : Epoch 5 - Evaluation time (seconds): 10.82 - Train metrics

2022-05-11 14:48:19,471 | INFO : 	f1: 0.8703
2022-05-11 14:48:19,472 | INFO : 	acc: 0.8712




2022-05-11 14:48:19,472 | INFO : 	ce_loss: 0.3698
2022-05-11 14:48:19,473 | INFO : Engine run starting with max_epochs=1.


                                                    

2022-05-11 14:48:20,656 | INFO : Epoch[1] Complete. Time taken: 00:00:01
2022-05-11 14:48:20,657 | INFO : Engine run complete. Time taken: 00:00:01
2022-05-11 14:48:20,658 | INFO : Epoch 5 - Evaluation time (seconds): 1.18 - Test metrics

2022-05-11 14:48:20,658 | INFO : 	f1: 0.8542
2022-05-11 14:48:20,659 | INFO : 	acc: 0.8542
2022-05-11 14:48:20,660 | INFO : 	ce_loss: 0.3963
2022-05-11 14:48:20,660 | INFO : Epoch[5] Complete. Time taken: 00:00:27




2022-05-11 14:48:20,664 | INFO : Engine run starting with max_epochs=1.


                                                      

2022-05-11 14:48:31,485 | INFO : Epoch[1] Complete. Time taken: 00:00:11
2022-05-11 14:48:31,486 | INFO : Engine run complete. Time taken: 00:00:11
2022-05-11 14:48:31,487 | INFO : Epoch 5 - Evaluation time (seconds): 10.82 - Train metrics

2022-05-11 14:48:31,487 | INFO : 	f1: 0.8703




2022-05-11 14:48:31,488 | INFO : 	acc: 0.8712
2022-05-11 14:48:31,489 | INFO : 	ce_loss: 0.3698
2022-05-11 14:48:31,489 | INFO : Engine run starting with max_epochs=1.


                                                    

2022-05-11 14:48:32,681 | INFO : Epoch[1] Complete. Time taken: 00:00:01
2022-05-11 14:48:32,681 | INFO : Engine run complete. Time taken: 00:00:01
2022-05-11 14:48:32,682 | INFO : Epoch 5 - Evaluation time (seconds): 1.19 - Test metrics

2022-05-11 14:48:32,682 | INFO : 	f1: 0.8542
2022-05-11 14:48:32,683 | INFO : 	acc: 0.8542
2022-05-11 14:48:32,683 | INFO : 	ce_loss: 0.3963




2022-05-11 14:48:32,684 | INFO : Engine run complete. Time taken: 00:02:32


9. Do evaluation on the test set

In [19]:
from mlmodule.callbacks.memory import CollectLabelsInMemory

test_cifar10 = CIFAR10(root=root_dir, train=False, download=True,  transform=None)
test_samples = [(img, labels_dict[label]) for img, label in test_cifar10]
test_images, test_labels = zip(*test_samples)

# Callbacks
ff_test = CollectFeaturesInMemory()
score_test = CollectLabelsInMemory()

# Extract the image features
features_test_runner = TorchInferenceRunner(
    model=resnet,
    dataset=ListDataset(test_images),
    callbacks=[ff_test],
    options=TorchRunnerOptions(
        data_loader_options={'batch_size': 32},
        device=torch_device, 
        tqdm_enabled=True
    ),
)
features_test_runner.run()

# Do the predictions
scores_test_runner = TorchInferenceRunner(
    model=classifier,
    dataset=ListDataset(ff_test.features),
    callbacks=[score_test],
    options=TorchRunnerOptions(
        data_loader_options={'batch_size': 32},
        device=torch_device, 
        tqdm_enabled=True
    ),
)
scores_test_runner.run()

Files already downloaded and verified


100%|██████████| 313/313 [00:22<00:00, 14.14it/s]
100%|██████████| 313/313 [00:00<00:00, 2084.19it/s]


10. Print classification report

In [20]:
from sklearn.metrics import classification_report
print(classification_report(test_labels, score_test.labels))
    

              precision    recall  f1-score   support

    airplane       0.88      0.85      0.86      1000
  automobile       0.93      0.92      0.93      1000
        bird       0.88      0.73      0.79      1000
         cat       0.81      0.66      0.73      1000
        deer       0.79      0.85      0.82      1000
         dog       0.74      0.86      0.80      1000
        frog       0.78      0.96      0.86      1000
       horse       0.93      0.79      0.85      1000
        ship       0.87      0.96      0.91      1000
       truck       0.94      0.91      0.93      1000

    accuracy                           0.85     10000
   macro avg       0.85      0.85      0.85     10000
weighted avg       0.85      0.85      0.85     10000

