# Installation

In [None]:
!pip install -r requirements.txt

# Required Code definitions

In [1]:
from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from pathlib import Path

import nemo
import numpy
import torch
import torchvision
from PIL import Image
from pycocotools.coco import COCO
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm


class CNNBackbone(Enum):
    CUSTOM = 0
    MOBILENET_V2 = 1


@dataclass
class DatasetConfig:
    image_directory: str | Path
    annotation_file: str | Path


@dataclass
class TrainerConfig:
    device: str
    epochs: int
    batch_size: int
    learning_rate: float
    backbone: CNNBackbone
    pretrained: bool
    train_data: DatasetConfig
    validation_data: DatasetConfig
    model_output_file: str | Path
    augmentation: (
        None | torchvision.transforms.Compose | torchvision.transforms.v2.Compose
    ) = None
    augmentation_in_validation: bool = False


@dataclass
class ExporterConfig:
    device: str
    model_input_file: str | Path
    model_quantized_output_file: str | Path
    backbone: CNNBackbone
    image_width: int
    image_height: int
    layers_output_dir: str | Path


class COCOBBoxDataset(Dataset):
    def __init__(
        self,
        image_dir: str | Path,
        annotation_file: str | Path,
        augmentation: (
            None | torchvision.transforms.Compose | torchvision.transforms.v2.Compose
        ) = None,
    ) -> None:
        super().__init__()

        self._image_dir = image_dir if isinstance(image_dir, Path) else Path(image_dir)
        assert self._image_dir.is_dir(), "Image directory does not exists"

        self._annotation_file = (
            annotation_file
            if isinstance(annotation_file, Path)
            else Path(annotation_file)
        )
        assert self._annotation_file.is_file(), "Annotation file does not exists"

        self._coco = COCO(self._annotation_file)
        self._image_ids = self._coco.getImgIds()

        self._augmentation = augmentation

    def __len__(self) -> int:
        return len(self._image_ids)

    def __getitem__(self, index: int) -> tuple[Image.Image, torch.Tensor]:
        image_id = self._image_ids[index]
        image_name = self._coco.loadImgs(image_id)[0]["file_name"]
        image = Image.open(self._image_dir / image_name).convert("RGB")

        annotation_ids = self._coco.getAnnIds(imgIds=[image_id], iscrowd=False)
        annotations = self._coco.loadAnns(annotation_ids)

        if len(annotations) == 0:
            bboxes = torch.zeros(3, dtype=torch.float32)
        else:
            # TODO: This only works if every image has only one bbox
            x, y, w, h = annotations[0]["bbox"]
            x_n, y_n = x / image.width, y / image.height
            w_n, h_n = w / image.width, h / image.height
            xc = x_n + w_n / 2
            yc = y_n + h_n / 2
            side = max(w_n, h_n)
            bboxes = torch.tensor([xc, yc, side], dtype=torch.float32)

        if self._augmentation is not None:
            # TODO: This would transform the current image, but will not create new images. In addition, if the image is augmented it does not fit to the bbox anymore?
            # img = self.transform(img)
            pass
        image = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])(
            image
        )

        return image, bboxes


class Detector(torch.nn.Module):
    def __init__(self, backbone: CNNBackbone) -> None:
        super().__init__()

        if backbone == CNNBackbone.CUSTOM:
            raise Exception("Not implemented")
        elif backbone == CNNBackbone.MOBILENET_V2:
            detector = torchvision.models.mobilenet_v2(
                pretrained=False, width_mult=0.25
            )  # TODO: pretrained && width_mult configurable
            self._feature_extractor = detector.features
        else:
            raise ValueError(f"Unsupported backbone: {backbone}")

        self._pooling = torch.nn.AdaptiveAvgPool2d(1)
        self._classifier = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(detector.last_channel, 3),
            torch.nn.Sigmoid(),
        )

    def forward(self, x):
        x = self._feature_extractor(x)
        x = self._pooling(x)
        x = self._classifier(x)
        return x


