In [1]:
# ============================================================
# üß† DEEP LAYER AGGREGATION (DLA) - FULL REIMPLEMENTATION
# ============================================================
# Clean-room reimplementation faithful to the original DLA paper
# and the code by Zhou et al. (2019) used in HerdNet v0.2.1.
# No external dependencies, fully self-contained and notebook-safe.
# ============================================================

import math
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F


# ------------------------------------------------------------
# üîπ Utility Layers
# ------------------------------------------------------------
class Identity(nn.Module):
    """Identity mapping layer."""
    def __init__(self):
        super(Identity, self).__init__()
    def forward(self, x):
        return x


def fill_up_weights(up: nn.ConvTranspose2d) -> None:
    """Initialize ConvTranspose2d weights to perform bilinear upsampling."""
    w = up.weight.data
    f = math.ceil(w.size(2) / 2)
    c = (2 * f - 1 - f % 2) / (2.0 * f)
    for i in range(w.size(2)):
        for j in range(w.size(3)):
            w[0, 0, i, j] = (1 - abs(i / f - c)) * (1 - abs(j / f - c))
    for ch in range(1, w.size(0)):
        w[ch, 0, :, :] = w[0, 0, :, :]


BatchNorm = nn.BatchNorm2d  # keep same notation as original


# ------------------------------------------------------------
# üîπ Basic Blocks
# ------------------------------------------------------------
class BasicBlock(nn.Module):
    """Standard residual block."""
    def __init__(self, inplanes, planes, stride=1, dilation=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3,
                               stride=stride, padding=dilation,
                               bias=False, dilation=dilation)
        self.bn1 = BatchNorm(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=dilation,
                               bias=False, dilation=dilation)
        self.bn2 = BatchNorm(planes)
        self.stride = stride

    def forward(self, x, residual=None):
        if residual is None:
            residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = self.relu(out)
        return out


class Bottleneck(nn.Module):
    """Bottleneck residual block."""
    expansion = 2
    def __init__(self, inplanes, planes, stride=1, dilation=1):
        super(Bottleneck, self).__init__()
        bottle_planes = planes // self.expansion
        self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False)
        self.bn1 = BatchNorm(bottle_planes)
        self.conv2 = nn.Conv2d(bottle_planes, bottle_planes, kernel_size=3,
                               stride=stride, padding=dilation,
                               bias=False, dilation=dilation)
        self.bn2 = BatchNorm(bottle_planes)
        self.conv3 = nn.Conv2d(bottle_planes, planes, kernel_size=1, bias=False)
        self.bn3 = BatchNorm(planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x, residual=None):
        if residual is None:
            residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        out += residual
        out = self.relu(out)
        return out


class BottleneckX(nn.Module):
    """Bottleneck with grouped convolutions (ResNeXt style)."""
    expansion = 2
    cardinality = 32

    def __init__(self, inplanes, planes, stride=1, dilation=1):
        super(BottleneckX, self).__init__()
        cardinality = BottleneckX.cardinality
        bottle_planes = planes * cardinality // 32

        self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False)
        self.bn1 = BatchNorm(bottle_planes)
        self.conv2 = nn.Conv2d(bottle_planes, bottle_planes, kernel_size=3,
                               stride=stride, padding=dilation, bias=False,
                               dilation=dilation, groups=cardinality)
        self.bn2 = BatchNorm(bottle_planes)
        self.conv3 = nn.Conv2d(bottle_planes, planes, kernel_size=1, bias=False)
        self.bn3 = BatchNorm(planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x, residual=None):
        if residual is None:
            residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        out += residual
        out = self.relu(out)
        return out


# ------------------------------------------------------------
# üîπ Root and Tree (hierarchical aggregation)
# ------------------------------------------------------------
class Root(nn.Module):
    """Aggregates multiple feature maps from Tree branches."""
    def __init__(self, in_channels, out_channels, kernel_size, residual):
        super(Root, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size=1,
            stride=1, bias=False, padding=(kernel_size - 1) // 2
        )
        self.bn = BatchNorm(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.residual = residual

    def forward(self, *x):
        children = x
        out = self.conv(torch.cat(x, 1))
        out = self.bn(out)
        if self.residual:
            out += children[0]
        out = self.relu(out)
        return out


class Tree(nn.Module):
    """Recursive hierarchical feature aggregation."""
    def __init__(self, levels, block, in_channels, out_channels, stride=1,
                 level_root=False, root_dim=0, root_kernel_size=1,
                 dilation=1, root_residual=False):
        super(Tree, self).__init__()
        if root_dim == 0:
            root_dim = 2 * out_channels
        if level_root:
            root_dim += in_channels

        if levels == 1:
            self.tree1 = block(in_channels, out_channels, stride, dilation=dilation)
            self.tree2 = block(out_channels, out_channels, 1, dilation=dilation)
        else:
            self.tree1 = Tree(levels - 1, block, in_channels, out_channels, stride,
                              root_dim=0, root_kernel_size=root_kernel_size,
                              dilation=dilation, root_residual=root_residual)
            self.tree2 = Tree(levels - 1, block, out_channels, out_channels,
                              root_dim=root_dim + out_channels,
                              root_kernel_size=root_kernel_size,
                              dilation=dilation, root_residual=root_residual)
        if levels == 1:
            self.root = Root(root_dim, out_channels, root_kernel_size, root_residual)

        self.level_root = level_root
        self.root_dim = root_dim
        self.downsample = None
        self.project = None
        self.levels = levels

        if stride > 1:
            self.downsample = nn.MaxPool2d(stride, stride=stride)
        if in_channels != out_channels:
            self.project = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
                BatchNorm(out_channels)
            )

    def forward(self, x, residual=None, children=None):
        if children is None:
            children = []
        bottom = self.downsample(x) if self.downsample else x
        residual = self.project(bottom) if self.project else bottom
        if self.level_root:
            children.append(bottom)
        x1 = self.tree1(x, residual)
        if self.levels == 1:
            x2 = self.tree2(x1)
            out = self.root(x2, x1, *children)
        else:
            children.append(x1)
            out = self.tree2(x1, children=children)
        return out


# ------------------------------------------------------------
# üîπ DLA (Encoder)
# ------------------------------------------------------------
class DLA(nn.Module):
    """Deep Layer Aggregation backbone."""
    def __init__(self, levels, channels, block=BasicBlock,
                 residual_root=False, return_levels=False):
        super(DLA, self).__init__()
        self.channels = channels
        self.return_levels = return_levels

        self.base_layer = nn.Sequential(
            nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, bias=False),
            BatchNorm(channels[0]),
            nn.ReLU(inplace=True)
        )

        self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
        self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2)
        self.level2 = Tree(levels[2], block, channels[1], channels[2], 2,
                           level_root=False, root_residual=residual_root)
        self.level3 = Tree(levels[3], block, channels[2], channels[3], 2,
                           level_root=True, root_residual=residual_root)
        self.level4 = Tree(levels[4], block, channels[3], channels[4], 2,
                           level_root=True, root_residual=residual_root)
        self.level5 = Tree(levels[5], block, channels[4], channels[5], 2,
                           level_root=True, root_residual=residual_root)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, BatchNorm):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1):
        modules = []
        for i in range(convs):
            modules.extend([
                nn.Conv2d(inplanes, planes, kernel_size=3,
                          stride=stride if i == 0 else 1,
                          padding=dilation, bias=False, dilation=dilation),
                BatchNorm(planes),
                nn.ReLU(inplace=True)
            ])
            inplanes = planes
        return nn.Sequential(*modules)

    def forward(self, x):
        y = []
        x = self.base_layer(x)
        for i in range(6):
            x = getattr(self, f"level{i}")(x)
            y.append(x)
        return y if self.return_levels else x


