# HerdNet Training with Custom Configuration

This notebook contains the same config as this command line (adapt paths as needed):
```bash
PYTHONPATH="$pwd:PYTHONPATH" uv run tools/train.py \
    train.datasets.train.csv_file=$PWD/data/groundtruth/csv/train_big_size_A_B_E_K_WH_WB_points.csv \
    train.datasets.train.root_dir=$PWD/data/train \
    train.datasets.validate.csv_file=$PWD/data/groundtruth/csv/val_big_size_A_B_E_K_WH_WB_points.csv \
    train.datasets.validate.root_dir=$PWD/data/val \
    train.datasets.anno_type=point
```

## Setup and Imports

In [1]:
import os
import sys
import torch
import pandas as pd
import wandb
import albumentations as A
from torch.utils.data import DataLoader
from omegaconf import OmegaConf

# Add project root to Python path
project_root = os.path.abspath('..')
sys.path.insert(0, project_root)

import animaloc
from animaloc.models.utils import LossWrapper, load_model
from animaloc.eval import Evaluator, PointsMetrics, Stitcher, BoxesMetrics, ImageLevelMetrics
from animaloc.utils.seed import set_seed
from animaloc.utils.useful_funcs import current_date

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name()}")

PyTorch version: 2.8.0+cu128
CUDA available: True
CUDA device: NVIDIA GeForce RTX 4050 Laptop GPU


## Configuration Setup

In [2]:
# Load base configuration using the main config (like train.py does)
main_cfg_path = os.path.join(project_root, 'configs/train/herdnet.yaml')
cfg = OmegaConf.load(main_cfg_path)

# Override with custom paths (matching the command line args)
cfg.datasets.train.csv_file = os.path.join(project_root, 'data/groundtruth/csv/train_big_size_A_B_E_K_WH_WB_points.csv')
cfg.datasets.train.root_dir = os.path.join(project_root, 'data/train')
cfg.datasets.validate.csv_file = os.path.join(project_root, 'data/groundtruth/csv/val_big_size_A_B_E_K_WH_WB_points.csv')
cfg.datasets.validate.root_dir = os.path.join(project_root, 'data/val')
cfg.datasets.anno_type = 'point'

# Update class definitions for 6 classes based on your data
cfg.datasets.num_classes = 7  # Including background class
cfg.datasets.class_def = {
    1: 'buffalo',
    2: 'elephant', 
    3: 'kob',
    4: 'warthog',
    5: 'waterbuck',
    6: 'other'
}

# Update loss weights for 6 classes + background
cfg.losses.CrossEntropyLoss.kwargs.weight = [0.1, 5., 15., 1., 5., 5., 1.]

# Resolve interpolations - fix the incorrect interpolation keys in the config
# The config uses ${train.datasets.num_classes} but should use ${datasets.num_classes}
# We need to manually resolve these until the config is fixed
if 'end_transforms' in cfg.datasets.train and cfg.datasets.train.end_transforms is not None:
    if 'MultiTransformsWrapper' in cfg.datasets.train.end_transforms:
        if 'FIDT' in cfg.datasets.train.end_transforms.MultiTransformsWrapper:
            cfg.datasets.train.end_transforms.MultiTransformsWrapper.FIDT.num_classes = cfg.datasets.num_classes
            cfg.datasets.train.end_transforms.MultiTransformsWrapper.FIDT.down_ratio = cfg.model.kwargs.down_ratio
        if 'PointsToMask' in cfg.datasets.train.end_transforms.MultiTransformsWrapper:
            cfg.datasets.train.end_transforms.MultiTransformsWrapper.PointsToMask.num_classes = cfg.datasets.num_classes

if 'end_transforms' in cfg.datasets.validate and cfg.datasets.validate.end_transforms is not None:
    if 'DownSample' in cfg.datasets.validate.end_transforms:
        cfg.datasets.validate.end_transforms.DownSample.down_ratio = cfg.model.kwargs.down_ratio
        cfg.datasets.validate.end_transforms.DownSample.anno_type = cfg.datasets.anno_type

