In [None]:
%pip install --quiet --upgrade monai pytorch_lightning git+https://github.com/MedMNIST/MedMNIST.git pytorch_lightning openvino-dev

In [None]:
## full dataset
# https://scholar.cu.edu.eg/?q=afahmy/pages/dataset
# https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6906728/
# http://arxiv-export-lb.library.cornell.edu/pdf/2110.14795

## Imports

In [None]:
import copy
import datetime
import os
import random
import sys
from pathlib import Path

import dateutil
import matplotlib.pyplot as plt
import medmnist
import monai
import numpy as np
import pytorch_lightning as pl
import torch
import nncf  # Important - should be imported directly after torch
from nncf import NNCFConfig
from nncf.torch import create_compressed_model
from nncf.torch import register_default_init_args
from addict import Dict

from compression.api import DataLoader, Metric
from compression.engines.ie_engine import IEEngine
from compression.graph import load_model, save_model
from compression.graph.model_utils import compress_model_weights
from compression.pipeline.initializer import create_pipeline
from monai.data import Dataset
from monai.networks.nets import DenseNet
from monai.transforms import AddChannel, Compose, EnsureType, ToTensor
from openvino.inference_engine import IECore
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader as TorchDataLoader
from IPython.display import display, Markdown

## PyTorch Lightning Monai Model

In [None]:
class MonaiModel(pl.LightningModule):
    def __init__(self):
        super().__init__()

        # self._model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=2, init_features=4, growth_rate=4)
        self._model = DenseNet(
            spatial_dims=2, in_channels=1, out_channels=1, block_config=[4, 8, 6]
        ).cpu()

        # https://docs.monai.io/en/latest/highlights.html?deterministic-training-for-reproducibility
        monai.utils.set_determinism(seed=2.71828, additional_settings=None)

        self.loss_function = torch.nn.BCEWithLogitsLoss()
        self.metric = monai.metrics.ConfusionMatrixMetric(metric_name="accuracy")
        self.best_val_accuracy = 0
        self.best_val_epoch = 0

    def forward(self, x):
        return self._model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self._model.parameters())
        return optimizer

    def training_step(self, batch, batch_idx):
        images, labels = batch
        labels = labels.float()
        output = self.forward(images)
        loss = self.loss_function(output, labels)
        self.log("train_loss", loss.item())
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        labels = labels.float()
        output = self.forward(images)
        loss = self.loss_function(output, labels)

        # Compute statistics for metric computation
        y_true = labels.long()
        y_pred = torch.sigmoid(output).round().long()
        self.metric(y_pred, y_true)

        self.log("val_loss", loss)
        return {"val_loss": loss, "val_number": len(output)}

    def validation_epoch_end(self, outputs):
        val_loss, num_items = 0, 0

        for output in outputs:
            val_loss += output["val_loss"].sum().item()
            num_items += output["val_number"]
        # mean_val_dice = self.metric.avg_value["F1"]
        # self.metric.reset()
        mean_val_accuracy = self.metric.aggregate()[0].item()
        mean_val_loss = torch.tensor(val_loss / num_items)
        self.logger.experiment.add_scalar("Loss/Validation", mean_val_loss, self.current_epoch)
        self.logger.experiment.add_scalar(
            "Accuracy/Validation", mean_val_accuracy, self.current_epoch
        )
        self.log("accuracy", mean_val_accuracy, prog_bar=True, logger=False)

        if mean_val_accuracy > self.best_val_accuracy:
            self.best_val_accuracy = mean_val_accuracy
            self.best_val_epoch = self.current_epoch

## PyTorch Lightning DataModule

In [None]:
class DataModule(pl.LightningDataModule):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size

    def setup(self, stage=None):
        random.seed(1.414213)
        transforms = Compose([ToTensor(dtype=torch.float), AddChannel(), EnsureType()])

        train_data = medmnist.BreastMNIST(split="train", transform=None, download=True)
        val_data = medmnist.BreastMNIST(split="val", transform=None, download=True)

        self.dataset_train = Dataset(
            [(np.array(item[0]), item[1][0]) for item in train_data], transform=transforms
        )
        self.dataset_val = Dataset(
            [(np.array(item[0]), item[1][0]) for item in val_data], transform=transforms
        )

        print(f"Setup train dataset: {len(self.dataset_train)} items")
        print(f"Setup val dataset: {len(self.dataset_val)} items")

        assert len(self.dataset_train) > 0, "Train dataset is empty."
        assert len(self.dataset_val) > 0, "Val dataset is empty"

    def train_dataloader(self):
        return TorchDataLoader(
            self.dataset_train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=0,
            pin_memory=torch.cuda.is_available(),
        )

    def val_dataloader(self):
        return TorchDataLoader(
            self.dataset_val,
            batch_size=self.batch_size,
            num_workers=0,
            shuffle=False,
            pin_memory=torch.cuda.is_available(),
        )

    def test_dataloader(self):
        return self.val_dataloader()

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir tb_logs --bind_all

## Train

In [None]:
data = DataModule(batch_size=24)
monai_model = MonaiModel()

In [None]:
logger = TensorBoardLogger("tb_logs", name="medmnist_breast")
checkpoint_callback = ModelCheckpoint(monitor="accuracy", mode="max", save_top_k=3)

In [None]:
data.setup()
input_image, input_label = next(iter(data.val_dataloader()))

