# Model Training, Optimization and Quantization with MONAI, PyTorch Lightning and OpenVINO

This tutorial shows how to train a [MONAI](https://monai.io/) classification model on a [MedMNIST](https://medmnist.com/) dataset, and quantize the model with [OpenVINO's Post-Training Optimization Tool](https://docs.openvino.ai/latest/pot_README.html)

To run this notebook, please create a virtual environment with Python 3.7 or 3.8, with `python -m venv monai_env` (on Linux use `python3`) and install the requirements with `pip install -r requirements.txt`.

## Supported Networks and Datasets

The following MONAI Classification Networks are supported in this notebook:

`["DenseNet","SENet154", "SEResNet50",  "SEResNext50"]`. 

The variants of these networks `"DenseNet121", "DenseNet169", "DenseNet201", "DenseNet264",  "SEResNet101", "SEResNet152", "SEResNext101"` also work. 

All MedMNIST datasets for multi-class and binary-class classification are supported: `['pathmnist', 'dermamnist', 'octmnist', 'pneumoniamnist', 'breastmnist', 'bloodmnist', 'tissuemnist', 'organamnist', 'organcmnist', 'organsmnist', 'organmnist3d', 'nodulemnist3d', 'adrenalmnist3d', 'fracturemnist3d', 'vesselmnist3d', 'synapsemnist3d']`

## Imports

In [None]:
import datetime
import inspect
import logging
import os
import random
import subprocess
import warnings
from operator import itemgetter
from pathlib import Path
from typing import Dict

import dateutil
import matplotlib.pyplot as plt
import medmnist
import monai
import monai.networks.nets as nets
import numpy as np
import pytorch_lightning as pl
import torch
import nncf  # must be imported after torch
from nncf import NNCFConfig
from nncf.common.utils.logger import set_log_level
from nncf.torch import create_compressed_model, register_default_init_args
from IPython.display import Markdown, display
from monai.metrics import ConfusionMatrixMetric
from monai.transforms import (
    Activations,
    AddChannel,
    AsChannelFirst,
    AsDiscrete,
    Compose,
    ToTensor,
)
from openvino.inference_engine import IECore
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.model_summary import summarize
from torch.jit import TracerWarning
from torch.utils.data import DataLoader as TorchDataLoader
from torch.utils.data import Dataset

## Dataset, Metric and Model

We start by defining a Dataset and DataModule to handle data transformation and loading, a Metric class that specifies how to evaluate the model, and the PyTorch Lightning Model that specifies the MONAI model to use and contains the training and evaluation code.

### Data

We create a Monai Dataset to load and transform the data, and a PyTorch Lighnting DataModule for accessing this data during training. The dataset returns data as a tuple consisting of (image, mask, image_metadata, mask_metadata).

We use Monai Transforms to transform and augment the data during training. During training, we randomly rotate the data, add noise, and shift pixel values. Monai's ImageDataset ensures that the random seed for the image and segmentation mask transform are the same, and therefore that for the random rotation transforms, image and mask will be rotated in the same way. During validation, we only make sure that the dimensions and data type are correct.

The specified MedMNIST dataset is downloaded if it has not been downloaded before. See the top of this notebook for the supported datasets

#### Dataset

In [None]:
class MedMNISTDataset(Dataset):
    def __init__(self, medmnist_dataset, split) -> None:

        supported_datasets = [
            (ds_name)
            for (ds_name, ds_info) in medmnist.INFO.items()
            if ds_info["task"] in ["multi-class", "binary-class"]
        ]

        if medmnist_dataset not in supported_datasets:
            raise ValueError(
                f"{medmnist_dataset} is not a supported dataset. Supported datasets are: "
                f"{supported_datasets}."
            )

        dataset = getattr(
            medmnist, medmnist.dataset.INFO[medmnist_dataset]["python_class"]
        )
        self.num_dims = 3 if medmnist_dataset.endswith("3d") else 2
        self.num_channels = medmnist.dataset.INFO[medmnist_dataset]["n_channels"]
        self.labels = medmnist.dataset.INFO[medmnist_dataset]["label"]
        self.num_classes = len(self.labels)

        transforms = [ToTensor(dtype=torch.float)]
        # transforms = []
        if self.num_channels == 3:
            transforms.append(AsChannelFirst())
        elif self.num_dims == 2:
            transforms.append(AddChannel())

        transform = Compose(transforms)

        print(
            f"Setup {medmnist_dataset} {split}, {self.num_channels} channels, "
            f"{self.num_classes} classes"
        )
        split_dataset = dataset(split=split, transform=None, download=True)

        data = [
            (transform(np.asarray(item[0])), torch.as_tensor(item[1]))
            for item in split_dataset
        ]
        self.data = data

    def __len__(self):
        """
        Returns the number of elements in the dataset
        """
        return len(self.data)

    def __getitem__(self, index):
        """
        Get item from self.dataset at the specified index.
        Returns (annotation, image), where annotation is a tuple (index, class_index)
        and image a preprocessed image in network shape
        """
        image, label = self.data[index]
        image = torch.as_tensor(image, dtype=torch.float)
        label = torch.as_tensor(label, dtype=torch.float)
        # image = self.transform(image)
        return image, label

#### PyTorch Lightning DataModule

A [Lightning DataModule]( https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html) defines dataloaders and connects the Lightning Module to the Dataset. 

In [None]:
class DataModule(pl.LightningDataModule):
    def __init__(self, medmnist_dataset, batch_size):
        super().__init__()
        self.batch_size = batch_size
        self.medmnist_dataset = medmnist_dataset

    def setup(self, stage=None):
        random.seed(1.414213)

        self.dataset_train = MedMNISTDataset(self.medmnist_dataset, "train")
        self.dataset_val = MedMNISTDataset(self.medmnist_dataset, "val")
        self.dataset_test = MedMNISTDataset(self.medmnist_dataset, "test")

        print(f"Setup train dataset: {len(self.dataset_train)} items")
        print(f"Setup val dataset: {len(self.dataset_val)} items")
        print(f"Setup test dataset: {len(self.dataset_test)} items")

        assert len(self.dataset_train) > 0, "Train dataset is empty."
        assert len(self.dataset_val) > 0, "Val dataset is empty"
        assert len(self.dataset_test) > 0, "Val dataset is empty"

    def train_dataloader(self):
        return TorchDataLoader(
            self.dataset_train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=0,
            pin_memory=torch.cuda.is_available(),
            drop_last=True,
        )

    def val_dataloader(self):
        return TorchDataLoader(
            self.dataset_val,
            batch_size=self.batch_size,
            num_workers=0,
            shuffle=False,
            pin_memory=torch.cuda.is_available(),
            drop_last=False,
        )

    def test_dataloader(self):
        return TorchDataLoader(
            self.dataset_test,
            batch_size=self.batch_size,
            num_workers=0,
            shuffle=False,
            pin_memory=torch.cuda.is_available(),
            drop_last=False,
        )

### Model

We create a PyTorch Lightning [LightningModule](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html) to train a [Monai network](https://docs.monai.io/en/latest/networks.html#nets)

For binary classification models we use Binary Cross Entropy Loss, for multiclass segmentation Cross Entropy Loss. For optimizer, we use the Adam Optimizer with the default learning rate of 0.001. The evaluation metric is accuracy, implemented with the Accuracy class defined in the previous cell.

In [None]:
class MonaiModel(pl.LightningModule):
    def __init__(self, monai_model: str, config: Dict):
        """
        PyTorch Lightning Module for a given MONAI classification model.

        :param monai_model: MONAI model name. For example, DenseNet, SEResNet50
        :param config: Dictionary with configuration values to pass to MONAI model initialization.
                       Dictionary keys that are not supported model parameters are discarded.
        """
        super().__init__()
        model = getattr(nets, monai_model)
        model_config = config.copy()
        self.num_classes = model_config["num_classes"]
        
        # For binary classification (2 classes, e.g. normal/abnormal) the model is configured
        # with 1 class that can either be 0 or 1.
        if model_config["num_classes"] == 2:
            model_config["num_classes"] = 1

        for parameter in model_config.copy():
            if (
                parameter not in inspect.signature(model).parameters
                and parameter not in inspect.signature(model.__bases__[0]).parameters
            ):
                model_config.pop(parameter)
        self._model = model(**model_config).cpu()

        self.save_hyperparameters()
        # https://docs.monai.io/en/latest/highlights.html?deterministic-training-for-reproducibility
        monai.utils.set_determinism(seed=2.71828, additional_settings=None)
        self.loss_function = (
            torch.nn.CrossEntropyLoss()
            if self.num_classes > 2
            else torch.nn.BCEWithLogitsLoss()
        )
        self.metric = ConfusionMatrixMetric(metric_name="accuracy")
        self.best_val_accuracy = 0
        self.best_val_epoch = 0

        print(
            f"Initialized {monai_model} with settings: {model_config} {self.num_classes} classes"
        )

    def forward(self, x):
        return self._model(x)

    def forward_batch(self, batch):
        """
        Propagate images through the network

        :return: raw network output in layout expected by loss function
        """
        images, _ = batch
        images = torch.as_tensor(images, dtype=torch.float)
        output = self.forward(images)
        if isinstance(self.loss_function, torch.nn.BCEWithLogitsLoss):
            if len(output.shape) == 1:
                output = output.unsqueeze(-1)
        return output

    def process_labels(self, batch):
        """
        Return labels in format expected by the loss function (expected to be
        BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multiclass

        :return: labels in correct datatype and layout
        """
        _, annotation = batch
        labels = torch.as_tensor(annotation, dtype=torch.float)
        if isinstance(self.loss_function, torch.nn.BCEWithLogitsLoss):
            labels = labels.float()
        else:
            labels = labels.long().squeeze(dim=1)
        return labels

    @torch.no_grad()
    def predict_one(self, x):
        """
        Propage one image through the network and return the result as a class index
        Uses sigmoid for binary classification and argmax for multiclass

        :return prediction as class index integer
        """
        output = self.forward(x)
        if self.num_classes <= 2:
            predict = torch.sigmoid(output).round().byte().squeeze()
        else:
            predict = torch.argmax(output, axis=1).byte().squeeze()
        return predict

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self._model.parameters())
        return optimizer

    def training_step(self, batch, batch_idx):
        output = self.forward_batch(batch)
        labels = self.process_labels(batch)
        loss = self.loss_function(input=output, target=labels)
        self.log("train_loss", loss.item())
        return loss

    def validation_step(self, batch, batch_idx):
        stage = self.trainer.state.stage
        output = self.forward_batch(batch)
        labels = self.process_labels(batch)
        loss = self.loss_function(input=output, target=labels)
        # Update statistics for metric computation

        if len(labels.shape) == 1:
            labels.unsqueeze_(-1)

        is_binary_classification = self.num_classes <= 2
        output_transforms = Compose(
            [
                ToTensor(),
                Activations(sigmoid=is_binary_classification),
                AsDiscrete(threshold=0.5 if is_binary_classification else None),
                AsDiscrete(
                    to_onehot=self.num_classes, argmax=not is_binary_classification
                ),
            ]
        )
        target_transforms = Compose(
            [ToTensor(), AsDiscrete(to_onehot=self.num_classes)]
        )

        onehot_output = [output_transforms(item.cpu()) for item in output]
        onehot_target = [target_transforms(item.cpu()) for item in labels]

        self.metric(onehot_output, onehot_target)

        self.log(f"{stage}_loss", loss)
        return {f"{stage}_loss": loss, "num_items": len(output)}

    def validation_epoch_end(self, outputs):
        stage = self.trainer.state.stage
        loss, num_items = 0, 0
        for output in outputs:
            loss += output[f"{stage}_loss"].sum().item()
            num_items += output["num_items"]
        mean_accuracy = self.metric.aggregate()[0].item()
        self.metric.reset()

        mean_loss = torch.tensor(loss / num_items)
        self.logger.experiment.add_scalar(
            f"{stage}/loss", mean_loss, self.current_epoch
        )
        self.logger.experiment.add_scalar(
            f"{stage}/accuracy",
            mean_accuracy,
            self.current_epoch,
        )
        self.log(f"{stage}_accuracy", mean_accuracy, prog_bar=True, logger=False)
        if stage == "validate":
            if mean_accuracy > self.best_val_accuracy:
                self.best_val_accuracy = mean_accuracy
                self.best_val_epoch = self.current_epoch

    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        return self.validation_epoch_end(outputs)

## Model and Dataset Configuration

Running the next cell shows the supported MONAI models with number of trainable parameters and MedMNIST datasets with number of labels, and number of items in the train, validation and test set.

In the cell after that, specify the model and dataset to use in this notebook

For demo purposes, we define a `sample_config` with default values for all networks. Not all networks support all parameters; unsupported parameters will be ignored.

In [None]:
sample_config = {
    "in_channels": 1,
    "out_channels": 1,
    "num_classes": 2,
    "spatial_dims": 2,
    "block_config": [6, 12, 8],
}

supported_models = ["DenseNet", "SENet154", "SEResNet50", "SEResNext50"]
print(f"Supported models: {supported_models}")
for model_name in supported_models:
    model = MonaiModel(model_name, sample_config)
    param_size = summarize(model).trainable_parameters / 1000 / 1000
    print(f"Trainable parameters: {param_size:.2f} M")
    del model
print()

supported_datasets = [
    (ds_name)
    for (ds_name, ds_info) in medmnist.INFO.items()
    if ds_info["task"] in ["multi-class", "binary-class"]
]
print("Supported datasets:")
for key, value in medmnist.INFO.items():
    if key in supported_datasets:
        print(key, value["task"], f"{len(value['label'])} labels, {value['n_samples']}")

Specify the MedMNIST dataset and MONAI model to use. SENet154 is a large model which will take a long time to train. For training on CPU, DenseNet is recommended, in combination with a small dataset.

In [None]:
medmnist_dataset = "breastmnist"
monai_model = "DenseNet"

assert medmnist_dataset in supported_datasets
assert getattr(nets, monai_model)

The input arguments for the MONAI model are taken from the dataset information. For example, the `spatial_dims` argument to a MONAI model is set to the `num_dims` value of the MedMNIST dataset. 


In [None]:
data = DataModule(batch_size=19, medmnist_dataset=medmnist_dataset)
data.setup()
out_channels = (
    data.dataset_train.num_classes if data.dataset_train.num_classes > 2 else 1
)

In [None]:
# Set default values for MONAI models, based on the chosen medmnist_dataset.
# Not all models use all these parameters. Unused parameters are discarded in the MonaiModel.
# This makes it easier to test multiple models without creating custom configurations

default_config = {
    "in_channels": data.dataset_train.num_channels,
    "out_channels": out_channels,
    "num_classes": data.dataset_train.num_classes,
    "spatial_dims": data.dataset_train.num_dims,
    "block_config": [6, 12, 8],
}
print(default_config)

### Dataset Info and Visualization

Show MedMNIST information about the dataset, including description, labels and number of samples.

In [None]:
medmnist.INFO[medmnist_dataset]

Show a random sample of 10 images to verify (as much as that is possible with images of this size) that the dataset looks correct. For 3D images, the middle slice of the image is displayed.

In [None]:
indices = random.sample(range(len(data.dataset_train)), 10)
data_subset = itemgetter(*indices)(data.dataset_train)

fig, axs = plt.subplots(nrows=2, ncols=5, figsize=(20, 6))
plt.suptitle(medmnist_dataset)
for i, ax in enumerate(axs.ravel()):
    image, annotation = data_subset[i]
    if len(image.shape) == 3:
        image = image.permute(1, 2, 0).numpy().astype(np.uint8)
    elif len(image.shape) == 4:
        image = image[0][14]
    label = data.dataset_train.labels[str(annotation.short().item())]
    ax.imshow(image, cmap="gray")
    ax.set_title(label)

### Show Model Information

In [None]:
monai_lightning_model = MonaiModel(monai_model=monai_model, config=default_config)
summarize(monai_lightning_model)

## Start Training

Create a PyTorch Lightning [Trainer](https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html) and call the `.fit()` method to start training. Set `USE_CUDA` to True to enable training on CUDA-enabled GPUs and adjust the other settings as needed. The settings below train the model for 5 epochs, log to TensorBoard, and save the best 3 checkpoints, where "best" is defined as highest accuracy. On CUDA, 16 bit training is enabled. This is not supported on CPU. `limit_train_batches` can be useful for large datasets. At the end of training, the total training duration will be displayed.

See the [documentation](https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html) for information on all parameters and settings.

Uncomment the next cell to show a TensorBoard dashboard in the notebook. This will initially show no data. Click on the refresh button in the TensorBoard cell to show data during and after training.

To cancel training, click the stop button in the Jupyter toolbar at the top of the notebook. That will stop gracefully: the best checkpoint until that point is saved.

In [None]:
logger = TensorBoardLogger("tb_logs", name=f"{medmnist_dataset}_{monai_model.lower()}")
checkpoint_callback = ModelCheckpoint(
    monitor="validate_accuracy", mode="max", save_top_k=3
)
# Ignore PyTorch Lightning warnings about possible improvements
warnings.filterwarnings(
    "ignore", ".*Consider increasing the value of the `num_workers` argument*"
)
warnings.filterwarnings("ignore", ".*smaller than the logging interval*")
warnings.filterwarnings("ignore", ".*has already been called*")

limit_train_batches = 1.0
limit_val_batches = 1.0

# For demonstration purposes, we limit the number of dataset items for large datasets to reduce 
# training time. For better accuracy, set limit_train_batches to 1.0 (the default). 
# if len(data.train_dataloader().dataset) > 2000:
#     limit_train_batches = 1/(len(data.train_dataloader().dataset) / 2000)
# if len(data.val_dataloader().dataset) > 5000:
#     limit_val_batches = 0.5

USE_CUDA = torch.cuda.is_available()
trainer = pl.Trainer(
    max_epochs=5,
    gpus=1 if USE_CUDA else 0,
    logger=logger,
    precision=16 if USE_CUDA else 32,
    limit_train_batches=limit_train_batches,
    limit_val_batches=limit_val_batches,
    callbacks=[checkpoint_callback],
    fast_dev_run=False,  # set to True to quickly test the Lightning model
)

start = datetime.datetime.now()
try:
    trainer.fit(model=monai_lightning_model, datamodule=data)
finally:
    end = datetime.datetime.now()
    delta = dateutil.relativedelta.relativedelta(end, start)
    print(
        f"Training duration: {delta.hours:02d}:{delta.minutes:02d}:{delta.seconds:02d}"
    )

### Evaluate Trained Model on Test Set

In [None]:
monai_lightning_model.best_val_accuracy

In [None]:
torch_test_accuracy = trainer.test(ckpt_path="best", dataloaders=data)

Visualize images from the test set and print actual and predicted labels

In [None]:
indices = random.sample(range(len(data.dataset_test)), 10)
data_subset = itemgetter(*indices)(data.dataset_test)

fig, axs = plt.subplots(nrows=2, ncols=5, figsize=(20, 6))
plt.suptitle(f"{monai_model}, {medmnist_dataset} labels/predictions")
monai_lightning_model.eval()

for i, ax in enumerate(axs.ravel()):
    image, annotation = data_subset[i]
    prediction = monai_lightning_model.predict_one(image.unsqueeze(0)).item()
    if len(image.shape) == 3:
        image = image.permute(1, 2, 0).numpy().astype(np.uint8)
    elif len(image.shape) == 4:
        image = image[0][14]
    target_label = f"label: {data.dataset_train.labels[str(annotation.short().item())]}"
    predicted_label = f"prediction: {data.dataset_train.labels[str(prediction)]}"

    ax.imshow(image, cmap="gray")
    ax.set_title(f"{target_label}\n{predicted_label}")
    ax.axis("off")

## Convert FP32 model to ONNX and IR

Load the best checkpoint and export the model to ONNX.

In [None]:
MODEL_DIR = Path("nncf_models")
MODEL_DIR.mkdir(exist_ok=True)
onnx_path = MODEL_DIR / f"{monai_model}_{medmnist_dataset}_fp32.onnx"
checkpoint_path = checkpoint_callback.best_model_path
best_model = MonaiModel.load_from_checkpoint(checkpoint_path).cpu().eval()

if data.dataset_train.num_dims == 3:
    input_shape = [1, data.dataset_train.num_channels, 28, 28, 28]
else:
    input_shape = [1, data.dataset_train.num_channels, 28, 28]


dummy_input = torch.randn(*input_shape)
torch.onnx.export(best_model, dummy_input, onnx_path, opset_version=10)
print(f"Exported ONNX model to {onnx_path}")

# Convert ONNX model and export to OpenVINO IR
fp_ir_path = onnx_path.with_suffix(".xml")
IECore().read_network(onnx_path).serialize(str(fp_ir_path))
print(f"Exported FP32 IR model to {fp_ir_path}")

## NNCF

Created an nncf compressed model requires a configuration dictionary, a dataloader, and a model. We use the `best_model` loaded in the previous cell, and use the validation dataloader. 

In [None]:
NNCF_OUTPUT_DIR = Path("output")
NNCF_OUTPUT_DIR.mkdir(exist_ok=True)
input_shape = list(next(iter(data.val_dataloader()))[0].shape)
input_shape[0] = 1
nncf_config_dict = {
    "input_info": {"sample_size": input_shape},
    "log_dir": str(
        NNCF_OUTPUT_DIR
    ),  
    "compression": {
        "algorithm": "quantization",
    },
}
nncf_config = NNCFConfig.from_dict(nncf_config_dict)
nncf_config = register_default_init_args(nncf_config, data.val_dataloader())

set_log_level(logging.ERROR)  # Disables all NNCF info and warning messages
compression_ctrl, compressed_model = create_compressed_model(
    best_model._model, nncf_config
)
del best_model

### Export Quantized Model to ONNX and OpenVINO IR.

The quantized model can be exported to ONNX with NNCF's `export_model` method. We then export to OpenVINO IR. The ONNX modelfile will have the same size as the PyTorch model file. The quantized IR model file will have a smaller size.

In [None]:
warnings.filterwarnings("ignore", category=TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)

int8_onnx_path = MODEL_DIR / f"{monai_model}_{medmnist_dataset}_nncf.onnx"
compression_ctrl.export_model(int8_onnx_path)

### Convert to OpenVINO IR

We use Model Optimizer to convert the ONNX model to OpenVINO's Intermediate Representation (IR) format. In Jupyter we can call it with `! mo`, in a script we can use the subprocess module. We use subprocess here, to make it easy to convert this notebook to a script.

For ONNX conversion, Model Optimizer only needs a path to an input model. We also specify an `output_dir` to save the model. Model Optimizer creates an .xml and .bin file, with the same base filename as the ONNX model. The .xml file contains information about the network topology, the .bin file contains weights and biases binary data. By default, weights and biases are stored as FP32. To convert them to FP16, set `--data_type` to FP16. This saves space (FP16 takes half as much space as FP32), and increases inference speed when using an Intel integrated GPU. 

Run `!mo --help` to see information about all Model Optimizer parameters

In [None]:
mo_result = subprocess.run(
    ["mo", "--input_model", int8_onnx_path, "--output_dir", MODEL_DIR],
    check=False,
    universal_newlines=True,
    capture_output=True,
)
if mo_result.returncode == 0:
    print(
        "\n".join([line for line in mo_result.stdout.split("\n") if "SUCCESS" in line])
    )
else:
    mo_error = "\n".join([line for line in mo_result.stderr.split("\n")])
    raise RuntimeError(
        f"Model optimization failed with the following error:\n{mo_error}"
    )

Create a helper function to do inference on an IR model, and load the floating point and integer IR models to Inference Engine. Show inference results on the PyTorch Model, the floating point IR model, and the quantized INT8 IR model

## Compare Accuracy

In [None]:
def predict_ir(exec_net, image, return_class_index: bool=True):
    """
    Do inference of image on exec_net. Return the result as class index integer
    """
    input_layer = next(iter(exec_net.input_info))
    output_layer = next(iter(exec_net.outputs))
    output = exec_net.infer(inputs={input_layer: image})[output_layer]
    if return_class_index:
        output_shape = exec_net.outputs[output_layer].shape
        if output_shape[-1] == 1:
            predicted_class_index = sigmoid(output).round().astype(np.uint8).squeeze()
        else:
            predicted_class_index = np.argmax(output, axis=1).astype(np.uint8).squeeze()
        result = predicted_class_index
    else:
        result = output
    return result

In [None]:
def compute_accuracy(exec_net, dataset, num_classes):
    metric = ConfusionMatrixMetric(metric_name="accuracy")
    for image, label in dataset:

        output = predict_ir(exec_net, image, return_class_index=False)

        if len(label.shape) == 1:
            label.unsqueeze_(-1)

        is_binary_classification = num_classes <= 2
        output_transforms = Compose(
            [
                ToTensor(),
                Activations(sigmoid=is_binary_classification),
                AsDiscrete(threshold=0.5 if is_binary_classification else None),
                AsDiscrete(
                    to_onehot=num_classes, argmax=not is_binary_classification
                ),
            ]
        )
        target_transforms = Compose(
            [ToTensor(), AsDiscrete(to_onehot=num_classes)]
        )

        onehot_output = [output_transforms(item) for item in output]
        onehot_target = [target_transforms(item) for item in label]

        metric(onehot_output, onehot_target)
    return metric.aggregate()[0].item()

In [None]:
# Load IR and PyTorch Lightning Models
ie = IECore()
int8_ir_path = int8_onnx_path.with_suffix(".xml")
fp_net = ie.read_network(fp_ir_path)
fp_exec_net = ie.load_network(fp_net, "CPU")
int8_net = ie.read_network(int8_ir_path)
int8_exec_net = ie.load_network(int8_net, "CPU")
best_model = MonaiModel.load_from_checkpoint(checkpoint_path).cpu().eval()
num_classes = monai_lightning_model.num_classes

In [None]:
validation_dataset = MedMNISTDataset(medmnist_dataset=medmnist_dataset, split="val")
fp_acc = compute_accuracy(fp_exec_net, validation_dataset, num_classes)
int8_acc = compute_accuracy(int8_exec_net, validation_dataset, num_classes)

print(f"Validation accuracy of the PyTorch model: {monai_lightning_model.best_val_accuracy:.5f}")
print(f"Validation accuracy of the original IR model: {fp_acc:.5f}")
print(f"Validation accuracy of the quantized IR model: {int8_acc:.5f}")

In [None]:
test_dataset = MedMNISTDataset(medmnist_dataset=medmnist_dataset, split="test")
fp_acc = compute_accuracy(fp_exec_net, test_dataset, num_classes)
int8_acc = compute_accuracy(int8_exec_net, test_dataset, num_classes)

print(f"Test accuracy of the PyTorch model: {torch_test_accuracy[0]['test_accuracy']:.5f}")
print(f"Test accuracy of the original IR model: {fp_acc:.5f}")
print(f"Test accuracy of the quantized IR model: {int8_acc:.5f}")

## Visually Compare Inference Results

In [None]:
def sigmoid(x):
    return np.exp(-np.logaddexp(0, -x))

# Load 10 random images. Run this cell again to see inference results on
# 10 different images
indices = random.sample(range(len(data.dataset_test)), 10)
data_subset = itemgetter(*indices)(data.dataset_test)

fig, axs = plt.subplots(nrows=2, ncols=5, figsize=(20, 6))
fig.tight_layout(h_pad=5)
monai_lightning_model.eval()

for i, ax in enumerate(axs.ravel()):
    image, annotation = data_subset[i]
    input_image = image.unsqueeze(0)
    torch_prediction = best_model.predict_one(input_image).item()
    fp_ir_prediction = predict_ir(fp_exec_net, input_image)
    int8_ir_prediction = predict_ir(int8_exec_net, input_image)

    if len(image.shape) == 3:
        image = image.permute(1, 2, 0).numpy().astype(np.uint8)
    elif len(image.shape) == 4:
        image = image[0][14]
    target_label = (
        f"annotation: {data.dataset_train.labels[str(annotation.short().item())]}"
    )
    torch_predicted_label = f"torch: {data.dataset_train.labels[str(torch_prediction)]}"
    fp_predicted_label = f"FP IR: {data.dataset_train.labels[str(fp_ir_prediction)]}"
    int8_predicted_label = (
        f"INT8 IR: {data.dataset_train.labels[str(int8_ir_prediction)]}"
    )
    ax.imshow(image, cmap="gray")
    ax.set_title(
        f"{target_label}\n{torch_predicted_label}\n{fp_predicted_label}\n{int8_predicted_label}"
    )
    ax.axis("off")

## Benchmark Performance

To measure the inference performance of the FP16 and INT8 models, we use Benchmark Tool, OpenVINO's inference performance measurement tool. Benchmark tool is a command line application that can be run in the notebook with ! benchmark_app or %sx benchmark_app.

In the next cell, we create a wrapper function for `benchmark_app` that prints the benchmark_app command with the chosen parameters. For comparison purposes, it filters logging information from the output of `benchmark_app`.

> NOTE: For the most accurate performance estimation, we recommended running benchmark_app in a terminal/command prompt after closing other applications. Run `benchmark_app --help` to see all command line options.

In [None]:
def benchmark_model(
    model_path: os.PathLike,
    device: str = "CPU",
    seconds: int = 60,
    api: str = "async",
    batch: int = 1,
    cache_dir="model_cache",
    key=None
):
    ie = IECore()
    model_path = Path(model_path)
    if ("GPU" in device) and ("GPU" not in ie.available_devices):
        raise ValueError(
            f"A GPU device is not available. Available devices are: {ie.available_devices}"
        )
    else:
        benchmark_command = f"benchmark_app -m {model_path} -d {device} -t {seconds} -api {api} -b {batch} -cdir {cache_dir}"
        display(
            Markdown(
                f"**Benchmark {model_path.name} with {device} for {seconds} seconds with {api} inference**"
            )
        )
        display(Markdown(f"Benchmark command: `{benchmark_command}`"))

        benchmark_output = subprocess.run(
            benchmark_command.split(" "), capture_output=True, universal_newlines=True
        )
        benchmark_result = [
            line.replace("Latency", f"{key}_Latency").replace("Throughput", f"{key}_Throughput")
            for line in benchmark_output.stdout.splitlines()
            if not (line.startswith(r"[") or line.startswith("  ") or line == "")
        ]
        print("\n".join(benchmark_result))
        print()
        if "MULTI" in device:
            devices = device.replace("MULTI:", "").split(",")
            for single_device in devices:
                print(
                    f"{single_device} device: {ie.get_metric(device_name=single_device, metric_name='FULL_DEVICE_NAME')}"
                )
        else:
            print(
                f"Device: {ie.get_metric(device_name=device, metric_name='FULL_DEVICE_NAME')}"
            )

In [None]:
# FP model on CPU
benchmark_model(fp_ir_path, device="CPU", seconds=15, api="sync", key="FP")

In [None]:
# INT8 model on CPU
benchmark_model(int8_ir_path, device="CPU", seconds=15, api="sync", key="INT8")