# Also resolve interpolations in training_settings.stitcher
if cfg.training_settings.stitcher is not None:
    cfg.training_settings.stitcher.kwargs.down_ratio = cfg.model.kwargs.down_ratio

print("Configuration loaded and customized")
print(f"Training CSV: {cfg.datasets.train.csv_file}")
print(f"Training root: {cfg.datasets.train.root_dir}")
print(f"Validation CSV: {cfg.datasets.validate.csv_file}")
print(f"Validation root: {cfg.datasets.validate.root_dir}")
print(f"Number of classes: {cfg.datasets.num_classes}")
print(f"Annotation type: {cfg.datasets.anno_type}")

Configuration loaded and customized
Training CSV: /home/lmanrique/Do/HerdNet/data/groundtruth/csv/train_big_size_A_B_E_K_WH_WB_points.csv
Training root: /home/lmanrique/Do/HerdNet/data/train
Validation CSV: /home/lmanrique/Do/HerdNet/data/groundtruth/csv/val_big_size_A_B_E_K_WH_WB_points.csv
Validation root: /home/lmanrique/Do/HerdNet/data/val
Number of classes: 7
Annotation type: point


## Utility Functions

These functions replicate the helper functions from the train.py script

In [3]:
def _load_albu_transforms(tr_cfg: dict) -> list:
    transforms = []
    for name, kwargs in tr_cfg.items():
        transforms.append(A.__dict__[name](**kwargs))
    return transforms

def _load_end_transforms(tr_cfg):
    if tr_cfg is not None:
        transforms = []
        for name, kwargs in tr_cfg.items():
            if name == 'MultiTransformsWrapper':
                tr_list = []
                for n, k in kwargs.items():
                    tr_list.append(animaloc.data.transforms.__dict__[n](**k))
                transforms.append(animaloc.data.transforms.__dict__[name](tr_list))
            else:
                transforms.append(animaloc.data.transforms.__dict__[name](**kwargs))
        return transforms
    else:
        return None

def _get_collate_fn(cfg):
    fn = cfg.datasets.collate_fn
    if fn is not None:
        fn = animaloc.data.batch_utils.__dict__[fn]
    return fn

def _build_model(cfg) -> torch.nn.Module:
    name = cfg.model.name
    from_torchvision = cfg.model.from_torchvision
    
    if from_torchvision:
        import torchvision
        assert name in torchvision.models.__dict__.keys(), f"'{name}' unfound in torchvision's models"
        model = torchvision.models.__dict__[name]
    else:
        assert name in animaloc.models.__dict__.keys(), f"'{name}' class unfound, make sure you have included the class in the models list"
        model = animaloc.models.__dict__[name]
    
    kwargs = dict(cfg.model.kwargs)
    for k in ['num_classes']:
        kwargs.pop(k, None)
    
    model = model(**kwargs, num_classes=cfg.datasets.num_classes)
    return model

def _load_losses(cfg) -> tuple:
    criterions = []
    if cfg.losses is not None:
        for loss, args in cfg.losses.items():
            kwargs = {}
            if 'kwargs' in args.keys():
                kwargs = dict(args.kwargs)
                
                if 'weights' in kwargs.keys():
                    kwargs['weights'] = torch.Tensor(kwargs['weights'])
                elif 'weight' in kwargs.keys():
                    kwargs['weight'] = torch.Tensor(kwargs['weight']).to(torch.device(cfg.device_name))
            
            crit_dict = {}
            if args.from_torch:
                crit_dict.update({'loss': torch.nn.__dict__[loss](**kwargs)})
            else:
                crit_dict.update({'loss': animaloc.train.losses.__dict__[loss](**kwargs)})
            
            crit_dict.update({
                'idx': args.output_idx,
                'idy': args.target_idx,
                'lambda': args.lambda_const,
                'name': args.print_name
            })
            
            criterions.append(crit_dict)
    
    return criterions

