In [1]:
"""CLI Example"""

! otx train --config src/configs/classification_default.yaml --data.config.data_root tests/assets/classification_dataset --model.otx_model.config.head.num_classes 2 --print_config

engine:
  max_epochs: 10
  deterministic: false
  precision: 32
  val_check_interval: 1
  callbacks:
  - class_path: lightning.pytorch.callbacks.EarlyStopping
    init_args:
      monitor: val/accuracy
      min_delta: 0.0
      patience: 100
      verbose: false
      mode: max
      strict: true
      check_finite: true
      log_rank_zero_only: false
  - class_path: lightning.pytorch.callbacks.RichProgressBar
    init_args:
      refresh_rate: 1
      leave: false
      theme:
        description: white
        progress_bar: '#6206E0'
        progress_bar_finished: '#6206E0'
        progress_bar_pulse: '#6206E0'
        batch_progress: white
        time: grey54
        processing_speed: grey70
        metrics: white
        metrics_text_delimiter: ' '
        metrics_format: .3f
  accelerator: auto
  devices: auto
  strategy: auto
  num_nodes: 1
  fast_dev_run: false
  max_steps: -1
  overfit_batches: 0.0
  check_val_every_n_epoch: 1
  accumulate_grad_batches: 1
  inference_mode: t

In [None]:
! otx train --config src/configs/classification_default.yaml --data.config.data_root tests/assets/classification_dataset --model.otx_model.config.head.num_classes 2 --print_config > test_config.yaml

In [1]:
"""API Example with Config file."""

from otx.core.data.module import OTXDataModule
from otx.core.engine.engine import Engine
from otx.core.model.module.classification import OTXClassificationLitModule

config_file = "test_config.yaml"

model = OTXClassificationLitModule.from_config(config=config_file)
datamodule = OTXDataModule.from_config(task="MULTI_CLASS_CLS", config=config_file)

engine = Engine.from_config(config=config_file)
engine.train(model=model, datamodule=datamodule)

  from .autonotebook import tqdm as notebook_tqdm


Loads checkpoint by http backend from path: https://github.com/d-li14/mobilenetv3.pytorch/blob/master/pretrained/mobilenetv3-large-1cd25616.pth?raw=true
The model and loaded state dict do not match exactly

unexpected key in source state_dict: classifier.0.weight, classifier.0.bias, classifier.3.weight, classifier.3.bias

init weight - https://github.com/d-li14/mobilenetv3.pytorch/blob/master/pretrained/mobilenetv3-large-1cd25616.pth?raw=true
12/15 16:01:15 - mmengine - [4m[97mINFO[0m - 
backbone.features.0.0.weight - torch.Size([16, 3, 3, 3]): 
The value is the same before and after calling `init_weights` of ImageClassifier  
 
12/15 16:01:15 - mmengine - [4m[97mINFO[0m - 
backbone.features.0.1.weight - torch.Size([16]): 
The value is the same before and after calling `init_weights` of ImageClassifier  
 
12/15 16:01:15 - mmengine - [4m[97mINFO[0m - 
backbone.features.0.1.bias - torch.Size([16]): 
The value is the same before and after calling `init_weights` of ImageClassifie

Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(val_check_interval=1)` was configured so validation will run after every batch.
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/harimkan/workspace/repo/otx-fork/venv/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoi

`Trainer.fit` stopped: `max_epochs=10` reached.


{'train/loss': tensor(0.5327), 'val/accuracy': tensor(0.9200)}

In [2]:
from typing import Any

import torch
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.classification import (
    MulticlassClsBatchDataEntity,
    MulticlassClsBatchPredEntity,
)
from otx.core.model.entity.classification import OTXClassificationModel
from torch import nn
from torchvision.models.resnet import ResNet50_Weights, resnet50


class ResNet50WithLossComputation(nn.Module):
    def __init__(self, num_classes: int) -> None:
        super().__init__()
        self.num_classes = num_classes
        net = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        net.fc = nn.Linear(
            in_features=net.fc.in_features, out_features=self.num_classes
        )
        self.net = net
        self.softmax = nn.Softmax(dim=-1)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, images: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        logits = self.net(images)

        if self.training:
            return self.criterion(logits, labels)

        return self.softmax(logits)


class OTXResNet50(OTXClassificationModel):
    def __init__(self, num_classes: int) -> None:
        self.num_classes = num_classes
        super().__init__()
        self.register_buffer(
            "mean",
            torch.FloatTensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
            False,
        )
        self.register_buffer(
            "std",
            torch.FloatTensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
            False,
        )

    def _create_model(self) -> nn.Module:
        # ResNet50_Weights.IMAGENET1K_V2 is a really powerful pretrained model equipped with the modern training scheme:
        # ImageNet-1K acc@1: 80.858, acc@5": 95.434.
        return ResNet50WithLossComputation(num_classes=self.num_classes)

    def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]:
        images = torch.stack(inputs.images, dim=0).to(dtype=torch.float32)
        images = (images - self.mean) / self.std
        return {
            "images": images,
            "labels": torch.cat(inputs.labels, dim=0),
        }

    def _customize_outputs(
        self, outputs: Any, inputs: MulticlassClsBatchDataEntity
    ) -> MulticlassClsBatchPredEntity | OTXBatchLossEntity:
        if self.training:
            return {"loss": outputs}

        # To list, batch-wise
        scores = torch.unbind(outputs, 0)

        return MulticlassClsBatchPredEntity(
            batch_size=inputs.batch_size,
            images=inputs.images,
            imgs_info=inputs.imgs_info,
            scores=scores,
            labels=inputs.labels,
        )

In [4]:
from otx.core.engine.engine import Engine

engine = Engine(max_epochs=5)

model.model=OTXResNet50(num_classes=2)
engine.train(model=model, datamodule=datamodule)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(val_check_interval=1)` was configured so validation will run after every batch.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type               | Params
---------------------------------------------------
0 | model       | OTXResNet50        | 23.5 M
1 | val_metric  | MulticlassAccuracy | 0     
2 | test_metric | MulticlassAccuracy | 0     
---------------------------------------------------
23.5 M    Trainable params
0         Non-trainable params
23.5 M    Total params
94.049    Total estimated model params size (MB)


                                                                           

/home/harimkan/workspace/repo/otx-fork/venv/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 4: 100%|██████████| 1/1 [00:00<00:00,  8.17it/s, v_num=56, train/loss=0.595, val/accuracy=1.000]

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 1/1 [00:00<00:00,  2.89it/s, v_num=56, train/loss=0.595, val/accuracy=1.000]


{'train/loss': tensor(0.5950), 'val/accuracy': tensor(1.)}