class Trainer:
    class Metric:
        def __init__(self):
            self.sum = 0.0
            self.count = 0

        def update(self, val, n=1):
            self.sum += val * n
            self.count += n

        @property
        def avg(self):
            return self.sum / self.count if self.count else 0

    @classmethod
    def train(cls, config: TrainerConfig) -> None:
        device = torch.device(config.device)

        train_dataloader = DataLoader(
            COCOBBoxDataset(
                image_dir=config.train_data.image_directory,
                annotation_file=config.train_data.annotation_file,
                augmentation=config.augmentation,
            ),
            batch_size=config.batch_size,
        )

        validation_dataloader = DataLoader(
            COCOBBoxDataset(
                image_dir=config.validation_data.image_directory,
                annotation_file=config.validation_data.annotation_file,
                augmentation=(
                    config.augmentation if config.augmentation_in_validation else None
                ),
            ),
            batch_size=config.batch_size,
        )

        model = Detector(backbone=config.backbone).to(device)
        optimizer = torch.optim.Adam(
            model.parameters(), lr=config.learning_rate
        )  # TODO: config optimizer
        loss_criterion = torch.nn.MSELoss()  # TODO: config loss

        for epoch in range(1, config.epochs + 1):
            train_loss = cls._train_epoch(
                model, device, train_dataloader, optimizer, loss_criterion
            )
            # TODO: validation
            # validation_loss = validate(model, device, validation_dataloader, loss_criterion)
            # print(f"Epoch {epoch} / {config.epochs}: train_loss={train_loss:.4f}, validation_loss={validation_loss:.4f}")

        model_output_file = config.model_output_file if isinstance(config.model_output_file, Path) else Path(config.model_output_file)
        model_output_file.parent.mkdir(parents=True, exist_ok=True)
        torch.save(model.state_dict(), model_output_file)
        print("Training done")

    @classmethod
    def _train_epoch(
        cls,
        model: torch.nn.Module,
        device: torch.device,
        dataloader: DataLoader,
        optimizer: Optimizer,
        loss_criterion,
    ) -> float:
        metric = cls.Metric()
        model.train()
        for images, bboxes in tqdm(dataloader, desc="Train", leave=False, unit="batch"):
            images, bboxes = images.to(device), bboxes.to(device)
            optimizer.zero_grad()
            predictions = model(images)
            loss = loss_criterion(predictions, bboxes)
            loss.backward()
            optimizer.step()
            metric.update(loss.item(), images.size(0))
        return metric.avg


class Exporter:
    @staticmethod
    def export(config: ExporterConfig) -> None:
        device = torch.device(config.device)

        model = Detector(backbone=config.backbone).to(device)
        model.load_state_dict(torch.load(config.model_input_file, map_location=device))
        model.eval()

        dummy_input = torch.randn((1, 3, config.image_height, config.image_width)).to(
            device
        )

        activations = []
        hooks = []
        for name, module in model.named_modules():
            if len(list(module.children())) == 0:
                hooks.append(module.register_forward_hook(lambda module, inp, outp: activations.append(outp.detach().cpu().numpy().ravel())))

        # Execute model to fill activations
        _ = model(dummy_input)
        
        layers_output_dir = (
            config.layers_output_dir
            if isinstance(config.layers_output_dir, Path)
            else Path(config.layers_output_dir)
        )
        layers_output_dir.mkdir(parents=True, exist_ok=True)

        numpy.savetxt(
            layers_output_dir / "input.txt", dummy_input.numpy().ravel(), delimiter=","
        )

        for index, activation in enumerate(activations):
            numpy.savetxt(
                layers_output_dir / f"out_layer{index}.txt", activation, delimiter=","
            )
        for hook in hooks:
            hook.remove()

        model_quantized = nemo.transform.quantize_pact(model, device, dummy_input=dummy_input)
        model_quantized.change_precision(bits=1)  # TODO: bits confnigurable
        model_quantized = nemo.transform.bn_to_identity(model_quantized)
        model_quantized.qd_stage(eps_in=1 / 255)
        model_quantized.id_stage()
        model_quantized.eval()

        nemo.utils.export_onnx(
            config.model_quantized_output_file,
            model,
            model_quantized,
            (3, config.image_height, config.image_width),
            config.device
        )
        print("Export done")