def _define_stitcher(model, cfg):
    kwargs = dict(cfg.training_settings.stitcher.kwargs)
    for k in ['model', 'size', 'device_name']:
        kwargs.pop(k, None)
    
    stitcher = animaloc.eval.stitchers.__dict__[cfg.training_settings.stitcher.name](
        model=model,
        size=cfg.datasets.img_size,
        **kwargs,
        device_name=cfg.device_name
    )
    
    return stitcher

def _define_evaluator(model, dataloader, cfg):
    name = cfg.training_settings.evaluator.name
    anno_type = cfg.datasets.anno_type
    
    assert name in animaloc.eval.evaluators.__dict__.keys(), f"'{name}' class unfound"
    
    if anno_type == 'point':
        metrics = PointsMetrics(
            radius=cfg.training_settings.evaluator.threshold,
            num_classes=cfg.datasets.num_classes
        )
    elif anno_type == 'bbox':
        metrics = BoxesMetrics(
            iou=cfg.training_settings.evaluator.threshold,
            num_classes=cfg.datasets.num_classes
        )
    elif anno_type == 'image':
        metrics = ImageLevelMetrics(
            num_classes=cfg.datasets.num_classes
        )
    else:
        raise NotImplementedError
    
    stitcher = None
    if cfg.training_settings.stitcher is not None:
        stitcher = _define_stitcher(model, cfg)
    
    kwargs = dict(cfg.training_settings.evaluator.kwargs)
    for k in ['model', 'dataloader', 'metrics', 'device_name', 'stitcher', 'header', 'vizual_fn']:
        kwargs.pop(k, None)
    
    vizual_fn = None
    if cfg.training_settings.vizual_fn is not None:
        vizual_fn = animaloc.vizual.plots.__dict__[cfg.training_settings.vizual_fn]
    
    evaluator = animaloc.eval.evaluators.__dict__[name](
        model=model,
        dataloader=dataloader,
        metrics=metrics,
        device_name=cfg.device_name,
        stitcher=stitcher,
        header='[TEST]',
        vizual_fn=vizual_fn,
        **kwargs
    )
    
    return evaluator

print("Utility functions defined")

Utility functions defined


## Initialize Training

In [4]:
# Set the seed
print(f'Setting the seed to {cfg.seed}')
set_seed(cfg.seed)

# Setup device
device = torch.device(cfg.device_name)
print(f"Using device: {device}")

Setting the seed to 1
Using device: cuda


## Create Datasets and DataLoaders

In [5]:
print('Building datasets ...')

# Training dataset
train_args = cfg.datasets.train
val_args = cfg.datasets.validate

train_df = pd.read_csv(train_args.csv_file)
print(f"Training data loaded: {len(train_df)} samples")
print(f"Training CSV columns: {list(train_df.columns)}")
print(f"Training data shape: {train_df.shape}")

train_dataset = animaloc.datasets.__dict__[train_args.name](
    csv_file=train_df,
    root_dir=train_args.root_dir,
    albu_transforms=_load_albu_transforms(train_args.albu_transforms),
    end_transforms=_load_end_transforms(train_args.end_transforms)
)

train_dl_kwargs = dict(
    batch_size=cfg.training_settings.batch_size,
    shuffle=True,
    collate_fn=_get_collate_fn(cfg)
)

train_dataloader = DataLoader(train_dataset, **train_dl_kwargs)
print(f"Training dataloader created with batch size: {cfg.training_settings.batch_size}")

# Validation dataset
val_dataloader = None
if val_args is not None:
    val_df = pd.read_csv(val_args.csv_file)
    print(f"Validation data loaded: {len(val_df)} samples")
    print(f"Validation data shape: {val_df.shape}")
    
    val_dataset = animaloc.datasets.__dict__[val_args.name](
        csv_file=val_df,
        root_dir=val_args.root_dir,
        albu_transforms=_load_albu_transforms(val_args.albu_transforms),
        end_transforms=_load_end_transforms(val_args.end_transforms)
    )
    
    val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=_get_collate_fn(cfg))
    print("Validation dataloader created")

