# How to add custom model

This is a simple Jupyter notebook example to show how the developer can add a new model for the OTX task (Multi class classification at this time) and execute model training.
First let me start with importing everything we need.

In [None]:
from typing import Any

import torch
from torch import nn
from torchvision.models.resnet import resnet50, ResNet50_Weights

from hydra import compose, initialize

from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.classification import (
    MulticlassClsBatchDataEntity,
    MulticlassClsBatchPredEntity,
)
from otx.core.model.entity.classification import OTXClassificationModel
from otx.core.engine.train import train
from otx.cli.utils.hydra import configure_hydra_outputs


Now, we have everything we need. Before we start, please keep in mind that this is not our end image. The training from Python API design is not determined yet and this is very first place.

The first thing is that we need to develop the actual PyTorch Model which should be created in `OTXModel._create_model()` function.
As you know, `OTXModel` is required to produce the task losses in the training.
On the other hand, it should produce the model predictions from the image in the evaluation.
Therefore, this `nn.Module` should be able to compute the task losses.
This is important thing you have to notice.
Let's see the code now.

In [1]:
class ResNet50WithLossComputation(nn.Module):
    def __init__(self, num_classes: int) -> None:
        super().__init__()
        self.num_classes = num_classes
        net = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        net.fc = nn.Linear(
            in_features=net.fc.in_features, out_features=self.num_classes
        )
        self.net = net
        self.softmax = nn.Softmax(dim=-1)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, images: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        logits = self.net(images)

        if self.training:
            return self.criterion(logits, labels)

        return self.softmax(logits)


The next thing is that we need to develop the one derived from `OTXModel`.
However, in this example, we want to add the multi class classification model.
We should implement the class derived from `OTXClassificationModel`.
For another OTX task, such as `OTXTaskType.DETECTION`, we might be able to make a custom model by deriving from `OTXDetectionModel`.

Since every `OTXModel` is an abstract class, it is designed to require a developer to implement three abstract functions:

1) `_create_model()`
2) `_customize_inputs()`
3) `_customize_outputs()`

You can see that the following example is exactly implementing those three functions.
Let's see together.

In [None]:
class OTXResNet50(OTXClassificationModel):
    def __init__(self, num_classes: int) -> None:
        self.num_classes = num_classes
        super().__init__()
        self.register_buffer(
            "mean",
            torch.FloatTensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
            False,
        )
        self.register_buffer(
            "std",
            torch.FloatTensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
            False,
        )

    def _create_model(self) -> nn.Module:
        # ResNet50_Weights.IMAGENET1K_V2 is a really powerful pretrained model equipped with the modern training scheme:
        # ImageNet-1K acc@1: 80.858, acc@5": 95.434.
        return ResNet50WithLossComputation(num_classes=self.num_classes)

    def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]:
        images = torch.stack(inputs.images, dim=0).to(dtype=torch.float32)
        images = (images - self.mean) / self.std
        return {
            "images": images,
            "labels": torch.cat(inputs.labels, dim=0),
        }

    def _customize_outputs(
        self, outputs: Any, inputs: MulticlassClsBatchDataEntity
    ) -> MulticlassClsBatchPredEntity | OTXBatchLossEntity:
        if self.training:
            return {"loss": outputs}

        # To list, batch-wise
        scores = torch.unbind(outputs, 0)

        return MulticlassClsBatchPredEntity(
            batch_size=inputs.batch_size,
            images=inputs.images,
            imgs_info=inputs.imgs_info,
            scores=scores,
            labels=inputs.labels,
        )


Now, we have our own custom model which can be used for the OTX training process.
However, there are many things and configurations for the model training.
Setting these things from scratch is so sick.
Therefore, we will borrow the configurations from the similar model for the multi class classification task OTX provided: `classification/otx_efficientnet_b0`.
We just override our custom model on top of that.
Please see the following how we do that.

In [2]:
num_classes = 2
data_dir = "../tests/assets/classification_dataset"

with initialize(
    config_path="../src/otx/config", version_base="1.3", job_name="otx_train"
):
    overrides = [
        "+recipe=classification/otx_efficientnet_b0",
        "base.output_dir=outputs",
        "trainer.accelerator=gpu",
        f"base.data_dir={data_dir}",
    ]
    cfg = compose(config_name="train", overrides=overrides, return_hydra_config=True)
    configure_hydra_outputs(cfg)

    otx_model = OTXResNet50(num_classes=num_classes)
    train(cfg, otx_model=otx_model)


[2023-12-08 14:44:09,539][root][INFO] - Instantiating datamodule <{'data_format': 'imagenet_with_subset_dirs', 'data_root': '${base.data_dir}', 'train_subset': {'batch_size': 64, 'subset_name': 'train', 'transform_lib_type': <TransformLibType.MMPRETRAIN: 'MMPRETRAIN'>, 'transforms': [{'type': 'LoadImageFromFile'}, {'backend': 'cv2', 'scale': 224, 'type': 'RandomResizedCrop'}, {'type': 'PackInputs'}], 'num_workers': 2}, 'val_subset': {'batch_size': 64, 'subset_name': 'val', 'transform_lib_type': <TransformLibType.MMPRETRAIN: 'MMPRETRAIN'>, 'transforms': [{'type': 'LoadImageFromFile'}, {'backend': 'cv2', 'edge': 'short', 'scale': 256, 'type': 'ResizeEdge'}, {'crop_size': 224, 'type': 'CenterCrop'}, {'type': 'PackInputs'}], 'num_workers': 2}, 'test_subset': {'batch_size': 64, 'subset_name': 'test', 'transform_lib_type': <TransformLibType.MMPRETRAIN: 'MMPRETRAIN'>, 'transforms': [{'type': 'LoadImageFromFile'}, {'backend': 'cv2', 'edge': 'short', 'scale': 256, 'type': 'ResizeEdge'}, {'crop_

Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


[2023-12-08 14:44:11,479][root][INFO] - Logging hyperparameters!
[2023-12-08 14:44:11,483][root][INFO] - Starting training!


You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/vinnamki/miniconda3/envs/otx-v2/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:639: Checkpoint directory outputs/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Output()

`Trainer.fit` stopped: `max_epochs=10` reached.


**Saying again. This is not the end image of the OTX training API. We will continue to strive to improve it so that users can use it conveniently. And, I believe that it is not difficult since we already have a solid core design and it is just an entrypoint.**