In [None]:
%pip install --quiet --upgrade git+https://github.com/MedMNIST/MedMNIST.git pytorch_lightning openvino-dev

In [None]:
## full dataset
# https://scholar.cu.edu.eg/?q=afahmy/pages/dataset
# https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6906728/
# http://arxiv-export-lb.library.cornell.edu/pdf/2110.14795

## Imports

In [None]:
import copy
import datetime
import os
import random
import sys
from pathlib import Path

import dateutil
import matplotlib.pyplot as plt
import medmnist
import monai
import numpy as np
import pytorch_lightning as pl
import torch
from addict import Dict
from compression.api import DataLoader, Metric
from compression.engines.ie_engine import IEEngine
from compression.graph import load_model, save_model
from compression.graph.model_utils import compress_model_weights
from compression.pipeline.initializer import create_pipeline
from monai.data import Dataset
from monai.networks.nets import DenseNet
from monai.transforms import AddChannel, Compose, EnsureType, ToTensor
from openvino.inference_engine import IECore
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader as TorchDataLoader
from IPython.display import display, Markdown

## PyTorch Lightning Monai Model

In [None]:
class MonaiModel(pl.LightningModule):
    def __init__(self):
        super().__init__()

        # self._model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=2, init_features=4, growth_rate=4)
        self._model = DenseNet(
            spatial_dims=2, in_channels=1, out_channels=1, block_config=[4, 8, 6]
        ).cpu()

        # https://docs.monai.io/en/latest/highlights.html?deterministic-training-for-reproducibility
        monai.utils.set_determinism(seed=2.71828, additional_settings=None)

        self.loss_function = torch.nn.BCEWithLogitsLoss()
        self.metric = monai.metrics.ConfusionMatrixMetric(metric_name="accuracy")
        self.best_val_accuracy = 0
        self.best_val_epoch = 0

    def forward(self, x):
        return self._model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self._model.parameters())
        return optimizer

    def training_step(self, batch, batch_idx):
        images, labels = batch
        labels = labels.float()
        output = self.forward(images)
        loss = self.loss_function(output, labels)
        self.log("train_loss", loss.item())
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        labels = labels.float()
        output = self.forward(images)
        loss = self.loss_function(output, labels)

        # Compute statistics for metric computation
        y_true = labels.long()
        y_pred = torch.sigmoid(output).round().long()
        self.metric(y_pred, y_true)

        self.log("val_loss", loss)
        return {"val_loss": loss, "val_number": len(output)}

    def validation_epoch_end(self, outputs):
        val_loss, num_items = 0, 0

        for output in outputs:
            val_loss += output["val_loss"].sum().item()
            num_items += output["val_number"]
        # mean_val_dice = self.metric.avg_value["F1"]
        # self.metric.reset()
        mean_val_accuracy = self.metric.aggregate()[0].item()
        mean_val_loss = torch.tensor(val_loss / num_items)
        self.logger.experiment.add_scalar("Loss/Validation", mean_val_loss, self.current_epoch)
        self.logger.experiment.add_scalar(
            "Accuracy/Validation", mean_val_accuracy, self.current_epoch
        )
        self.log("accuracy", mean_val_accuracy, prog_bar=True, logger=False)

        if mean_val_accuracy > self.best_val_accuracy:
            self.best_val_accuracy = mean_val_accuracy
            self.best_val_epoch = self.current_epoch

## PyTorch Lightning DataModule

In [None]:
class DataModule(pl.LightningDataModule):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size

    def setup(self, stage=None):
        random.seed(1.414213)
        transforms = Compose([ToTensor(dtype=torch.float), AddChannel(), EnsureType()])

        train_data = medmnist.BreastMNIST(split="train", transform=None, download=True)
        val_data = medmnist.BreastMNIST(split="val", transform=None, download=True)

        self.dataset_train = Dataset(
            [(np.array(item[0]), item[1][0]) for item in train_data], transform=transforms
        )
        self.dataset_val = Dataset(
            [(np.array(item[0]), item[1][0]) for item in val_data], transform=transforms
        )

        print(f"Setup train dataset: {len(self.dataset_train)} items")
        print(f"Setup val dataset: {len(self.dataset_val)} items")

        assert len(self.dataset_train) > 0, "Train dataset is empty."
        assert len(self.dataset_val) > 0, "Val dataset is empty"

    def train_dataloader(self):
        return TorchDataLoader(
            self.dataset_train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=0,
            pin_memory=torch.cuda.is_available(),
        )

    def val_dataloader(self):
        return TorchDataLoader(
            self.dataset_val,
            batch_size=self.batch_size,
            num_workers=0,
            shuffle=False,
            pin_memory=torch.cuda.is_available(),
        )

    def test_dataloader(self):
        return self.val_dataloader()

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir tb_logs --bind_all

