In [1]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

import logging
import torch
from torchsummary import summary
import os
import gdown
from dotenv import load_dotenv

logging.basicConfig(level=logging.INFO)

from src.data_loader.data_loader import DataLoader
from src.model.model import TwoHeadConvNeXtV2
from src.config.configuration import CLASS_NUM, IMAGE_ROOT, META_CSV, LABEL_INFO_CSV, NUM_AUGMENTATIONS
from src.model.utils import train_model
from src.data_loader.augmentation import Augmentor

load_dotenv()

INFO:numexpr.utils:NumExpr defaulting to 16 threads.
  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
url = os.getenv("DATASET_URL")
if not url:
    logging.error("DATASET_URL not found in .env file! Please set it and reload the environment variables.")
else:
    output = "data/train_images_medium.zip"
    logging.info(f"Downloading dataset from Google Drive...")
    gdown.download(url, output, quiet=False, fuzzy=True)
    logging.info(f"Downloaded to {output}")

ERROR:root:DATASET_URL not found in .env file! Please set it and reload the environment variables.


In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
logging.info(f"Using device: {device}")
! nvidia-smi

INFO:root:Using device: cuda


Fri Nov 28 00:47:19 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 581.57                 Driver Version: 581.57         CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4060 ...  WDDM  |   00000000:01:00.0  On |                  N/A |
| N/A   44C    P0             13W /  115W |    1367MiB /   8188MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [4]:
img_path = IMAGE_ROOT
label_path = LABEL_INFO_CSV
meta_data_path = META_CSV
data_loader = DataLoader(
    image_data_set_path=img_path,
    meta_data_path=meta_data_path,
    label_info_path=label_path
    )

INFO:root:Initializing DataLoader...
INFO:root:Checking paths...
INFO:root:Loading metadata from data/train_images_metadata.csv...
INFO:root:Loading label info from data/venomous_status_metadata.csv...
INFO:root:Loading image data from data/train_images_large...
Loading metadata: 100%|██████████| 66454/66454 [00:16<00:00, 4002.82it/s]
INFO:root:Train: 53163, Val: 13291


In [5]:
import timm
models = timm.list_models('*convnextv2*')
print(models)


['convnextv2_atto', 'convnextv2_base', 'convnextv2_femto', 'convnextv2_huge', 'convnextv2_large', 'convnextv2_nano', 'convnextv2_pico', 'convnextv2_small', 'convnextv2_tiny']


In [6]:
model = TwoHeadConvNeXtV2(num_multi_classes=CLASS_NUM)

INFO:root:Using device: cuda
INFO:root:Creating TwoHeadConvNeXtV2 with backbone convnextv2_atto.fcmae
INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/convnextv2_atto.fcmae)
INFO:httpx:HTTP Request: HEAD https://huggingface.co/timm/convnextv2_atto.fcmae/resolve/main/model.safetensors "HTTP/1.1 302 Found"
INFO:timm.models._hub:[timm/convnextv2_atto.fcmae] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


In [7]:
model.backbone.default_cfg

{'url': 'https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_atto_1k_224_fcmae.pt',
 'hf_hub_id': 'timm/convnextv2_atto.fcmae',
 'architecture': 'convnextv2_atto',
 'tag': 'fcmae',
 'custom_load': False,
 'input_size': (3, 224, 224),
 'fixed_input_size': False,
 'interpolation': 'bicubic',
 'crop_pct': 0.875,
 'crop_mode': 'center',
 'mean': (0.485, 0.456, 0.406),
 'std': (0.229, 0.224, 0.225),
 'num_classes': 0,
 'pool_size': (7, 7),
 'first_conv': 'stem.0',
 'classifier': 'head.fc',
 'license': 'cc-by-nc-4.0',
 'origin_url': 'https://github.com/facebookresearch/ConvNeXt-V2',
 'paper_name': 'ConvNeXt-V2: Co-designing and Scaling ConvNets with Masked Autoencoders',
 'paper_ids': 'arXiv:2301.00808'}

In [8]:
summary(model, input_size=(3, 224, 224), device=str(device))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 40, 56, 56]           1,960
       LayerNorm2d-2           [-1, 40, 56, 56]              80
          Identity-3           [-1, 40, 56, 56]               0
            Conv2d-4           [-1, 40, 56, 56]           2,000
       LayerNorm2d-5           [-1, 40, 56, 56]              80
            Conv2d-6          [-1, 160, 56, 56]           6,560
              GELU-7          [-1, 160, 56, 56]               0
           Dropout-8          [-1, 160, 56, 56]               0
GlobalResponseNorm-9          [-1, 160, 56, 56]             320
           Conv2d-10           [-1, 40, 56, 56]           6,440
          Dropout-11           [-1, 40, 56, 56]               0
GlobalResponseNormMlp-12           [-1, 40, 56, 56]               0
         Identity-13           [-1, 40, 56, 56]               0
         Identity-14           [-1,

In [9]:
augmentor = Augmentor(num_augmentations=NUM_AUGMENTATIONS, center_n_transforms=2, center_magnitude=10)

In [10]:
train_model(data_loader, model, augmentor=augmentor)

INFO:root:Class-balanced augmentation: Rarest class weight=2.3, Most common weight=0.1
INFO:root:Created 500000 virtual augmented samples
INFO:root:PHASE 1: Training only the heads (backbone frozen)
Phase1 Epoch 1:   1%|          | 105/13830 [01:41<3:41:44,  1.03it/s]


KeyboardInterrupt: 