In [1]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

import logging
import torch
from torchsummary import summary


logging.basicConfig(level=logging.DEBUG)

from src.data_loader.data_loader import DataLoader
from src.model.model import TwoHeadConvNeXtV2
from src.config.configuration import CLASS_NUM

INFO:numexpr.utils:NumExpr defaulting to 16 threads.
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
logging.info(f"Using device: {device}")
! nvidia-smi

INFO:root:Using device: cuda


Sat Nov 22 20:40:27 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 581.57                 Driver Version: 581.57         CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4060 ...  WDDM  |   00000000:01:00.0  On |                  N/A |
| N/A   48C    P8              3W /   30W |    1621MiB /   8188MiB |      9%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [3]:
img_path = "data/train_images_small"
label_path = "data/venomous_status_metadata.csv"
meta_data_path = "data/train_images_metadata.csv"

data_loader = DataLoader(
    image_data_set_path=img_path,
    meta_data_path=meta_data_path,
    label_info_path=label_path
    )

INFO:root:Initializing DataLoader...
INFO:root:Checking paths...
INFO:root:Loading metadata from data/train_images_metadata.csv...
DEBUG:root:Metadata columns: ['observation_id', 'endemic', 'binomial_name', 'code', 'image_path', 'class_id']
INFO:root:Loading label info from data/venomous_status_metadata.csv...
DEBUG:root:Label info columns: ['class_id', 'MIVS']
DEBUG:root:Merged Metadata columns: ['observation_id', 'endemic', 'binomial_name', 'code', 'image_path', 'class_id', 'MIVS']
INFO:root:Loading image data from data/train_images_small...
Loading metadata: 100%|██████████| 66454/66454 [00:09<00:00, 7194.05it/s]
INFO:root:Train: 53163, Val: 13291


In [4]:
model = TwoHeadConvNeXtV2(num_multi_classes=CLASS_NUM)

INFO:root:Using device: cuda
INFO:root:Creating TwoHeadConvNeXtV2 with backbone convnextv2_tiny.fcmae
INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/convnextv2_tiny.fcmae)
DEBUG:httpcore.connection:connect_tcp.started host='huggingface.co' port=443 local_address=None timeout=10 socket_options=None
DEBUG:httpcore.connection:connect_tcp.complete return_value=<httpcore._backends.sync.SyncStream object at 0x000002E8FC268A30>
DEBUG:httpcore.connection:start_tls.started ssl_context=<ssl.SSLContext object at 0x000002E8FFDF9840> server_hostname='huggingface.co' timeout=10
DEBUG:httpcore.connection:start_tls.complete return_value=<httpcore._backends.sync.SyncStream object at 0x000002E8FC268A00>
DEBUG:httpcore.http11:send_request_headers.started request=<Request [b'HEAD']>
DEBUG:httpcore.http11:send_request_headers.complete
DEBUG:httpcore.http11:send_request_body.started request=<Request [b'HEAD']>
DEBUG:httpcore.http11:send_request_body.complete
DEBUG:httpcore.

In [5]:
summary(model, input_size=(3, 224, 224), device=str(device))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 56, 56]           4,704
       LayerNorm2d-2           [-1, 96, 56, 56]             192
          Identity-3           [-1, 96, 56, 56]               0
            Conv2d-4           [-1, 96, 56, 56]           4,800
         LayerNorm-5           [-1, 56, 56, 96]             192
            Linear-6          [-1, 56, 56, 384]          37,248
              GELU-7          [-1, 56, 56, 384]               0
           Dropout-8          [-1, 56, 56, 384]               0
GlobalResponseNorm-9          [-1, 56, 56, 384]             768
           Linear-10           [-1, 56, 56, 96]          36,960
          Dropout-11           [-1, 56, 56, 96]               0
GlobalResponseNormMlp-12           [-1, 56, 56, 96]               0
         Identity-13           [-1, 96, 56, 56]               0
         Identity-14           [-1,

In [None]:
from src.model.utils import train_model
from src.data_loader.augmentation import Augmentor
augmentor = Augmentor(num_augmentations=40000, center_n_transforms=2, center_magnitude=10)
train_model(data_loader, model, augmentor=augmentor)

INFO:root:Using device: cuda