## Train

In [None]:
data = DataModule(batch_size=24)
model = MonaiModel()

In [None]:
logger = TensorBoardLogger("tb_logs", name="medmnist_breast")
checkpoint_callback = ModelCheckpoint(monitor="accuracy", mode="max", save_top_k=3)

In [None]:
data.setup()
input_image, input_label = next(iter(data.val_dataloader()))

In [None]:
USE_CUDA = False
trainer = pl.Trainer(
    max_epochs=50,
    gpus=1 if USE_CUDA else 0,
    logger=logger,
    precision=16 if USE_CUDA else 32,
    limit_train_batches=0.5,
    limit_val_batches=0.5,
    # callbacks=[checkpoint_callback],
    fast_dev_run=False,  # set to True to quickly test Lightning model
)

start = datetime.datetime.now()
print(start.strftime("%H:%M:%S"))
try:
    trainer.fit(model, data)
finally:
    end = datetime.datetime.now()
    print(end.strftime("%H:%M:%S"))
    delta = dateutil.relativedelta.relativedelta(end, start)
    print(f"Training duration: {delta.hours:02d}:{delta.minutes:02d}:{delta.seconds:02d}")

## Convert to ONNX

In [None]:
onnx_path = "medmnist_breast.onnx"
dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(model._model.cpu().eval(), dummy_input, onnx_path, opset_version=10)
print(f"Exported ONNX model to {onnx_path}")

## Convert to OpenVINO IR

In [None]:
!mo --input_model medmnist_breast.onnx

## POT

### Accuracy Metric

In [None]:
# The sigmoid function is used to transform the result of the network
# to binary segmentation masks
def sigmoid(x):
    return np.exp(-np.logaddexp(0, -x))


class Accuracy(Metric):
    def __init__(self):
        super().__init__()
        self._name = "accuracy"
        self._matches = []

    @property
    def value(self):
        """Returns accuracy metric value for the last model output."""
        return {self._name: [self._matches[-1]]}

    @property
    def avg_value(self):
        """
        Returns accuracy metric value for all model outputs. Results per image are stored in
        self._matches, where True means a correct prediction and False a wrong prediction.
        Accuracy is computed as the number of correct predictions divided by the total
        number of predictions.
        """
        num_correct = np.count_nonzero(self._matches)
        return {self._name: num_correct / len(self._matches)}

    def update(self, output, target):
        """Updates prediction matches.

        :param output: model output
        :param target: annotations
        """
        predict = sigmoid(output[0]).round().astype(np.uint8).squeeze()
        match = predict == target[0].squeeze().numpy().astype(np.uint8)
        # print(predict)
        # print(target[0].squeeze().item())
        # print('...')
        self._matches.append(match)

    def reset(self):
        """
        Resets the Accuracy metric. This is a required method that should initialize all
        attributes to their initial value.
        """
        self._matches = []

    def get_attributes(self):
        """
        Returns a dictionary of metric attributes {metric_name: {attribute_name: value}}.
        Required attributes: 'direction': 'higher-better' or 'higher-worse'
                             'type': metric type
        """
        return {self._name: {"direction": "higher-better", "type": "accuracy"}}

### Data Loader

In [None]:
class ClassificationDataLoader(DataLoader):
    """
    DataLoader for image data that is stored in a directory per category. For example, for
    categories _rose_ and _daisy_, rose images are expected in data_source/rose, daisy images
    in data_source/daisy.
    """

    def __init__(self, dataset):
        """
        :param dataset: dataset
        """
        self.dataset = dataset

    def __len__(self):
        """
        Returns the number of elements in the dataset
        """
        return len(self.dataset)

    def __getitem__(self, index):
        """
        Get item from self.dataset at the specified index.
        Returns (annotation, image), where annotation is a tuple (index, class_index)
        and image a preprocessed image in network shape
        """
        image, label = self.dataset[index]
        annotation = (index, label)
        return annotation, image

### Config

In [None]:
model_config = Dict(
    {
        "model_name": "medmnist_breast",
        "model": "medmnist_breast.xml",
        "weights": "medmnist_breast.bin",
    }
)