print(f"Dataset setup complete!")

Building datasets ...
Training data loaded: 6962 samples
Training CSV columns: ['images', 'x', 'y', 'labels']
Training data shape: (6962, 4)
Training dataloader created with batch size: 4
Validation data loaded: 978 samples
Validation data shape: (978, 4)
Validation dataloader created
Dataset setup complete!


## Setup Weights & Biases Logging

In [6]:
print('Connecting to Weights & Biases ...')
settings = cfg.training_settings
losses = cfg.losses
if losses is not None:
    losses = list(cfg.losses.keys())

wandb.init(
    project=cfg.wandb_project,
    entity=cfg.wandb_entity,
    config=dict(
        batch_size=settings.batch_size,
        optimizer=settings.optimizer,
        lr=settings.lr,
        weight_decay=settings.weight_decay,
        warmup_iters=settings.warmup_iters,
        epochs=settings.epochs,
        losses=losses,
        seed=cfg.seed,
        data_augmentation=list(cfg.datasets.train.albu_transforms.keys()),
        input_size=cfg.datasets.img_size,
        **cfg.model.kwargs
    )
)

date = current_date()
wandb.run.name = f'{date}_' + cfg.wandb_run + f'_RUN_{wandb.run.id}'
print(f"W&B run name: {wandb.run.name}")

Connecting to Weights & Biases ...


