# How to add custom model

This is a simple Jupyter notebook example to show how the developer can add a new model for the OTX task (Multi class classification at this time) and execute model training.
First let me start with importing everything we need.

In [1]:
from typing import Any

import torch
from lightning.pytorch.cli import ReduceLROnPlateau
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.classification import (
    MulticlassClsBatchDataEntity,
    MulticlassClsBatchPredEntity,
)
from otx.core.model.entity.classification import OTXClassificationModel
from otx.core.model.module.classification import OTXClassificationLitModule
from torch import nn
from torchvision.models.resnet import ResNet50_Weights, resnet50


  from .autonotebook import tqdm as notebook_tqdm


Now, we have everything we need. Before we start, please keep in mind that this is not our end image. The training from Python API design is not determined yet and this is very first place.

The first thing is that we need to develop the actual PyTorch Model which should be created in `OTXModel._create_model()` function.
As you know, `OTXModel` is required to produce the task losses in the training.
On the other hand, it should produce the model predictions from the image in the evaluation.
Therefore, this `nn.Module` should be able to compute the task losses.
This is important thing you have to notice.
Let's see the code now.

In [2]:
class ResNet50WithLossComputation(nn.Module):
    def __init__(self, num_classes: int) -> None:
        super().__init__()
        self.num_classes = num_classes
        net = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        net.fc = nn.Linear(
            in_features=net.fc.in_features, out_features=self.num_classes
        )
        self.net = net
        self.softmax = nn.Softmax(dim=-1)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, images: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        logits = self.net(images)

        if self.training:
            return self.criterion(logits, labels)

        return self.softmax(logits)


The next thing is that we need to develop the one derived from `OTXModel`.
However, in this example, we want to add the multi class classification model.
We should implement the class derived from `OTXClassificationModel`.
For another OTX task, such as `OTXTaskType.DETECTION`, we might be able to make a custom model by deriving from `OTXDetectionModel`.

Since every `OTXModel` is an abstract class, it is designed to require a developer to implement three abstract functions:

1) `_create_model()`
2) `_customize_inputs()`
3) `_customize_outputs()`

You can see that the following example is exactly implementing those three functions.
Let's see together.

In [3]:
class OTXResNet50(OTXClassificationModel):
    def __init__(self, num_classes: int) -> None:
        self.num_classes = num_classes
        super().__init__()
        self.register_buffer(
            "mean",
            torch.FloatTensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
            False,
        )
        self.register_buffer(
            "std",
            torch.FloatTensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
            False,
        )

    def _create_model(self) -> nn.Module:
        # ResNet50_Weights.IMAGENET1K_V2 is a really powerful pretrained model equipped with the modern training scheme:
        # ImageNet-1K acc@1: 80.858, acc@5": 95.434.
        return ResNet50WithLossComputation(num_classes=self.num_classes)

    def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]:
        images = torch.stack(inputs.images, dim=0).to(dtype=torch.float32)
        images = (images - self.mean) / self.std
        return {
            "images": images,
            "labels": torch.cat(inputs.labels, dim=0),
        }

    def _customize_outputs(
        self, outputs: Any, inputs: MulticlassClsBatchDataEntity
    ) -> MulticlassClsBatchPredEntity | OTXBatchLossEntity:
        if self.training:
            return {"loss": outputs}

        # To list, batch-wise
        scores = torch.unbind(outputs, 0)

        return MulticlassClsBatchPredEntity(
            batch_size=inputs.batch_size,
            images=inputs.images,
            imgs_info=inputs.imgs_info,
            scores=scores,
            labels=inputs.labels,
        )

We need to prepare a datamodule. Of course, we can also declare it via a config file and use it, but for the current example we will write it without a config file.

In [4]:
from otx.core.data.module import OTXDataModule
from otx.core.config.data import DataModuleConfig, SubsetConfig

task = "MULTI_CLASS_CLS"
data_dir = "../../tests/assets/classification_dataset"

train_transform = [
    {"type": "LoadImageFromFile"},
    {"type": "RandomResizedCrop", "scale": 224, "backend": "cv2"},
    {"type": "PackInputs"},
]
val_transform = [
    {"type": "LoadImageFromFile"},
    {"type": "ResizeEdge", "scale": 256, "edge": "short", "backend": "cv2"},
    {"type": "PackInputs"},
]

datamodule = OTXDataModule(
    task=task,
    config=DataModuleConfig(
        data_format="imagenet_with_subset_dirs",
        data_root=data_dir,
        train_subset=SubsetConfig(
            batch_size=2,
            subset_name="train",
            transform_lib_type="MMPRETRAIN",
            transforms=train_transform,
        ),
        val_subset=SubsetConfig(
            batch_size=1,
            subset_name="val",
            transform_lib_type="MMPRETRAIN",
            transforms=val_transform,
        ),
        test_subset=SubsetConfig(
            batch_size=1,
            subset_name="test",
            transform_lib_type="MMPRETRAIN",
            transforms=val_transform,
        ),
    ),
)


We can train through a class in OTX called Engine.

In [5]:
from otx.core.engine.engine import Engine

num_classes = 2
lightning_module = OTXClassificationLitModule(
    otx_model=OTXResNet50(num_classes=num_classes),
    torch_compile=False,
    optimizer=lambda p: torch.optim.SGD(p, lr=0.0049, momentum=0.9, weight_decay=0.0001),
    scheduler=lambda o: ReduceLROnPlateau(o, patience=1, factor=0.5, monitor="train/loss"),
)

from otx.core.engine.engine import Engine

engine = Engine(
    task=task,
    work_dir="./otx-workspace",
    device="gpu",
)

engine.train(
    model=lightning_module,
    datamodule=datamodule,
    max_epochs=3,
    precision="16",
)

/home/harimkan/workspace/repo/otx-fork/venv/lib/python3.10/site-packages/lightning/fabric/connector.py:565: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(val_check_interval=1)` was configured so validation will run after every batch.
/home/harimkan/workspace/repo/otx-fork/venv/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:43: attribute 'optimizer' removed from hparams because it cannot be pickled
/home/harimkan/workspace/repo/otx-fork/venv/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:43: attribute 'scheduler' removed from hparams because it cannot be pickled
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them

                                                                           

/home/harimkan/workspace/repo/otx-fork/venv/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (13) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 2: 100%|██████████| 13/13 [00:02<00:00,  5.33it/s, v_num=0, train/loss=0.311, val/accuracy=1.000] 

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 13/13 [00:02<00:00,  4.88it/s, v_num=0, train/loss=0.311, val/accuracy=1.000]


{'train/loss': tensor(0.3110), 'val/accuracy': tensor(1.)}

**Saying again. This is not the end image of the OTX training API. We will continue to strive to improve it so that users can use it conveniently. And, I believe that it is not difficult since we already have a solid core design and it is just an entrypoint.**