engine_config = Dict({"device": "CPU", "stat_requests_number": 2, "eval_requests_number": 2})

algorithms = [
    {
        "name": "AccuracyAwareQuantization",
        "params": {
            "target_device": "CPU",
            "preset": "mixed",
            "stat_subset_size": 1000,
        },
    }
]

### Execute POT 

In [None]:
# Step 1: Load the model
model = load_model(model_config=model_config)
original_model = copy.deepcopy(model)

# Step 2: Initialize the data loader
data_loader = ClassificationDataLoader(dataset=data.dataset_val)

# Step 3 (Optional. Required for AccuracyAwareQuantization): Initialize the metric
#        Compute metric results on original model
metric = Accuracy()

# Step 4: Initialize the engine for metric calculation and statistics collection
engine = IEEngine(config=engine_config, data_loader=data_loader, metric=metric)

# Step 5: Create a pipeline of compression algorithms
pipeline = create_pipeline(algo_config=algorithms, engine=engine)

# Step 6: Execute the pipeline
compressed_model = pipeline.run(model=model)

# Step 7 (Optional): Compress model weights quantized precision
#                    in order to reduce the size of final .bin file
compress_model_weights(model=compressed_model)

# Step 8: Save the compressed model and get the path to the model
compressed_model_paths = save_model(
    model=compressed_model, save_path=os.path.join(os.path.curdir, "model/optimized")
)

compressed_model_xml = Path(compressed_model_paths[0]["model"])
print(f"The quantized model is stored in {compressed_model_xml}")

## Evaluate Results

In [None]:
# Step 9 (Optional): Evaluate the original and compressed model. Print the results
original_metric_results = pipeline.evaluate(original_model)
if original_metric_results:
    print(f"Accuracy of the original model:  {next(iter(original_metric_results.values())):.5f}")

In [None]:
quantized_metric_results = pipeline.evaluate(compressed_model)
if quantized_metric_results:
    print(f"Accuracy of the quantized model: {next(iter(quantized_metric_results.values())):.5f}")

## Benchmark 

In [None]:
def benchmark_model(model_path: os.PathLike,
                    device: str = "CPU",
                    seconds: int = 60, api: str = "async",
                    batch: int = 1, 
                    cache_dir="model_cache"):
    ie = IECore()
    model_path = Path(model_path)
    if ("GPU" in device) and ("GPU" not in ie.available_devices):
        raise ValueError(f"A GPU device is not available. Available devices are: {ie.available_devices}")
    else:
        benchmark_command = f"benchmark_app -m {model_path} -d {device} -t {seconds} -api {api} -b {batch} -cdir {cache_dir}"
        display(Markdown(f"**Benchmark {model_path.name} with {device} for {seconds} seconds with {api} inference**"));
        display(Markdown(f"Benchmark command: `{benchmark_command}`"));

        benchmark_output = %sx $benchmark_command
        benchmark_result = [line for line in benchmark_output
                            if not (line.startswith(r"[") or line.startswith("  ") or line == "")]
        print("\n".join(benchmark_result))
        print()
        if "MULTI" in device:
            devices = device.replace("MULTI:","").split(",")
            for single_device in devices:
                print(f"{single_device} device: {ie.get_metric(device_name=single_device, metric_name='FULL_DEVICE_NAME')}")
        else:
            print(f"Device: {ie.get_metric(device_name=device, metric_name='FULL_DEVICE_NAME')}")

In [None]:
## FP32 model
benchmark_model("medmnist_breast.xml", device="CPU", seconds=15)

In [None]:
## INT8 model
benchmark_model(compressed_model_xml, "CPU", seconds=15)

## Check Results

In [None]:
ie = IECore()

data = DataModule(batch_size=6)
data.setup()
model = MonaiModel()
net = ie.read_network(compressed_model_xml)
net.batch_size = 6
counts = []
exec_net = ie.load_network(net, "CPU")
input_layer = next(iter(exec_net.input_info))
for input_image, input_label in data.val_dataloader():
    raw_result = exec_net.infer(inputs={input_layer: input_image})["386"]
    result = sigmoid(raw_result).round().astype(np.uint8)
    counts.append(input_label.numpy().squeeze() == result.squeeze())

In [None]:
np.count_nonzero(counts) / np.size(counts)

In [None]:
# fig, axs = plt.subplots(nrows=4, ncols=6, figsize=(15, 12))
# for i, ax in enumerate(axs.ravel()):
#     ax.imshow(input_image[i][0], cmap="gray")