In [None]:
USE_CUDA = False
trainer = pl.Trainer(
    max_epochs=50,
    gpus=1 if USE_CUDA else 0,
    logger=logger,
    precision=16 if USE_CUDA else 32,
    limit_train_batches=0.5,
    # callbacks=[checkpoint_callback],
    fast_dev_run=False,  # set to True to quickly test Lightning model
)

start = datetime.datetime.now()
print(start.strftime("%H:%M:%S"))
trainer.fit(monai_model, data)
end = datetime.datetime.now()
print(end.strftime("%H:%M:%S"))
delta = dateutil.relativedelta.relativedelta(end, start)
print(f"Training duration: {delta.hours:02d}:{delta.minutes:02d}:{delta.seconds:02d}")

## Convert to ONNX

In [None]:
onnx_path = "medmnist_breast.onnx"
dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(monai_model._model.cpu().eval(), dummy_input, onnx_path, opset_version=10)
print(f"Exported ONNX model to {onnx_path}")

## NNCF

In [None]:
OUTPUT_DIR = Path("output")
OUTPUT_DIR.mkdir(exist_ok=True)

nncf_config_dict = {
    "input_info": {"sample_size": [1, 1, 28, 28]},
    "log_dir": str(OUTPUT_DIR),  # log directory for NNCF-specific logging outputs
    "compression": {
        "algorithm": "quantization",  # specify the algorithm here
    },
}
nncf_config = NNCFConfig.from_dict(nncf_config_dict)

In [None]:
train_loader = data.train_dataloader()
nncf_config = register_default_init_args(nncf_config, train_loader)

In [None]:
compression_ctrl, model = create_compressed_model(monai_model._model, nncf_config);

In [None]:
monai_model._model = model

trainer = pl.Trainer(
    max_epochs=10,
    gpus=1 if USE_CUDA else 0,
    logger=logger,
    limit_train_batches=0.5,
    limit_val_batches=0.5,
    # callbacks=[checkpoint_callback],
    fast_dev_run=False,  # set to True to quickly test Lightning model
)

start = datetime.datetime.now()
print(start.strftime("%H:%M:%S"))
trainer.fit(monai_model, data)
end = datetime.datetime.now()
print(end.strftime("%H:%M:%S"))
delta = dateutil.relativedelta.relativedelta(end, start)
print(f"Training duration: {delta.hours:02d}:{delta.minutes:02d}:{delta.seconds:02d}")

In [None]:
compression_ctrl.export_model("medmnist_breast_nncf_int8.onnx")

In [None]:
!mo --data_type FP16 --input_model medmnist_breast_nncf_int8.onnx

In [None]:
!mo --data_type FP16 --input_model medmnist_breast.onnx

## Benchmark

In [None]:
def benchmark_model(model_path: os.PathLike,
                    device: str = "CPU",
                    seconds: int = 60, api: str = "async",
                    batch: int = 1, 
                    cache_dir="model_cache"):
    ie = IECore()
    model_path = Path(model_path)
    if ("GPU" in device) and ("GPU" not in ie.available_devices):
        raise ValueError(f"A GPU device is not available. Available devices are: {ie.available_devices}")
    else:
        benchmark_command = f"benchmark_app -m {model_path} -d {device} -t {seconds} -api {api} -b {batch} -cdir {cache_dir}"
        display(Markdown(f"**Benchmark {model_path.name} with {device} for {seconds} seconds with {api} inference**"));
        display(Markdown(f"Benchmark command: `{benchmark_command}`"));

        benchmark_output = %sx $benchmark_command
        benchmark_result = [line for line in benchmark_output
                            if not (line.startswith(r"[") or line.startswith("  ") or line == "")]
        print("\n".join(benchmark_result))
        print()
        if "MULTI" in device:
            devices = device.replace("MULTI:","").split(",")
            for single_device in devices:
                print(f"{single_device} device: {ie.get_metric(device_name=single_device, metric_name='FULL_DEVICE_NAME')}")
        else:
            print(f"Device: {ie.get_metric(device_name=device, metric_name='FULL_DEVICE_NAME')}")

In [None]:
## FP32 model
benchmark_model("medmnist_breast.xml", device="CPU", seconds=15)

In [None]:
## INT8 model
benchmark_model("medmnist_breast_nncf_int8.xml", "CPU", seconds=15)

## Check Accuracy

In [None]:
# The sigmoid function is used to transform the result of the network
# to binary segmentation masks
def sigmoid(x):
    return np.exp(-np.logaddexp(0, -x))

model_files = ["medmnist_breast.xml", "medmnist_breast_nncf_int8.xml"]

ie = IECore()

data = DataModule(batch_size=6)
data.setup()

for model_xml in model_files:
    net = ie.read_network(model_xml)
    net.batch_size = 6
    counts = []
    exec_net = ie.load_network(net, "CPU")
    input_layer = next(iter(exec_net.input_info))
    output_layer = next(iter(exec_net.outputs))
    for input_image, input_label in data.val_dataloader():
        raw_result = exec_net.infer(inputs={input_layer: input_image})[output_layer]
        result = sigmoid(raw_result).round().astype(np.uint8)
        counts.append(input_label.numpy().squeeze() == result.squeeze())

    accuracy = np.count_nonzero(counts) / np.size(counts)
    print(model_xml, accuracy)