[34m[1mwandb[0m: Currently logged in as: [33mluis-manrique-car[0m ([33mluis-manrique-car-camera-traps[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


W&B run name: 20250913_camera-traps-train_RUN_ax140pht


## Build Model and Setup Training

In [7]:
print('Building the model ...')
model = _build_model(cfg)
print(f"Model created: {cfg.model.name}")

print('Preparing for training ...')
criterions = _load_losses(cfg)
model = LossWrapper(model, criterions).to(device)
print(f"Model wrapped with losses and moved to {device}")

# Load pretrained weights if specified
if cfg.model.load_from is not None:
    model = load_model(model, cfg.model.load_from)
    print(f"Loaded pretrained weights from {cfg.model.load_from}")
    
    if 'HerdNet' in cfg.model.name:
        if cfg.model.freeze is not None:
            model.model.freeze(layers=list(cfg.model.freeze))
            print(f"Layers {list(cfg.model.freeze)} frozen")

# Setup optimizer
if cfg.training_settings.optimizer == 'adam':
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=cfg.training_settings.lr,
        weight_decay=cfg.training_settings.weight_decay
    )
else:
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=cfg.training_settings.lr,
        weight_decay=cfg.training_settings.weight_decay
    )

print(f"Optimizer: {cfg.training_settings.optimizer}")
print(f"Learning rate: {cfg.training_settings.lr}")
print(f"Weight decay: {cfg.training_settings.weight_decay}")

# Watch the model's gradients during training
wandb.watch(model)

Building the model ...
Model created: HerdNet
Preparing for training ...
Model wrapped with losses and moved to cuda
Optimizer: adam
Learning rate: 0.0001
Weight decay: 0.0005


## Setup Evaluator

In [8]:
evaluator = None
validate_on = 'recall'
select = 'min'

if cfg.training_settings.evaluator is not None:
    assert val_dataloader is not None, 'A validation dataset must be defined to build an evaluator'
    
    evaluator = _define_evaluator(model, val_dataloader, cfg)
    select = cfg.training_settings.evaluator.select_mode
    validate_on = cfg.training_settings.evaluator.validate_on
    print(f"Evaluator created - selecting {select} {validate_on}")
else:
    print("No evaluator configured")

Evaluator created - selecting max f1_score


## Create Trainer and Start Training

In [9]:
# Setup auto learning rate scheduler
auto_lr = cfg.training_settings.auto_lr
if auto_lr:
    auto_lr = dict(cfg.training_settings.auto_lr)
    print(f"Auto LR scheduler enabled: {auto_lr}")

# Setup visualization function
vizual_fn = None
if cfg.training_settings.vizual_fn is not None:
    vizual_fn = animaloc.vizual.plots.__dict__[cfg.training_settings.vizual_fn]

# Create trainer
trainer = animaloc.train.trainers.__dict__[cfg.training_settings.trainer](
    model,
    train_dataloader,
    optimizer=optimizer,
    num_epochs=cfg.training_settings.epochs,
    auto_lr=auto_lr,
    val_dataloader=val_dataloader,
    evaluator=evaluator,
    device_name=cfg.device_name,
    vizual_fn=vizual_fn,
    work_dir=None,
    print_freq=cfg.training_settings.print_freq,
    valid_freq=cfg.training_settings.valid_freq,
)

print(f"Trainer created: {cfg.training_settings.trainer}")
print(f"Training for {cfg.training_settings.epochs} epochs")
print(f"Validation frequency: every {cfg.training_settings.valid_freq} epoch(s)")
print(f"Print frequency: every {cfg.training_settings.print_freq} iterations")

Auto LR scheduler enabled: {'mode': 'max', 'patience': 10, 'threshold': 0.0001, 'threshold_mode': 'rel', 'cooldown': 10, 'min_lr': 1e-06}
Trainer created: Trainer
Training for 100 epochs
Validation frequency: every 1 epoch(s)
Print frequency: every 100 iterations


In [10]:
# Start training
if cfg.model.resume_from is not None:
    print(f'Resuming training from \'{cfg.model.resume_from}\' ...')
    trainer.resume(
        pth_path=cfg.model.resume_from,
        select=select,
        validate_on=validate_on,
        load_optim=True,
        wandb_flag=True
    )
else:
    print('Starting training ...')
    trainer.start(
        cfg.training_settings.warmup_iters,
        select=select,
        validate_on=validate_on,
        wandb_flag=True
    )

Starting training ...
[TRAINING] - Epoch: [1] [  1/232] eta: 0:04:22 lr: 0.000002 loss: 15318.1357 (15318.1357) focal_loss: 15316.1514 (15316.1514) ce_loss: 1.9848 (1.9848) time: 1.1331 data: 0.5681 max mem: 2104


KeyboardInterrupt: 

## Post-Training: Add Metadata to Model Files

In [None]:
# Add information in .pth files
print("Adding metadata to model files...")
for pth_name in ['best_model.pth', 'latest_model.pth']:
    path = os.path.join(os.getcwd(), pth_name)
    if os.path.exists(path):
        pth_file = torch.load(path)
        norm_trans = _load_albu_transforms(train_args.albu_transforms)[-1]
        pth_file['classes'] = dict(cfg.datasets.class_def)
        pth_file['mean'] = list(norm_trans.mean)
        pth_file['std'] = list(norm_trans.std)
        torch.save(pth_file, path)
        print(f"Updated {pth_name} with metadata")
    else:
        print(f"Warning: {pth_name} not found")

print("Training completed!")

## Optional: Display Training Results

In [None]:
# Display some training information
print("\n=== Training Summary ===")
print(f"Model: {cfg.model.name}")
print(f"Number of classes: {cfg.datasets.num_classes}")
print(f"Training samples: {len(train_dataset)}")
if val_dataloader:
    print(f"Validation samples: {len(val_dataset)}")
print(f"Batch size: {cfg.training_settings.batch_size}")
print(f"Learning rate: {cfg.training_settings.lr}")
print(f"Epochs: {cfg.training_settings.epochs}")
print(f"Device: {device}")

# Show W&B run URL
print(f"\nW&B Run: {wandb.run.url}")