# How to add custom model

This is a simple Jupyter notebook example to show how the developer can add a new model for the OTX task (Multi class classification at this time) and execute model training.
First let me start with importing everything we need.

In [1]:
from __future__ import annotations

from typing import TYPE_CHECKING, Any

import torch
from otx.core.data.entity.classification import (
    MulticlassClsBatchDataEntity,
    MulticlassClsBatchPredEntity,
)
from otx.core.model.entity.classification import OTXMulticlassClsModel
from torch import nn
from torchvision.models.resnet import ResNet50_Weights, resnet50

if TYPE_CHECKING:
    from otx.core.data.entity.base import OTXBatchLossEntity


Now, we have everything we need. Before we start, please keep in mind that this is not our end image. The training from Python API design is not determined yet and this is very first place.

The first thing is that we need to develop the actual PyTorch Model which should be created in `OTXModel._create_model()` function.
As you know, `OTXModel` is required to produce the task losses in the training.
On the other hand, it should produce the model predictions from the image in the evaluation.
Therefore, this `nn.Module` should be able to compute the task losses.
This is important thing you have to notice.
Let's see the code now.

In [2]:
class ResNet50WithLossComputation(nn.Module):
    def __init__(self, num_classes: int) -> None:
        super().__init__()
        self.num_classes = num_classes
        net = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        net.fc = nn.Linear(
            in_features=net.fc.in_features, out_features=self.num_classes
        )
        self.net = net
        self.softmax = nn.Softmax(dim=-1)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, images: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        logits = self.net(images)

        if self.training:
            return self.criterion(logits, labels)

        return self.softmax(logits)


The next thing is that we need to develop the one derived from `OTXModel`.
However, in this example, we want to add the multi class classification model.
We should implement the class derived from `OTXClassificationModel`.
For another OTX task, such as `OTXTaskType.DETECTION`, we might be able to make a custom model by deriving from `OTXDetectionModel`.

Since every `OTXModel` is an abstract class, it is designed to require a developer to implement three abstract functions:

1) `_create_model()`
2) `_customize_inputs()`
3) `_customize_outputs()`

You can see that the following example is exactly implementing those three functions.
Let's see together.

In [3]:
class OTXResNet50(OTXMulticlassClsModel):
    def __init__(self, num_classes: int) -> None:
        super().__init__(num_classes=num_classes)
        self.register_buffer(
            "mean",
            torch.FloatTensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
            False,
        )
        self.register_buffer(
            "std",
            torch.FloatTensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
            False,
        )

    def _create_model(self) -> nn.Module:
        # ResNet50_Weights.IMAGENET1K_V2 is a really powerful pretrained model equipped with the modern training scheme:
        # ImageNet-1K acc@1: 80.858, acc@5": 95.434.
        return ResNet50WithLossComputation(num_classes=self.num_classes)

    def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]:
        images = inputs.images.to(dtype=torch.float32)
        images = (images - self.mean) / self.std
        return {
            "images": images,
            "labels": torch.cat(inputs.labels, dim=0),
        }

    def _customize_outputs(
        self, outputs: Any, inputs: MulticlassClsBatchDataEntity
    ) -> MulticlassClsBatchPredEntity | OTXBatchLossEntity:
        if self.training:
            return {"loss": outputs}

        # To list, batch-wise
        scores = torch.unbind(outputs, 0)

        return MulticlassClsBatchPredEntity(
            batch_size=inputs.batch_size,
            images=inputs.images,
            imgs_info=inputs.imgs_info,
            scores=scores,
            labels=inputs.labels,
        )


Now, we have our own custom model which can be used for the OTX training process.
However, there are many things and configurations for the model training.
**Setting these things from scratch is so sick.**
Therefore, we will **borrow the configurations from the similar model for the multi class classification task OTX provided: `classification/otx_efficientnet_b0`.**
We just override our custom model on top of that.

It means that we can compose **1) Data transform from MMPretrain and 2) Custom model from TorchVision**.
As you know, our design let them independent each other.
Therefore, any composition, not just this example, such as 1) Data transform from TorchVision and 2) Custom model from Detectron, is possible.

Please see the following how we do that.

In [4]:
from otx.engine import Engine

data_dir = "../tests/assets/classification_dataset"
otx_model = OTXResNet50(num_classes=2)

engine = Engine(
    data_root=data_dir,
    model=otx_model,
    device="gpu",
    work_dir="otx-workspace",
)

engine.train(max_epochs=10)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type               | Params
---------------------------------------------------
0 | model       | OTXResNet50        | 23.5 M
1 | val_metric  | MulticlassAccuracy | 0     
2 | test_metric | MulticlassAccura

                                                                           

/home/harimkan/workspace/repo/otx-regression/venv/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 9: 100%|██████████| 1/1 [00:00<00:00,  8.83it/s, v_num=9, train/loss=0.563, val/accuracy=1.000]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 1/1 [00:00<00:00,  2.92it/s, v_num=9, train/loss=0.563, val/accuracy=1.000]


{'train/loss': tensor(0.5628), 'val/accuracy': tensor(1.)}

**Saying again. This is not the end image of the OTX training API. We will continue to strive to improve it so that users can use it conveniently. And, I believe that it is not difficult since we already have a solid core design and it is just an entrypoint.**