In [2]:
# TODO: implement augmentation
# augmentation = torchvision.transforms.v2.Compose([
# torchvision.transforms.RandomResizedCrop((244, 324), scale=(0.8, 1.0), ratio=(0.75, 1.33)),
# torchvision.transforms.RandomHorizontalFlip(p=0.5),
# torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
# torchvision.transforms.ToTensor()
# ])

# Training

In [None]:
config = TrainerConfig(
    device="cpu",
    epochs=1,
    batch_size=32,
    learning_rate=1e-3,
    backbone=CNNBackbone.MOBILENET_V2,
    pretrained=False,
    train_data=DatasetConfig(
        image_directory=".", annotation_file="annotations/train.json"
    ),
    validation_data=DatasetConfig(
        image_directory=".", annotation_file="annotations/validation.json"
    ),
    model_output_file="output/model.pth",
    # augmentation=augmentation
)

Trainer.train(config)

# Export to ONNX & Quantization

In [None]:
config = ExporterConfig(
    device="cpu",
    model_input_file="output/model.pth",
    model_quantized_output_file="output/model_q.onnx",
    backbone=CNNBackbone.MOBILENET_V2,
    image_width=324,
    image_height=244,
    layers_output_dir="output",
)

Exporter.export(config)

# Dory (ONNX -> C)

#### Download dory script

In [None]:
!wget -O network_generate.py https://raw.githubusercontent.com/nkaaf/pulp-platform-dory/refs/heads/master/network_generate.py

#### Parameters

In [44]:
CONFIG_FILE_NAME="drone_ai.json"
MODEL_FILE_NAME="output/model_q.onnx"
EXPORT_DIR="c_export"

#### Create config

In [None]:
!mkdir --parents config && cp "config.template" "config/$CONFIG_FILE_NAME" && sed -i "s|%MODEL_FILE_NAME|../$MODEL_FILE_NAME|" "config/$CONFIG_FILE_NAME" && echo "Config created in: config/$CONFIG_FILE_NAME"

#### Start

In [51]:
!rm --recursive --force "export/$EXPORT_DIR" && mkdir --parents "export/$EXPORT_DIR" && python3 network_generate.py NEMO PULP.GAP8_L2 "config/$CONFIG_FILE_NAME" --app_dir "export/$EXPORT_DIR"  && echo "Exported to: export/$EXPORT_DIR"

Using NEMO as frontend. Targeting PULP.GAP8_L2 platform. 
Using config/../output/model_q.onnx target input onnx.

Creating Original_graph.onnx in logs/Frontend/onnx_files/
Creating Original_graph.json in logs/Frontend/json_files/
onnx_to_dory net prefix: 

##################################
## DORY GENERAL PARSING OF ONNX ##
## FINAL RAPRESENTATION:DORY IR ##
##################################

Parsing ONNX Graph to create DORY graph.
Creating 00_DORY_raw_graph.json in logs/Frontend/json_files/
Creating 00_DORY_raw_graph.onnx in logs/Frontend/onnx_files/

Embedding constant nodes inside nodes to which the tensors belong.
Creating 01_DORY_graph_constants_removed.json in logs/Frontend/json_files/
Creating 01_DORY_graph_constants_removed.onnx in logs/Frontend/onnx_files/

NEMO Frontend: Matching patterns from generated ONNX to DORY.

NEMO Frontend: Updating Add nodes with constants.
Creating 02_DORY_mapped_graph.json in logs/Frontend/json_files/
Creating 02_DORY_ma