# ------------------------------------------------------------
# üîπ IDAUp (Iterative Deep Aggregation Upsampling)
# ------------------------------------------------------------
class IDAUp(nn.Module):
    """Iterative Deep Aggregation for upsampling feature maps."""
    def __init__(self, node_kernel, out_dim, channels, up_factors):
        super(IDAUp, self).__init__()
        self.channels = channels
        self.out_dim = out_dim
        for i, c in enumerate(channels):
            if c == out_dim:
                proj = Identity()
            else:
                proj = nn.Sequential(
                    nn.Conv2d(c, out_dim, kernel_size=1, stride=1, bias=False),
                    BatchNorm(out_dim),
                    nn.ReLU(inplace=True)
                )
            f = int(up_factors[i])
            if f == 1:
                up = Identity()
            else:
                up = nn.ConvTranspose2d(
                    out_dim, out_dim, f * 2, stride=f, padding=f // 2,
                    output_padding=0, groups=out_dim, bias=False
                )
                fill_up_weights(up)
            setattr(self, f"proj_{i}", proj)
            setattr(self, f"up_{i}", up)

        for i in range(1, len(channels)):
            node = nn.Sequential(
                nn.Conv2d(out_dim * 2, out_dim, kernel_size=node_kernel,
                          stride=1, padding=node_kernel // 2, bias=False),
                BatchNorm(out_dim),
                nn.ReLU(inplace=True)
            )
            setattr(self, f"node_{i}", node)

    def forward(self, layers):
        assert len(self.channels) == len(layers), f"{len(self.channels)} vs {len(layers)} layers"
        layers = list(layers)
        for i, l in enumerate(layers):
            upsample = getattr(self, f"up_{i}")
            project = getattr(self, f"proj_{i}")
            layers[i] = upsample(project(l))
        x = layers[0]
        y = []
        for i in range(1, len(layers)):
            node = getattr(self, f"node_{i}")
            x = node(torch.cat([x, layers[i]], 1))
            y.append(x)
        return x, y


# ------------------------------------------------------------
# üîπ DLAUp (Full multi-scale upsampling)
# ------------------------------------------------------------
class DLAUp(nn.Module):
    """Combines multiple DLA feature maps into a high-resolution output."""
    def __init__(self, channels, scales=(1, 2, 4, 8, 16), in_channels=None):
        super(DLAUp, self).__init__()
        if in_channels is None:
            in_channels = channels
        self.channels = channels
        channels = list(channels)
        scales = np.array(scales, dtype=int)

        for i in range(len(channels) - 1):
            j = -i - 2
            setattr(self, f"ida_{i}",
                    IDAUp(3, channels[j], in_channels[j:],
                          scales[j:] // scales[j]))
            scales[j + 1:] = scales[j]
            in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]]

    def forward(self, layers):
        layers = list(layers)
        assert len(layers) > 1
        for i in range(len(layers) - 1):
            ida = getattr(self, f"ida_{i}")
            x, y = ida(layers[-i - 2:])
            layers[-i - 1:] = y
        return x


# ------------------------------------------------------------
# üîπ Factory (dla34)
# ------------------------------------------------------------
def dla34(pretrained=False, return_levels=True):
    """Constructs DLA-34 architecture."""
    model = DLA([1, 1, 1, 2, 2, 1],
                [16, 32, 64, 128, 256, 512],
                block=BasicBlock,
                residual_root=False,
                return_levels=return_levels)
    return model

In [2]:
# ============================================================
# üß† HERDNET ARCHITECTURE - FROM SCRATCH REIMPLEMENTATION
# ============================================================
# Faithful reimplementation of HerdNet v0.2.1 (Delplanque, 2024)
# Compatible with the custom DLA + DLAUp backbone defined above.
# No external dependencies or registration mechanisms.
# ============================================================

import torch
import torch.nn as nn
import numpy as np


class HerdNet(nn.Module):
    """
    HerdNet architecture
    --------------------
    Reimplementation of the model introduced by Alexandre Delplanque (v0.2.1),
    designed for density-based localization and classification from aerial images.

    The model uses a DLA encoder and DLAUp decoder to produce:
        - heatmap: localization density map (Sigmoid)
        - clsmap : classification map over object categories

    Parameters
    ----------
    num_layers : int
        Number of DLA layers. Default = 34.
    num_classes : int
        Number of output classes (including background). Default = 2.
    pretrained : bool
        Whether to use pretrained DLA weights. Default = True (ignored in this reimplementation).
    down_ratio : int
        Downsample ratio (1, 2, 4, 8, 16). Defines starting level in DLA. Default = 2.
    head_conv : int
        Number of convolutional filters in the head layers. Default = 64.
    """

    def __init__(
        self,
        num_layers: int = 34,
        num_classes: int = 2,
        pretrained: bool = True,
        down_ratio: int = 2,
        head_conv: int = 64
    ):
        super(HerdNet, self).__init__()

        # --------------------------------------------------------
        # üîπ Sanity check for down_ratio
        # --------------------------------------------------------
        assert down_ratio in [1, 2, 4, 8, 16], \
            f"Invalid down_ratio={down_ratio}. Must be one of [1, 2, 4, 8, 16]."

        self.num_layers = num_layers
        self.num_classes = num_classes
        self.pretrained = pretrained
        self.down_ratio = down_ratio
        self.head_conv = head_conv

        # --------------------------------------------------------
        # üîπ Determine the level index for decoding
        # --------------------------------------------------------
        self.first_level = int(np.log2(down_ratio))

        # --------------------------------------------------------
        # üîπ Select appropriate DLA backbone
        # --------------------------------------------------------
        if num_layers == 34:
            self.base_0 = dla34(pretrained=pretrained, return_levels=True)
        else:
            raise ValueError(f"Unsupported DLA depth: {num_layers}")

        self.channels_0 = self.base_0.channels
        channels = self.channels_0

        # --------------------------------------------------------
        # üîπ DLAUp Decoder (multi-scale feature aggregation)
        # --------------------------------------------------------
        scales = [2 ** i for i in range(len(channels[self.first_level:]))]
        self.dla_up = DLAUp(channels[self.first_level:], scales=scales)

        # --------------------------------------------------------
        # üîπ Bottleneck convolution
        # --------------------------------------------------------
        self.bottleneck_conv = nn.Conv2d(
            channels[-1], channels[-1],
            kernel_size=1, stride=1, padding=0, bias=True
        )

        # --------------------------------------------------------
        # üîπ Localization head (density map)
        # --------------------------------------------------------
        self.loc_head = nn.Sequential(
            nn.Conv2d(channels[self.first_level], head_conv,
                      kernel_size=3, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(head_conv, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.Sigmoid()
        )
        self.loc_head[-2].bias.data.fill_(0.00)

        # --------------------------------------------------------
        # üîπ Classification head (category map)
        # --------------------------------------------------------
        self.cls_head = nn.Sequential(
            nn.Conv2d(channels[-1], head_conv,
                      kernel_size=3, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(head_conv, num_classes,
                      kernel_size=1, stride=1, padding=0, bias=True)
        )
        self.cls_head[-1].bias.data.fill_(0.00)

        # --------------------------------------------------------
        # üîπ Device configuration (auto-detect GPU)
        # --------------------------------------------------------
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(self.device)

    # ============================================================
    # üîπ Forward pass
    # ============================================================
    def forward(self, x: torch.Tensor):
        """Forward propagation of HerdNet.

        Parameters
        ----------
        x : torch.Tensor
            Input image tensor of shape (B, 3, H, W).

        Returns
        -------
        heatmap : torch.Tensor
            Localization map (1, H', W').
        clsmap : torch.Tensor
            Classification map (num_classes, h/32, w/32).
        """
        x = x.to(self.device)

        # Encoder
        encode = self.base_0(x)

        # Bottleneck
        bottleneck = self.bottleneck_conv(encode[-1])
        encode[-1] = bottleneck

        # Decoder (multi-scale upsampling)
        decode_hm = self.dla_up(encode[self.first_level:])

        # Heads
        heatmap = self.loc_head(decode_hm)
        clsmap = self.cls_head(bottleneck)

        return heatmap, clsmap

    # ============================================================
    # üîπ Layer freezing utilities
    # ============================================================
    def freeze(self, layers: list) -> None:
        """Freeze all layers mentioned in the input list."""
        for layer in layers:
            self._freeze_layer(layer)

    def _freeze_layer(self, layer_name: str) -> None:
        for param in getattr(self, layer_name).parameters():
            param.requires_grad = False

    # ============================================================
    # üîπ Adapt class head
    # ============================================================
    def reshape_classes(self, num_classes: int) -> None:
        """Reshape classification head to match a new number of classes."""
        self.cls_head[-1] = nn.Conv2d(
            self.head_conv, num_classes,
            kernel_size=1, stride=1, padding=0, bias=True
        )
        self.cls_head[-1].bias.data.fill_(0.00)
        self.num_classes = num_classes


In [3]:
# ============================================================
# üß© SEMSEG-DLA (Semantic Segmentation Variant)
# ============================================================
# Faithful reimplementation of the original SemSegDLA (Delplanque, 2024)
# Standalone, compatible with our reimplemented DLA and DLAUp modules.
# ============================================================

import torch
import torch.nn as nn
import numpy as np


class SemSegDLA(nn.Module):
    """
    Semantic Segmentation version of DLA with multi-channel output.

    This architecture is a simplified variant of HerdNet focused only
    on producing dense localization heatmaps (no classification head).

    Parameters
    ----------
    num_layers : int
        Number of DLA layers. Default = 34.
    num_classes : int
        Number of output classes (background included). Default = 2.
    pretrained : bool
        Whether to use pretrained DLA weights. Default = True (ignored here).
    down_ratio : int
        Downsampling ratio (1, 2, 4, 8, or 16). Default = 2.
    head_conv : int
        Number of channels in the head convolution layers. Default = 64.
    """

    def __init__(
        self,
        num_layers: int = 34,
        num_classes: int = 2,
        pretrained: bool = True,
        down_ratio: int = 2,
        head_conv: int = 64
    ):
        super(SemSegDLA, self).__init__()

        # --------------------------------------------------------
        # üîπ Sanity check for down_ratio
        # --------------------------------------------------------
        assert down_ratio in [1, 2, 4, 8, 16], \
            f"Downsample ratio must be one of [1, 2, 4, 8, 16], got {down_ratio}."

        self.num_layers = num_layers
        self.num_classes = num_classes
        self.down_ratio = down_ratio
        self.head_conv = head_conv
        self.pretrained = pretrained

        # --------------------------------------------------------
        # üîπ Compute first level from down_ratio
        # --------------------------------------------------------
        self.first_level = int(np.log2(down_ratio))

        # --------------------------------------------------------
        # üîπ Select appropriate DLA backbone
        # --------------------------------------------------------
        if num_layers == 34:
            self.base = dla34(pretrained=pretrained, return_levels=True)
        else:
            raise ValueError(f"Unsupported DLA depth: {num_layers}")

        channels = self.base.channels

        # --------------------------------------------------------
        # üîπ Multi-scale decoder
        # --------------------------------------------------------
        scales = [2 ** i for i in range(len(channels[self.first_level:]))]
        self.dla_up = DLAUp(channels[self.first_level:], scales=scales)

        # --------------------------------------------------------
        # üîπ Heatmap head (Sigmoid output)
        # --------------------------------------------------------
        self.hm_head = nn.Sequential(
            nn.Conv2d(channels[self.first_level], head_conv,
                      kernel_size=3, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(head_conv, num_classes - 1,
                      kernel_size=1, stride=1, padding=0, bias=True),
            nn.Sigmoid()
        )

        # Optional: initialize last conv bias
        # self.hm_head[-2].bias.data.fill_(0.00)

        # --------------------------------------------------------
        # üîπ Device configuration (auto-detect GPU)
        # --------------------------------------------------------
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(self.device)

    # ============================================================
    # üîπ Forward pass
    # ============================================================
    def forward(self, x: torch.Tensor):
        """Forward propagation."""
        x = x.to(self.device)
        encode = self.base(x)
        decode_hm = self.dla_up(encode[self.first_level:])
        heatmap = self.hm_head(decode_hm)
        return heatmap


In [4]:
# ============================================================
# ‚öôÔ∏è MODEL UTILITIES AND LOSS WRAPPER - FROM SCRATCH
# ============================================================
# Faithful reimplementation of load_model, count_parameters,
# and LossWrapper (Delplanque, 2024) for standalone use.
# ============================================================

import torch
from typing import Union, Tuple, List, Optional


def load_model(model: torch.nn.Module, pth_path: str) -> torch.nn.Module:
    """
    Load model parameters from a .pth checkpoint file.

    Parameters
    ----------
    model : torch.nn.Module
        The model whose state_dict will be updated.
    pth_path : str
        Path to the checkpoint (.pth) file.

    Returns
    -------
    torch.nn.Module
        Model with loaded parameters.
    """
    map_location = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    checkpoint = torch.load(pth_path, map_location=map_location)

    if "model_state_dict" in checkpoint:
        model.load_state_dict(checkpoint["model_state_dict"])
    else:
        model.load_state_dict(checkpoint)

    return model


def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
    """
    Compute and print the number of trainable and total parameters in the model.

    Parameters
    ----------
    model : torch.nn.Module
        The model to inspect.

    Returns
    -------
    (trainable_params, total_params) : tuple of int
    """
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())

    print(f"TRAINABLE PARAMETERS: {trainable_params}")
    print(f"TOTAL PARAMETERS: {total_params}")

    return trainable_params, total_params


class LossWrapper(torch.nn.Module):
    """
    nn.Module wrapper to integrate loss computation directly within a model.

    This wrapper allows any model to output both predictions and corresponding
    loss values, depending on the selected mode.

    Parameters
    ----------
    model : torch.nn.Module
        The model to wrap.
    losses : list of dict
        List of dictionaries containing:
            - 'idx': index of model output to use,
            - 'idy': index of target tensor to use,
            - 'name': name of the loss term,
            - 'lambda': regularization multiplier,
            - 'loss': torch.nn.Module loss function.
    mode : str
        Output mode. Must be one of:
            - 'loss_only' : return only the loss dictionary.
            - 'preds_only': return only model predictions.
            - 'both'      : return both predictions and loss dict.
            - 'module'    : during training ‚Üí return loss dict;
                            during eval ‚Üí return both outputs.
        Default = 'module'.
    """

    def __init__(
        self,
        model: torch.nn.Module,
        losses: List[dict],
        mode: str = "module"
    ) -> None:
        super().__init__()

        assert isinstance(losses, list), "losses must be a list of dictionaries."
        assert mode in ["loss_only", "preds_only", "both", "module"], (
            f"Invalid mode '{mode}'. Expected one of "
            "['loss_only', 'preds_only', 'both', 'module']."
        )

        self.model = model
        self.losses = losses
        self.output_mode = mode

    def forward(
        self,
        x: torch.Tensor,
        target: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None
    ) -> Union[Tuple[torch.Tensor, dict], dict, torch.Tensor]:
        """
        Forward propagation with optional loss computation.

        Parameters
        ----------
        x : torch.Tensor
            Model input.
        target : torch.Tensor or list of torch.Tensor, optional
            Ground truth data for computing loss terms.

        Returns
        -------
        Depends on mode:
            - 'loss_only'  ‚Üí loss dict
            - 'preds_only' ‚Üí model predictions
            - 'both'       ‚Üí (preds, loss dict)
            - 'module'     ‚Üí loss dict (train mode) or (preds, loss dict) (eval mode)
        """
        try:
            output = self.model(x)
        except ValueError:
            # Some models expect target in forward (rare case)
            output = self.model(x, target)

        # Ensure output and target are lists
        outputs_used = output if isinstance(output, (list, tuple)) else [output]
        targets_used = target if isinstance(target, (list, tuple)) else [target]

        # Compute individual losses
        loss_dict = {}
        if target is not None:
            for loss_def in self.losses:
                i = loss_def["idx"]
                j = loss_def["idy"]
                reg = loss_def["lambda"]
                loss_fn = loss_def["loss"]
                loss_val = loss_fn(outputs_used[i], targets_used[j])
                loss_dict[loss_def["name"]] = reg * loss_val

        # Output logic according to mode
        if self.output_mode == "module":
            if self.training:
                # Training: only loss dict
                if not loss_dict:
                    loss_dict = output
                return loss_dict
            else:
                # Evaluation: predictions and loss dict
                return output, loss_dict

        elif self.output_mode == "loss_only":
            return loss_dict

        elif self.output_mode == "preds_only":
            return output

        elif self.output_mode == "both":
            return output, loss_dict


In [5]:
# ============================================================
# üé≤ RANDOMNESS & REPRODUCIBILITY UTILITIES
# ============================================================
# Faithful reimplementation of set_seed and seed_worker
# (Delplanque, 2024). Ensures deterministic training behavior
# across CPU, GPU, NumPy, and DataLoader workers.
# ============================================================

import random
import numpy as np
import torch


def set_seed(seed: int) -> None:
    """
    Set global random seed for reproducibility across libraries.

    This function initializes deterministic states for Python,
    NumPy, and PyTorch (CPU and GPU), helping to ensure consistent
    behavior across runs.

    Notes
    -----
    Perfect reproducibility is not guaranteed due to inherent
    non-determinism in some CUDA operations.
    See: https://pytorch.org/docs/stable/notes/randomness.html

    Parameters
    ----------
    seed : int
        Random seed to set globally.
    """
    # Python random
    random.seed(seed)

    # NumPy random
    np.random.seed(seed)

    # PyTorch (CPU)
    torch.manual_seed(seed)

    # PyTorch (GPU)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # Deterministic behavior configuration
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def seed_worker(worker_id: int) -> None:
    """
    Set deterministic seed for PyTorch DataLoader workers.

    Parameters
    ----------
    worker_id : int
        The worker ID provided automatically by DataLoader.
    """
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


In [6]:
# ============================================================
# ‚öôÔ∏è ADALOSS MODULE - FROM SCRATCH
# ============================================================
# Faithful reimplementation of Adaloss (Delplanque, 2024).
# Adaptive loss weighting for dynamic end-transform parameters.
# ============================================================

import torch


class Adaloss:
    """
    Adaptive Loss Optimizer (Adaloss)
    ---------------------------------
    Adjusts a scalar dataset parameter dynamically during training
    based on the variance ratio of recent loss windows.

    Parameters
    ----------
    param : torch.Tensor
        Parameter tensor to adapt (e.g., end-transform weight).
    w : int
        Sliding window size for variance estimation. Default = 3.
    rho : float
        Adaptation rate (0‚Äì1). Default = 0.9.
    delta_max : float
        Maximum allowed update step magnitude. Default = 5.0.
    """

    def __init__(
        self,
        param: torch.Tensor,
        w: int = 3,
        rho: float = 0.9,
        delta_max: float = 5.0
    ) -> None:

        assert isinstance(param, torch.Tensor), \
            f"param must be a torch.Tensor, got {type(param)}"

        assert isinstance(w, int) and w > 0, \
            "w must be a positive integer"

        assert 0.0 <= rho <= 1.0, \
            f"rho must be between 0.0 and 1.0, got {rho}"

        self.param = param
        self.w = w
        self.rho = rho
        self.delta_max = delta_max

        self._step = 1
        self._losses = []
        self.loss_history = []
        self.var_history = []
        self.param_tracker = []

    def step(self) -> None:
        """Perform one Adaloss adaptation step."""
        self._update_losses()
        self._update_vars()

        if self._step > self.w:
            delta = self._compute_delta()
            delta_clamped = torch.clamp(delta, -self.delta_max, self.delta_max)
            self.param.add_(delta_clamped)

        self._step += 1
        self.param_tracker.append(torch.clone(self.param))

    def feed(self, loss: torch.Tensor) -> None:
        """Feed the current loss value for tracking."""
        self._losses.append(loss.detach())

    def _update_losses(self) -> None:
        """Update rolling window of losses."""
        self.loss_history.append(torch.stack(self._losses))
        self._losses = []

    def _update_vars(self) -> None:
        """Compute and store loss variance for current window."""
        if self._step >= self.w:
            window_losses = torch.cat(self.loss_history[self._step - self.w:self._step])
            variance = torch.var(window_losses)
            self.var_history.append(variance)

    def _compute_delta(self) -> torch.Tensor:
        """Compute adaptive delta from variance ratio."""
        if len(self.var_history) < 2:
            return torch.tensor(0.0)
        ratio = self.var_history[-2] / (self.var_history[-1] + 1e-8)
        return self.rho * (1.0 - ratio)


In [7]:
# ============================================================
# ‚öôÔ∏è TRAINER MODULE - FULL STANDALONE VERSION
# ============================================================
# Faithful reimplementation of Trainer and FasterRCNNTrainer
# (Delplanque, 2024), stripped of animaloc dependencies.
# ============================================================

import os
import sys
import math
import time
import torch
import wandb
import matplotlib.pyplot as plt

from collections import deque, defaultdict
from typing import Any, List, Optional, Union, Callable


# ============================================================
# üîß UTILS: SmoothedValue
# ============================================================
class SmoothedValue:
    """Track a series of values and provide smoothed averages."""

    def __init__(self, window_size: int = 20, fmt: str = "{median:.4f} ({global_avg:.4f})"):
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value: float, n: int = 1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item() if len(d) > 0 else 0.0

    @property
    def avg(self):
        d = torch.tensor(list(self.deque))
        return d.mean().item() if len(d) > 0 else 0.0

    @property
    def global_avg(self):
        return self.total / max(1, self.count)

    def __str__(self):
        return self.fmt.format(median=self.median, global_avg=self.global_avg)


# ============================================================
# üîß UTILS: reduce_dict (monoprocess replacement)
# ============================================================
def reduce_dict(input_dict):
    """Return same dict (for single process execution)."""
    return input_dict


# ============================================================
# üîß UTILS: CustomLogger
# ============================================================
class CustomLogger:
    """
    Minimal logger replicating Delplanque's behavior.
    Tracks metrics and prints progress with defined frequency.
    """

    def __init__(self, delimiter=" ", filename=None, work_dir=".", csv=False):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter
        self.filename = filename
        self.work_dir = work_dir
        self.csv = csv
        self.filepath = None
        if csv and filename:
            self.filepath = os.path.join(work_dir, f"{filename}.csv")
            with open(self.filepath, "w", encoding="utf-8") as f:
                f.write("step,loss\n")

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            self.meters[k].update(v)

    def __str__(self):
        metrics = [f"{k}: {v}" for k, v in self.meters.items()]
        return self.delimiter.join(metrics)

    def log_every(self, iterable, print_freq, header=""):
        i = 0
        start_time = time.time()
        for obj in iterable:
            yield obj
            if i % print_freq == 0:
                print(f"{header} [{i}/{len(iterable)}] {self}")
            i += 1
        total_time = time.time() - start_time
        print(f"{header} Total time: {total_time:.2f}s")


# ============================================================
# üß† TRAINER BASE CLASS
# ============================================================
class Trainer:
    """Base class for supervised training of models."""

    def __init__(
        self,
        model: torch.nn.Module,
        train_dataloader: torch.utils.data.DataLoader,
        optimizer: torch.optim.Optimizer,
        num_epochs: int,
        lr_milestones: Optional[List[int]] = None,
        auto_lr: Union[bool, dict] = False,
        adaloss: Optional[Any] = None,
        val_dataloader: Optional[torch.utils.data.DataLoader] = None,
        evaluator: Optional[Any] = None,
        vizual_fn: Optional[Callable] = None,
        work_dir: Optional[str] = None,
        device_name: str = "cuda",
        print_freq: int = 50,
        valid_freq: int = 1,
        csv_logger: bool = False
    ):

        self.device = torch.device(device_name if torch.cuda.is_available() else "cpu")
        self.model = model.to(self.device)
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.optimizer = optimizer
        self.epochs = num_epochs
        self.lr_milestones = lr_milestones
        self.auto_lr = auto_lr
        self.auto_lr_flag = bool(auto_lr)
        self.adaloss = adaloss
        self.evaluator = evaluator
        self.vizual_fn = vizual_fn
        self.print_freq = print_freq
        self.valid_freq = valid_freq
        self.work_dir = work_dir or os.getcwd()
        self.csv_logger = csv_logger

        # local loggers
        self.train_logger = CustomLogger(delimiter=" ", filename="training", work_dir=self.work_dir, csv=csv_logger)
        self.val_logger = CustomLogger(delimiter=" ", filename="validation", work_dir=self.work_dir, csv=csv_logger)

    # ----------------------------------------------------------
    def prepare_data(self, images, targets):
        """Prepare tensors for GPU."""
        images = images.to(self.device)
        if isinstance(targets, (list, tuple)):
            targets = [t.to(self.device) for t in targets]
        else:
            targets = targets.to(self.device)
        return images, targets

    # ----------------------------------------------------------
    def _lr_scheduler(self):
        """Build learning rate scheduler."""
        if self.auto_lr is True:
            return torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer)
        elif isinstance(self.auto_lr, dict):
            return torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, **self.auto_lr)
        elif self.lr_milestones:
            return torch.optim.lr_scheduler.MultiStepLR(self.optimizer, self.lr_milestones)
        return None

    # ----------------------------------------------------------
    def _warmup_lr_scheduler(self, warmup_iters, warmup_factor):
        """Gradual warmup scheduler."""
        def warmup_func(x):
            if x >= warmup_iters:
                return 1
            alpha = float(x) / warmup_iters
            return warmup_factor * (1 - alpha) + alpha
        return torch.optim.lr_scheduler.LambdaLR(self.optimizer, warmup_func)

    # ----------------------------------------------------------
    def _is_best(self, val_output, mode="min"):
        """Check if current epoch is best."""
        if not hasattr(self, "best_val"):
            self.best_val = float("inf") if mode == "min" else 0
        if (mode == "min" and val_output < self.best_val) or (mode == "max" and val_output > self.best_val):
            self.best_val = val_output
            return True
        return False

    # ----------------------------------------------------------
    def _save_checkpoint(self, epoch, mode):
        """Save checkpoint."""
        os.makedirs(self.work_dir, exist_ok=True)
        if mode == "all":
            outpath = os.path.join(self.work_dir, f"epoch_{epoch}.pth")
        elif mode == "best":
            outpath = os.path.join(self.work_dir, "best_model.pth")
        else:
            outpath = os.path.join(self.work_dir, "latest_model.pth")

        torch.save({
            "epoch": epoch,
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "best_val": getattr(self, "best_val", None)
        }, outpath)

    # ----------------------------------------------------------
    def _vizual(self, image, target, output):
        return self.vizual_fn(image=image, target=target, output=output)

    # ----------------------------------------------------------
    def evaluate(self, epoch, reduction="mean", wandb_flag=False, returns="all"):
        """Validation loop."""
        self.model.eval()
        header = f"[VALIDATION] - Epoch: [{epoch}]"
        batch_losses = []

        with torch.no_grad():
            for i, (images, targets) in enumerate(self.val_logger.log_every(self.val_dataloader, self.print_freq, header)):
                images, targets = self.prepare_data(images, targets)
                output, loss_dict = self.model(images, targets)
                losses = sum(loss for loss in loss_dict.values()) if returns == "all" else loss_dict[returns]
                batch_losses.append(losses.detach())

                if wandb_flag and self.vizual_fn and i % self.print_freq == 0:
                    fig = self._vizual(images, targets, output)
                    wandb.log({"validation_viz": fig})

        batch_losses = torch.stack(batch_losses)
        val_loss = torch.mean(batch_losses).item() if reduction == "mean" else torch.sum(batch_losses).item()
        print(f"{header} {reduction} loss: {val_loss:.4f}")
        return val_loss

    # ----------------------------------------------------------
    def _train(self, epoch, warmup_iters=None, wandb_flag=False):
        """Training loop."""
        self.model.train()
        self.train_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
        header = f"[TRAINING] - Epoch: [{epoch}]"

        if warmup_iters and epoch == 1:
            self.start_lr_scheduler = self._warmup_lr_scheduler(
                min(warmup_iters, len(self.train_dataloader) - 1),
                1.0 / warmup_iters
            )

        batch_losses = []

        for images, targets in self.train_logger.log_every(self.train_dataloader, self.print_freq, header):
            images, targets = self.prepare_data(images, targets)
            self.optimizer.zero_grad()
            loss_dict = self.model(images, targets)
            loss_total = sum(loss for loss in loss_dict.values())

            if not math.isfinite(loss_total.item()):
                print("Loss is NaN or Inf. Stopping.")
                sys.exit(1)

            loss_total.backward()
            self.optimizer.step()
            if self.adaloss:
                self.adaloss.feed(loss_total)

            if warmup_iters and epoch == 1:
                self.start_lr_scheduler.step()

            batch_losses.append(loss_total.detach())

        mean_loss = torch.mean(torch.stack(batch_losses)).item()
        print(f"{header} mean loss: {mean_loss:.4f}")
        return mean_loss

    # ----------------------------------------------------------
    def start(self, warmup_iters=None, checkpoints="best", select="min", validate_on="all", wandb_flag=False):
        """Start full training."""
        lr_scheduler = self._lr_scheduler()
        if select == "min":
            self.best_val = float("inf")
        else:
            self.best_val = 0

        for epoch in range(1, self.epochs + 1):
            train_loss = self._train(epoch, warmup_iters, wandb_flag)
            val_loss = None

            if self.val_dataloader and (epoch % self.valid_freq == 0 or epoch in [1, self.epochs]):
                val_loss = self.evaluate(epoch, wandb_flag=wandb_flag, returns=validate_on)

                if checkpoints == "all" or self._is_best(val_loss, select):
                    self._save_checkpoint(epoch, checkpoints if checkpoints == "all" else "best")

            self._save_checkpoint(epoch, "latest")

            if lr_scheduler:
                if self.auto_lr_flag and val_loss is not None:
                    lr_scheduler.step(val_loss)
                else:
                    lr_scheduler.step()

            if self.adaloss:
                self.adaloss.step()

        return self.model


# ============================================================
# ü¶ã FASTER-RCNN TRAINER SUBCLASS
# ============================================================
class FasterRCNNTrainer(Trainer):
    """Trainer subclass for object detection tasks."""

    def prepare_data(self, images, targets):
        images = [img.to(self.device) for img in images]
        targets = [{k: v.to(self.device) for k, v in t.items() if torch.is_tensor(v)} for t in targets]
        return images, targets


In [8]:
# ============================================================
# üß© CORE UTILITIES FOR HERDNET REBUILD
# ============================================================
# This block restores the missing low-level structures from
# the original animaloc framework:
# - Registry
# - Geometry classes (Point, BoundingBox, processors)
# - Image patch utilities
# - Logger compatible with Evaluator
# ============================================================

import os
import math
import torch
import numpy as np
import logging
from collections import defaultdict
from typing import Any, Callable, Dict, List, Tuple, Optional
from torch.utils.data import Dataset


# ============================================================
# üì¶ 1. REGISTRY SYSTEM
# ============================================================
class Registry:
    """
    Lightweight registry to store callable modules (e.g. models, evaluators, metrics).
    Mimics animaloc.utils.registry.Registry.
    """

    def __init__(self, name: str, module_key: str = None):
        self.name = name
        self.module_key = module_key
        self._registry = {}

    def register(self, name: Optional[str] = None):
        """
        Decorator to register classes or functions into the registry.
        """
        def decorator(obj):
            key = name or obj.__name__
            if key in self._registry:
                raise ValueError(f"{key} already registered in {self.name}")
            self._registry[key] = obj
            return obj
        return decorator

    def get(self, name: str):
        return self._registry.get(name)

    @property
    def registry_names(self):
        return list(self._registry.keys())

    def __getitem__(self, key):
        return self._registry[key]

    def __repr__(self):
        return f"Registry({self.name}, {len(self._registry)} items)"


# ============================================================
# üìè 2. GEOMETRY CLASSES
# ============================================================
class Point:
    """Simple point in (x, y) coordinates."""

    def __init__(self, x: float, y: float):
        self.x = float(x)
        self.y = float(y)

    def __iter__(self):
        yield self.x
        yield self.y

    def __repr__(self):
        return f"Point(x={self.x:.2f}, y={self.y:.2f})"


class BoundingBox:
    """Simple bounding box defined by (x_min, y_min, x_max, y_max)."""

    def __init__(self, x_min: float, y_min: float, x_max: float, y_max: float):
        self.x_min = float(x_min)
        self.y_min = float(y_min)
        self.x_max = float(x_max)
        self.y_max = float(y_max)
        self._validate()

    def _validate(self):
        assert self.x_max >= self.x_min, "x_max must be >= x_min"
        assert self.y_max >= self.y_min, "y_max must be >= y_min"

    @property
    def width(self):
        return self.x_max - self.x_min

    @property
    def height(self):
        return self.y_max - self.y_min

    @property
    def area(self):
        return self.width * self.height

    def intersect(self, other: "BoundingBox") -> "BoundingBox":
        """Return the intersection box of two bounding boxes."""
        x_min = max(self.x_min, other.x_min)
        y_min = max(self.y_min, other.y_min)
        x_max = min(self.x_max, other.x_max)
        y_max = min(self.y_max, other.y_max)
        if x_max < x_min or y_max < y_min:
            return BoundingBox(0, 0, 0, 0)
        return BoundingBox(x_min, y_min, x_max, y_max)

    def __repr__(self):
        return f"BBox({self.x_min:.2f},{self.y_min:.2f},{self.x_max:.2f},{self.y_max:.2f})"


class PointProcessor:
    """Compute distances between points."""

    def __init__(self, point: Point):
        self.point = point

    def dist(self, other: Point) -> float:
        return math.sqrt((self.point.x - other.x) ** 2 + (self.point.y - other.y) ** 2)


class BboxProcessor:
    """Bounding box processing: intersection and distance."""

    def __init__(self, bbox: BoundingBox):
        self.bbox = bbox

    def intersect(self, other: BoundingBox) -> BoundingBox:
        return self.bbox.intersect(other)


# ============================================================
# üß© 3. IMAGE TO PATCHES (FOR STITCHER)
# ============================================================
class ImageToPatches:
    """
    Utility to split an image tensor [C,H,W] into overlapping patches
    of fixed size (height, width) with specified overlap (pixels).
    """

    def __init__(self, image: torch.Tensor, size: Tuple[int, int], overlap: int):
        self.image = image
        self.size = np.array(size)
        self.overlap = overlap
        self._ncol = math.ceil((image.shape[1] - overlap) / (size[0] - overlap))
        self._nrow = math.ceil((image.shape[2] - overlap) / (size[1] - overlap))

    def make_patches(self) -> torch.Tensor:
        """Return list of image patches as tensor [N,C,H,W]."""
        C, H, W = self.image.shape
        h, w = self.size
        stride_h, stride_w = h - self.overlap, w - self.overlap
        patches = []

        for y in range(0, H - h + 1, stride_h):
            for x in range(0, W - w + 1, stride_w):
                patch = self.image[:, y:y+h, x:x+w]
                patches.append(patch.unsqueeze(0))
        return torch.cat(patches, dim=0)

    def get_limits(self) -> Dict[int, BoundingBox]:
        """Return coordinates of each patch as BoundingBox objects."""
        limits = {}
        h, w = self.size
        stride_h, stride_w = h - self.overlap, w - self.overlap
        idx = 0
        for y in range(0, self.image.shape[1] - h + 1, stride_h):
            for x in range(0, self.image.shape[2] - w + 1, stride_w):
                limits[idx] = BoundingBox(x, y, x+w, y+h)
                idx += 1
        return limits


# ============================================================
# ü™µ 4. EXTENDED CUSTOM LOGGER
# ============================================================
class CustomLogger:
    """
    Minimal logging utility compatible with Evaluator and Trainer.
    Supports add_meter(), log_every(), and persistent output file.
    """

    def __init__(self, delimiter: str = " ", filename: str = "log", work_dir: Optional[str] = None):
        self.delimiter = delimiter
        self.meters = defaultdict(list)
        self.work_dir = work_dir or os.getcwd()
        self.log_path = os.path.join(self.work_dir, f"{filename}.txt")
        os.makedirs(self.work_dir, exist_ok=True)
        self.logger = logging.getLogger(filename)
        self.logger.setLevel(logging.INFO)
        handler = logging.FileHandler(self.log_path, mode="a", encoding="utf-8")
        formatter = logging.Formatter("[%(asctime)s] %(message)s", datefmt="%H:%M:%S")
        handler.setFormatter(formatter)
        if not self.logger.handlers:
            self.logger.addHandler(handler)

    def log_every(self, iterable, print_freq: int, header: Optional[str] = None):
        """Iterate and print progress every print_freq iterations."""
        header = header or ""
        for i, obj in enumerate(iterable):
            if i % print_freq == 0 or i == len(iterable) - 1:
                msg = self._get_meter_summary()
                print(f"{header} [{i+1}/{len(iterable)}]{self.delimiter}{msg}")
                self.logger.info(f"{header} [{i+1}/{len(iterable)}]{self.delimiter}{msg}")
            yield obj

    def add_meter(self, name: str, value: Any):
        """Add a metric value to the logger."""
        self.meters[name].append(value)

    def _get_meter_summary(self) -> str:
        """Summarize meters as mean values."""
        parts = []
        for k, v in self.meters.items():
            if isinstance(v, (list, tuple)) and len(v) > 0:
                parts.append(f"{k}: {np.mean(v):.3f}")
        return self.delimiter.join(parts)

    def save_csv(self):
        """Optional: save meters as CSV for analysis."""
        import pandas as pd
        df = pd.DataFrame(dict([(k, v) for k, v in self.meters.items()]))
        csv_path = self.log_path.replace(".txt", ".csv")
        df.to_csv(csv_path, index=False)
        print(f"Saved meters to {csv_path}")


In [9]:
# ============================================================
# üìú LOGGER.PY ‚Äî Console & File Logging Utility
# ============================================================
# Compatible con entrenamiento y evaluaci√≥n.
# Registra m√©tricas en consola y CSV (.txt), con timestamp.
# Inspirado en animaloc.utils.logger (Alexandre Delplanque, 2024)
# ============================================================

import os
import csv
import sys
import time
import errno
import datetime
from collections import defaultdict


class CustomLogger:
    """
    CustomLogger
    ------------
    Simple and robust logger for training/evaluation progress.

    Features:
    - Console + file logging
    - CSV-friendly output
    - Compatible with MetricLogger.add_meter()
    - Timestamped filenames (YYYYMMDD-HHMMSS)
    """

    def __init__(self, delimiter: str = "\t", filename: str = "log", work_dir: str = None):
        self.delimiter = delimiter
        self.meters = defaultdict(list)
        self.header = None

        # Prepare output directory
        self.work_dir = work_dir or os.getcwd()
        if not os.path.exists(self.work_dir):
            try:
                os.makedirs(self.work_dir)
            except OSError as e:
                if e.errno != errno.EEXIST:
                    raise

        # Timestamp for unique file
        timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        self.filename_txt = os.path.join(self.work_dir, f"{filename}_{timestamp}.txt")
        self.filename_csv = os.path.join(self.work_dir, f"{filename}_{timestamp}.csv")

        # Open file handles
        self.txt_file = open(self.filename_txt, "w", encoding="utf-8", buffering=1)
        self.csv_file = open(self.filename_csv, "w", newline="", encoding="utf-8")
        self.csv_writer = None

    # --------------------------------------------------------
    # BASIC LOGGING METHODS
    # --------------------------------------------------------
    def add_meter(self, name: str, value: float):
        """Add a scalar meter value."""
        self.meters[name].append(value)

    def write(self, msg: str):
        """Write to both console and file."""
        print(msg)
        self.txt_file.write(msg + "\n")

    def log(self, **kwargs):
        """Log a full row of key-value pairs (e.g. loss, acc, etc.)."""
        timestamp = datetime.datetime.now().strftime("%H:%M:%S")
        entry = {"time": timestamp, **kwargs}
        self._write_row(entry)

    def _write_row(self, row: dict):
        """Internal CSV-safe write with header detection."""
        if self.csv_writer is None:
            fieldnames = list(row.keys())
            self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
            self.csv_writer.writeheader()
        self.csv_writer.writerow(row)
        self.csv_file.flush()

    # --------------------------------------------------------
    # ITERABLE LOGGING (for dataloaders)
    # --------------------------------------------------------
    def log_every(self, iterable, print_freq: int = 10, header: str = None):
        """
        Iterate over an iterable and print running logs periodically.
        Mimics torch.utils MetricLogger behavior.
        """
        start_time = time.time()
        end = time.time()
        total = len(iterable)
        header = header or ""

        for i, obj in enumerate(iterable):
            yield obj
            if i % print_freq == 0 or i == total - 1:
                eta_seconds = (time.time() - end) * (total - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                log_line = f"[{i:>{len(str(total))}}/{total}] ETA: {eta_string}"
                self.write(f"{header} {log_line}")
            end = time.time()

        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        self.write(f"{header} Total time: {total_time_str} ({total_time / total:.4f} s/it)")

    # --------------------------------------------------------
    # SHUTDOWN
    # --------------------------------------------------------
    def close(self):
        """Close log files cleanly."""
        self.txt_file.close()
        self.csv_file.close()

    # --------------------------------------------------------
    # STRING REPRESENTATION
    # --------------------------------------------------------
    def __str__(self):
        parts = [f"{k}: {v[-1]:.4f}" for k, v in self.meters.items() if len(v) > 0]
        return self.delimiter.join(parts)


# ============================================================
# üîß UTILITY HELPERS
# ============================================================

def mkdir(path):
    """Create directories recursively."""
    try:
        os.makedirs(path)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise


def current_date():
    """Return current date in YYYYMMDD format."""
    return datetime.date.today().strftime("%Y%m%d")


def get_date_time():
    """Return current date and time as (date_str, time_str)."""
    today = datetime.date.today().strftime("%d/%m/%Y")
    now = datetime.datetime.now().strftime("%H:%M:%S")
    return today, now


def vdir(obj):
    """Return object attributes without dunder methods."""
    return [m for m in dir(obj) if not m.startswith("__")]


In [12]:
# ============================================================
# üéØ FOCAL LOSS - Clean Reimplementation (HerdNet v0.2.1)
# ============================================================
# Self-contained version of Delplanque‚Äôs focal loss used in HerdNet.
# Preserves original logic and hyperparameters: alpha, beta,
# reduction, weights, density weighting, normalization, and eps.
# ============================================================

import torch
import torch.nn as nn

class FocalLoss(nn.Module):
    """
    Focal Loss module (Delplanque, 2024 - HerdNet v0.2.1)

    Implements the modified focal loss used in HerdNet, originally inspired by
    the CenterNet implementation:
    https://github.com/xingyizhou/CenterNet/blob/master/src/lib/models/losses.py

    Parameters
    ----------
    alpha : int, optional
        Exponent for modulating factor (1 - p_t). Default = 2.
    beta : int, optional
        Exponent for weighting negative examples. Default = 4.
    reduction : str, optional
        Reduction method ('sum' or 'mean'). Default = 'sum'.
    weights : torch.Tensor, optional
        Optional per-channel weighting tensor.
    density_weight : str, optional
        If set to 'linear', 'squared' or 'cubic', scales loss per sample
        by number of positive locations to account for density.
    normalize : bool, optional
        Normalize loss by number of positive samples. Default = False.
    eps : float, optional
        Small constant for numerical stability. Default = 1e-6.
    """

    def __init__(
        self,
        alpha: int = 2,
        beta: int = 4,
        reduction: str = "sum",
        weights: torch.Tensor = None,
        density_weight: str = None,
        normalize: bool = False,
        eps: float = 1e-6
    ):
        super(FocalLoss, self).__init__()

        assert reduction in ["mean", "sum"], \
            f"Reduction must be 'mean' or 'sum', got {reduction}"

        self.alpha = alpha
        self.beta = beta
        self.reduction = reduction
        self.weights = weights
        self.density_weight = density_weight
        self.normalize = normalize
        self.eps = eps

    def forward(self, output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Compute the focal loss between prediction and target.

        Parameters
        ----------
        output : torch.Tensor
            Predicted heatmap tensor [B, C, H, W].
        target : torch.Tensor
            Ground truth heatmap tensor [B, C, H, W].

        Returns
        -------
        torch.Tensor
            Computed loss value.
        """
        return self._neg_loss(output, target)

    def _neg_loss(self, output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """Core focal loss computation."""
        B, C, _, _ = target.shape
        device = output.device

        if self.weights is not None:
            assert self.weights.shape[0] == C, (
                f"Expected {C} channel weights, got {self.weights.shape[0]}"
            )

        # Clamp predictions to avoid log(0)
        output = torch.clamp(output, min=self.eps, max=1 - self.eps)

        # Masks
        pos_inds = target.eq(1).float()
        neg_inds = target.lt(1).float()
        neg_weights = torch.pow(1 - target, self.beta)

        # Compute log terms
        pos_loss = torch.log(output) * torch.pow(1 - output, self.alpha) * pos_inds
        neg_loss = torch.log(1 - output) * torch.pow(output, self.alpha) * neg_weights * neg_inds

        # Sum spatially
        num_pos = pos_inds.sum(dim=(2, 3))
        pos_loss = pos_loss.sum(dim=(2, 3))
        neg_loss = neg_loss.sum(dim=(2, 3))

        # Initialize total loss tensor
        loss = torch.zeros((B, C), device=device)

        # Iterate over batch and channels
        for b in range(B):
            for c in range(C):
                density = torch.tensor(1.0, device=device)
                if self.density_weight == "linear":
                    density = num_pos[b][c]
                elif self.density_weight == "squared":
                    density = num_pos[b][c] ** 2
                elif self.density_weight == "cubic":
                    density = num_pos[b][c] ** 3

                if num_pos[b][c] == 0:
                    loss[b][c] = -neg_loss[b][c]
                else:
                    total = pos_loss[b][c] + neg_loss[b][c]
                    loss[b][c] = density * (-total)
                    if self.normalize:
                        loss[b][c] = loss[b][c] / (num_pos[b][c] + self.eps)

        # Apply per-channel weights
        if self.weights is not None:
            loss = loss * self.weights.to(device)

        # Final reduction
        if self.reduction == "mean":
            return loss.mean()
        return loss.sum()


In [10]:
# ============================================================
# üì¶ CSVDataset ‚Äî Dataset class for images + CSV annotations
# ============================================================
# Compatible with:
#  - Point annotations (for HerdNet or DensityMap models)
#  - Bounding boxes (for FasterRCNN, etc.)
#  - Albumentations transforms
# ============================================================

import os
import cv2
import torch
import pandas as pd
from PIL import Image

from torch.utils.data import Dataset
from albumentations import Compose

# ============================================================
# üß© DATASET CLASS
# ============================================================
class CSVDataset(Dataset):
    """
    CSVDataset
    ----------
    Reads image paths and annotations from a CSV file.

    Supported annotation types:
    - Points (x, y, label)
    - Bounding boxes (x_min, y_min, x_max, y_max, label)

    The dataset can apply Albumentations transforms and return
    tensors ready for model input.

    Parameters
    ----------
    csv_file : str
        Path to the CSV file containing annotations.
    img_dir : str
        Directory containing the images.
    transforms : albumentations.Compose, optional
        Transformation pipeline applied to each sample.
    anno_type : str, optional
        Type of annotation: 'point' or 'bbox'.
        Defaults to 'point'.
    """

    def __init__(self, csv_file: str, img_dir: str, transforms: Compose = None, anno_type: str = "point") -> None:
        assert os.path.exists(csv_file), f"CSV file not found: {csv_file}"
        assert os.path.exists(img_dir), f"Image directory not found: {img_dir}"
        assert anno_type in ["point", "bbox"], f"Invalid annotation type: {anno_type}"

        self.csv_file = csv_file
        self.img_dir = img_dir
        self.transforms = transforms
        self.anno_type = anno_type

        # Read CSV safely
        self.df = pd.read_csv(csv_file)
        if "image_id" not in self.df.columns:
            raise ValueError("CSV must contain 'image_id' column")

        # Unique image names and mapping
        self._img_names = sorted(self.df["image_id"].unique())
        self.num_samples = len(self._img_names)

    # --------------------------------------------------------
    # üîπ LENGTH
    # --------------------------------------------------------
    def __len__(self) -> int:
        return self.num_samples

    # --------------------------------------------------------
    # üîπ GET ITEM
    # --------------------------------------------------------
    def __getitem__(self, idx: int):
        """Load one sample (image + target)."""

        # 1Ô∏è‚É£ Get image path
        img_name = self._img_names[idx]
        img_path = os.path.join(self.img_dir, img_name)
        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image not found: {img_path}")

        # 2Ô∏è‚É£ Load image (convert BGR ‚Üí RGB)
        image = cv2.imread(img_path)
        if image is None:
            raise ValueError(f"Failed to read image: {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # 3Ô∏è‚É£ Subset dataframe
        subset = self.df[self.df["image_id"] == img_name]

        # 4Ô∏è‚É£ Parse annotations
        if self.anno_type == "point":
            target = self._parse_points(subset)
        else:
            target = self._parse_bboxes(subset)

        # 5Ô∏è‚É£ Apply Albumentations transforms (if any)
        if self.transforms is not None:
            transformed = self.transforms(image=image, **target)
            image = transformed["image"]
            target = {k: transformed[k] for k in target.keys() if k in transformed}

        # 6Ô∏è‚É£ Convert to torch.Tensor
        image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)

        return image, target

    # --------------------------------------------------------
    # üî∏ PARSE METHODS
    # --------------------------------------------------------
    def _parse_points(self, subset: pd.DataFrame) -> dict:
        """Parse point annotations from a dataframe."""
        required_cols = {"x", "y", "label"}
        if not required_cols.issubset(subset.columns):
            raise ValueError(f"Missing columns for point annotations: {required_cols}")

        points = subset[["x", "y"]].values.tolist()
        labels = subset["label"].astype(int).tolist()

        target = {
            "points": torch.tensor(points, dtype=torch.float32),
            "labels": torch.tensor(labels, dtype=torch.int64),
        }
        return target

    def _parse_bboxes(self, subset: pd.DataFrame) -> dict:
        """Parse bounding box annotations from a dataframe."""
        required_cols = {"x_min", "y_min", "x_max", "y_max", "label"}
        if not required_cols.issubset(subset.columns):
            raise ValueError(f"Missing columns for bbox annotations: {required_cols}")

        boxes = subset[["x_min", "y_min", "x_max", "y_max"]].values.tolist()
        labels = subset["label"].astype(int).tolist()

        target = {
            "boxes": torch.tensor(boxes, dtype=torch.float32),
            "labels": torch.tensor(labels, dtype=torch.int64),
        }
        return target

    # --------------------------------------------------------
    # üîπ IMAGE UTILITIES
    # --------------------------------------------------------
    def get_image_name(self, idx: int) -> str:
        """Return image filename by index."""
        return self._img_names[idx]

    def get_image_path(self, idx: int) -> str:
        """Return absolute image path by index."""
        return os.path.join(self.img_dir, self._img_names[idx])

    # --------------------------------------------------------
    # üîπ REPR
    # --------------------------------------------------------
    def __repr__(self) -> str:
        return f"CSVDataset(n={self.num_samples}, type={self.anno_type}, dir='{self.img_dir}')"

In [11]:
# ============================================================
# ‚öôÔ∏è REGISTRY.PY ‚Äî Dynamic Class Registration System
# ============================================================
# This module provides a simple registry pattern to dynamically
# register and instantiate modules (models, trainers, datasets, etc.)
# ============================================================

import inspect
from typing import Any, Callable, Dict, Optional


class Registry:
    """
    Registry
    --------
    A lightweight class registry system used to register and retrieve
    model components (e.g., networks, trainers, transforms).

    Example:
    --------
    MODELS = Registry("models")

    @MODELS.register()
    class MyModel:
        pass

    model = MODELS.build("MyModel")
    """

    def __init__(self, name: str, module_key: Optional[str] = None):
        self._name = name
        self._module_key = module_key or name
        self._registry: Dict[str, Any] = {}

    # --------------------------------------------------------
    # üîπ REGISTRATION
    # --------------------------------------------------------
    def register(self, name: Optional[str] = None) -> Callable:
        """
        Decorator to register a class or function.

        Args:
            name (str, optional): Custom registry key name.
        """

        def decorator(obj: Any):
            key = name or obj.__name__
            if key in self._registry:
                raise KeyError(f"'{key}' is already registered in {self._name}")
            self._registry[key] = obj
            return obj

        return decorator

    # --------------------------------------------------------
    # üîπ RETRIEVAL
    # --------------------------------------------------------
    def get(self, key: str) -> Any:
        """
        Retrieve an object from the registry by its key.

        Args:
            key (str): The registered name or class name.

        Returns:
            Registered class or function.
        """
        if key not in self._registry:
            raise KeyError(f"'{key}' is not registered in {self._name}")
        return self._registry[key]

    # --------------------------------------------------------
    # üîπ BUILD INSTANCE
    # --------------------------------------------------------
    def build(self, key: str, **kwargs) -> Any:
        """
        Instantiate a registered class or call a registered function.

        Args:
            key (str): Registered name or class name.
            **kwargs: Arguments passed to the constructor.

        Returns:
            Instantiated object.
        """
        cls_or_func = self.get(key)
        if inspect.isclass(cls_or_func):
            return cls_or_func(**kwargs)
        elif callable(cls_or_func):
            return cls_or_func(**kwargs)
        else:
            raise TypeError(f"Object '{key}' is not callable in {self._name}")

    # --------------------------------------------------------
    # üîπ LIST UTILITIES
    # --------------------------------------------------------
    @property
    def registry_names(self):
        """Return a sorted list of registered names."""
        return sorted(list(self._registry.keys()))

    def __len__(self) -> int:
        return len(self._registry)

    def __contains__(self, key: str) -> bool:
        return key in self._registry

    def __repr__(self) -> str:
        return f"Registry(name='{self._name}', entries={list(self._registry.keys())})"


In [13]:
# ============================================================
# ‚öôÔ∏è DEFAULT LOSS CONFIGURATION BUILDER
# ============================================================
# Generates the exact same losses list used in HerdNet training:
#   - FocalLoss for localization heatmap
#   - CrossEntropyLoss for classification map
# Can be passed directly to LossWrapper(model, losses=losses)
# ============================================================

import torch
from torch.nn import CrossEntropyLoss


def build_default_losses(num_classes: int = 4, device: str = None):
    """
    Build the default HerdNet loss configuration.

    Parameters
    ----------
    num_classes : int, optional
        Number of classes used in the classification map.
        Default = 4 (as in HerdNet repo: background + 3 animal classes).
    device : str, optional
        Device to place weight tensors on. If None, auto-detects CUDA.

    Returns
    -------
    list of dict
        List compatible with LossWrapper initialization.
    """
    from torch import Tensor

    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    # Class weights (repo default)
    weight = Tensor([0.1] + [1.0] * (num_classes - 1)).to(device)

    # Focal loss for density / localization
    focal_loss = FocalLoss(
        alpha=2,
        beta=4,
        reduction="mean",
        density_weight=None,
        normalize=False
    )

    # Cross-entropy loss for per-pixel classification
    ce_loss = CrossEntropyLoss(
        reduction="mean",
        weight=weight
    )

    losses = [
        {"loss": focal_loss, "idx": 0, "idy": 0, "lambda": 1.0, "name": "focal_loss"},
        {"loss": ce_loss, "idx": 1, "idy": 1, "lambda": 1.0, "name": "ce_loss"},
    ]

    return losses


In [14]:
# 1. Crear el modelo base
herdnet = HerdNet(num_layers=34, num_classes=4, down_ratio=2)

# 2. Construir la configuraci√≥n de p√©rdidas
losses = build_default_losses(num_classes=4)

# 3. Envolver el modelo
herdnet = LossWrapper(herdnet, losses=losses)

In [15]:
herdnet

LossWrapper(
  (model): HerdNet(
    (base_0): DLA(
      (base_layer): Sequential(
        (0): Conv2d(3, 16, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (level0): Sequential(
        (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (level1): Sequential(
        (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (level2): Tree(
        (tree1): BasicBlock(
          (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-0

In [16]:
# ============================================================
# üß† HERDNET TRAINER (Faithful Reimplementation)
# ============================================================
# Rewritten from scratch following the structure of
# animaloc.train.trainers (Delplanque, 2024).
# Handles training, validation, checkpointing, and metrics.
# ============================================================

import os
import torch
import time
from tqdm import tqdm
import torch.nn.functional as F


class HerdNetTrainer:
    """
    HerdNetTrainer
    --------------
    Manages full training and validation loops for HerdNet models
    wrapped with LossWrapper. Includes checkpoint saving, metric
    tracking, and early stopping.

    Parameters
    ----------
    model : torch.nn.Module
        Wrapped model (LossWrapper(HerdNet, losses)).
    optimizer : torch.optim.Optimizer
        Optimizer instance (e.g., Adam, SGD).
    train_loader : torch.utils.data.DataLoader
        Training data loader.
    val_loader : torch.utils.data.DataLoader
        Validation data loader.
    device : torch.device
        Device for computation.
    save_dir : str
        Directory where checkpoints and logs will be saved.
    save_every : int
        Save checkpoint every N epochs.
    patience : int
        Stop training early if val loss does not improve after N epochs.
    """

    def __init__(
        self,
        model,
        optimizer,
        train_loader,
        val_loader=None,
        device=None,
        save_dir="checkpoints",
        save_every=5,
        patience=10,
    ):
        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.save_dir = save_dir
        self.save_every = save_every
        self.patience = patience

        os.makedirs(save_dir, exist_ok=True)

    # --------------------------------------------------------
    # üîπ Save checkpoint
    # --------------------------------------------------------
    def save_checkpoint(self, epoch, val_loss):
        checkpoint_path = os.path.join(self.save_dir, f"herdnet_epoch_{epoch}.pth")
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "val_loss": val_loss,
            },
            checkpoint_path,
        )
        print(f"[CHECKPOINT] Saved: {checkpoint_path}")

    # --------------------------------------------------------
    # üîπ Training loop
    # --------------------------------------------------------
    def train(self, num_epochs):
        best_val_loss = float("inf")
        epochs_no_improve = 0

        for epoch in range(1, num_epochs + 1):
            print(f"\nüü© Epoch [{epoch}/{num_epochs}]")

            # --------------------------
            # Training phase
            # --------------------------
            self.model.train()
            train_loss = 0.0
            start_time = time.time()

            for batch in tqdm(self.train_loader, desc="Training", leave=False):
                images, targets = batch

                # Move to device
                images = images.to(self.device)
                targets = [t.to(self.device) for t in targets] if isinstance(targets, (list, tuple)) else targets.to(self.device)

                # Forward + loss
                self.optimizer.zero_grad()
                loss_dict = self.model(images, targets)
                total_loss = sum(loss_dict.values())

                total_loss.backward()
                self.optimizer.step()
                train_loss += total_loss.item()

            train_loss /= len(self.train_loader)
            duration = time.time() - start_time
            print(f"[TRAIN] Loss: {train_loss:.4f} | Time: {duration:.1f}s")

            # --------------------------
            # Validation phase
            # --------------------------
            if self.val_loader is not None:
                val_loss = self.validate()
                print(f"[VAL]   Loss: {val_loss:.4f}")

                # Check for improvement
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    epochs_no_improve = 0
                    self.save_checkpoint(epoch, val_loss)
                else:
                    epochs_no_improve += 1
                    if epochs_no_improve >= self.patience:
                        print("[EARLY STOPPING] Validation loss did not improve.")
                        break

            # Save regular checkpoint
            if epoch % self.save_every == 0:
                self.save_checkpoint(epoch, best_val_loss)

    # --------------------------------------------------------
    # üîπ Validation loop
    # --------------------------------------------------------
    def validate(self):
        self.model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for batch in tqdm(self.val_loader, desc="Validation", leave=False):
                images, targets = batch

                images = images.to(self.device)
                targets = [t.to(self.device) for t in targets] if isinstance(targets, (list, tuple)) else targets.to(self.device)

                _, loss_dict = self.model(images, targets)
                total_loss = sum(loss_dict.values())
                val_loss += total_loss.item()

        val_loss /= len(self.val_loader)
        return val_loss


In [None]:
# 1. Crear modelo y p√©rdidas
herdnet = HerdNet(num_layers=34, num_classes=4, down_ratio=2)
losses = build_default_losses(num_classes=4)
wrapped_model = LossWrapper(herdnet, losses=losses)

# 2. Preparar optimizador
optimizer = torch.optim.Adam(wrapped_model.parameters(), lr=1e-4)

# 3. Instanciar entrenador
trainer = HerdNetTrainer(
    model=wrapped_model,
    optimizer=optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    save_dir="checkpoints",
    save_every=2,
    patience=5
)

# 4. Entrenar
trainer.train(num_epochs=50)

NameError: name 'train_loader' is not defined