In [None]:
# Install dependencies

%pip install loguru==0.7.3 python-dotenv==1.0.1 PyYAML==6.0.2 torch==2.6.0 torchvision==0.21.0 tqdm==4.67.1 typer==0.15.1 matplotlib==3.10.0 pyarrow==18.1.0 setuptools==75.1.0 protobuf==4.25.3 ultralytics==8.3.107 ray==2.43.0 albumentations==2.0.5 shortuuid==1.0.13

In [None]:
# Imports

from pathlib import Path
from PIL import Image
from torch.optim.lr_scheduler import LRScheduler
from ultralytics import YOLO, settings
from ultralytics.data.dataset import YOLODataset
from ultralytics.models.yolo.detect import DetectionTrainer, DetectionValidator
from ultralytics.nn.tasks import DetectionModel
from ultralytics.utils import colorstr, LOGGER, SettingsManager
from ultralytics.utils.torch_utils import one_cycle
import gc
import glob
import io
import math
import matplotlib.pyplot as plt
import numpy as np
import os
import shortuuid
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.utils as vutils
import wandb
import yaml

In [None]:
# Utils

def load_config(config_file: str) -> dict:
    """Load and return configuration from YAML file."""
    with open(config_file, "r") as f:
        return yaml.safe_load(f)

def _get_wandb_key_colab() -> str:
    try:
        from google.colab import userdata # type: ignore

        if userdata.get("WANDB_API_KEY") is not None:
            return userdata.get("WANDB_API_KEY")
        else:
            raise ValueError("No WANDB key found")
    except:
        return None

def _get_wandb_env(path: Path) -> str:
    try:
        from dotenv import dotenv_values # type: ignore

        """Get W&B API key from Colab userdata or environment variable"""

        path = Path(path)
        if not path.exists():
            raise FileNotFoundError(f"Could not find .env file at {path}")

        print(f"Loading secrets from {path}")

        secrets = dotenv_values(path)
        print(f"Found keys: {list(secrets.keys())}")

        if "WANDB_API_KEY" not in secrets:
            raise KeyError(f"WANDB_API_KEY not found in {path}. Available keys: {list(secrets.keys())}")

        return secrets['WANDB_API_KEY']
    except:
        return None

def get_wandb_key(path: Path = "../.env") -> str:
    return _get_wandb_key_colab() if _get_wandb_key_colab() is not None else _get_wandb_env(path)

def clear_cache():
    # Clear CUDA cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Clear Python garbage collector
    gc.collect()

def get_device() -> str:
    try:
        return 0 if torch.cuda.is_available() else "cpu"
    except Exception as e:
        print(f"Error setting device: {e}")

def remove_models():
    pt_files = glob.glob("*.pt")
    print("Files to be removed:", pt_files)

    for file in pt_files:
        os.remove(file)

In [None]:
# Root path

ROOT_PATH = Path.cwd().parents[0]
DATASET_PATH = os.path.join(ROOT_PATH, "data", "raw") # Should also change the path in visdrone.yaml to point to the new location
WEIGHTS_PATH = os.path.join(ROOT_PATH, "models")
RUNS_PATH = os.path.join(ROOT_PATH, "runs")
WANDB_REPORT_PATH = os.path.join(ROOT_PATH, "reports")

In [None]:
settings = SettingsManager()
settings.update(
    # runs_dir=RUNS_PATH, 
    # weights_dir=WEIGHTS_PATH,
    # datasets_dir=DATASET_PATH,
    wandb=False
)

In [None]:
# Config

config_data = """
wandb: # Mandatory to use wandb to track the metrics
  project: "EyeInTheSky_merged"
  group: "distillation"
  dir: "reports"
distillation:
  use_cyclical_lr: False
  feature_layers: [11, 14, 17, 20] # Layers to extract features from
  visualize_feature_plot: False
  feature_plot_interval: 20 # Interval for plotting features
  temperature: 2.0 # Temperature for distillation
  alpha: 0.5 # Weight for the distillation loss
  max_lr: 0.01 # Maximum learning rate for the cyclical learning rate
  cycle_size: 20 # Number of epochs for each cycle
  group_scalers: {0: 1.0, 1: 1.0, 2: 1.0}
  teacher_model_path: "francescoperagine-universit-degli-studi-di-bari-aldo-moro/EyeInTheSky_merged/run_sdrn2wmo_model:v0" # Path to the teacher model. Do not leave it in the current directory as it will be removed.
train:
  project: "EyeInTheSky"
  data: "VisDrone.yaml"
  model: "yolo12n.pt" # Model with weights. To train it from scratch, use "yolo12n.yaml with pretrained: False"
  pretrained: True # Read above
  patience: 5
  task: detect
  epochs: 150
  weight_decay: 0.0
  lr0: 0.001
  lrf: 0.001
  nbs: 64
  seed: 42
  plots: True
  exist_ok: False
  save: True
  save_period: 5
  val: True
  warmup_epochs: 5
  visualize: True
  show: True
  single_cls: False
  rect: False
  resume: False
  fraction: 1.0
  freeze: None
  cache: False
  verbose: False
  amp: True
  save_crop: True
  save_conf: True
  save_txt: True
  save_json: True
val:
  project: "EyeInTheSky"
  task: detect
  data: "VisDrone.yaml"
  half: True
  conf: 0.25
  iou: 0.6
  split: "test"
  rect: True
  plots: True
  visualize: True
  verbose: False
  save_crop: True
  save_conf: True
  save_txt: True
  save_json: True
"""

In [None]:
# Load config
# config_path = root / "config" / "config.yaml"
# config = load_config(config_path)

config = yaml.safe_load(config_data)

config["wandb"].update({
    "dir": WANDB_REPORT_PATH,
})

config["train"].update({
    "device" : get_device(),
    # "save_dir": RUNS_PATH,
})

config["val"].update({
    "device" : get_device(),
    # "save_dir": RUNS_PATH,
})


In [None]:
# Dataset, Trainer, Validator

class VisDroneDataset(YOLODataset):
    """
    Custom dataset for VisDrone that merges pedestrian (0) and people (1) classes.
    
    This dataset handler performs class remapping at the earliest stage of the pipeline
    by combining pedestrian and people into a single 'persona' class and shifting all 
    other class indices down by one. The merged class mapping is stored as a class 
    attribute for access during training and validation.
    
    The remapping happens in the get_labels() method which modifies the label tensors
    directly, ensuring all downstream processing uses the merged classes.
    
    Class attributes:
        merged_names (dict): New class mapping after merging pedestrian and people classes
    """
    
    # Define the merged names as a class attribute to be accessible from the trainer
    merged_names = {
        0: 'persona',
        1: 'bicicletta',
        2: 'auto',
        3: 'furgone',
        4: 'camion',
        5: 'triciclo',
        6: 'triciclo-tendato',
        7: 'autobus',
        8: 'motociclo'
    }
    
    def __init__(self, *args, **kwargs):
        # Initialize parent class with modified kwargs
        super().__init__(*args, **kwargs)
        
        # Log class mapping
        LOGGER.info(f"{colorstr('VisDroneDataset:')} Using merged classes: {self.merged_names}")
    
    def get_labels(self):
        """
        Load and process labels with class remapping.
        """
        # Get labels from parent method
        labels = super().get_labels()
        
        # Process statistics
        people_count = 0
        shifted_count = 0
        
        # Process labels to merge classes
        for i in range(len(labels)):
            cls = labels[i]['cls']
            
            if len(cls) > 0:
                # Count 'people' instances
                people_mask = cls == 1
                people_count += np.sum(people_mask)
                
                # Merge class 1 (people) into class 0 (pedestrian -> person)
                cls[people_mask] = 0
                
                # Shift classes > 1 down by 1
                gt1_mask = cls > 1
                shifted_count += np.sum(gt1_mask)
                cls[gt1_mask] -= 1
                
                # Store modified labels
                labels[i]['cls'] = cls
        
        # Now set correct class count and names for training
        if hasattr(self, 'data'):
            # Update names and class count
            self.data['names'] = self.merged_names
            self.data['nc'] = len(self.merged_names)
        
        # Log statistics
        person_count = sum(np.sum(label['cls'] == 0) for label in labels)
        LOGGER.info(f"\n{colorstr('VisDroneDataset:')} Remapped {people_count} 'people' instances to {self.merged_names[0]}")
        LOGGER.info(f"{colorstr('VisDroneDataset:')} Total 'persona' instances after merge: {person_count}")
        LOGGER.info(f"{colorstr('VisDroneDataset:')} Shifted {shifted_count} instances of other classes")
        
        return labels

class MergedClassDetectionTrainer(DetectionTrainer):
    """
    Custom YOLO trainer that uses the VisDroneDataset with merged classes.
    
    Extends the standard DetectionTrainer to work with the merged-class dataset.
    The key modifications are in build_dataset() to use VisDroneDataset instead of
    the default, and in set_model_attributes() to properly update the model's class
    names and count to match the merged dataset.
    
    This ensures that all aspects of training - from data loading to loss calculation -
    work consistently with the merged class structure.
    """
    
    def build_dataset(self, img_path, mode="train", batch=None):
        """Build custom VisDroneDataset."""
        return VisDroneDataset(
            img_path=img_path,
            imgsz=self.args.imgsz,
            batch_size=batch or self.batch_size,
            augment=mode == "train",
            hyp=self.args,
            rect=self.args.rect if mode == "train" else True,
            cache=self.args.cache or None,
            single_cls=self.args.single_cls,
            stride=self.stride,
            pad=0.0 if mode == "train" else 0.5,
            prefix=colorstr(f"{mode}: "),
            task=self.args.task,
            classes=None,
            data=self.data,
            fraction=self.args.fraction if mode == "train" else 1.0,
        )
    
    def get_model(self, cfg=None, weights=None, verbose=True):
        """Create and return a DetectionModel."""
        
        model = DetectionModel(
            cfg=cfg, 
            nc=self.data["nc"],
            verbose=verbose,
        )

        model.args = self.args
        
        if weights:
            LOGGER.info(f"Loading weights into model")
            model.load(weights)
            
        return model    
    
    def set_model_attributes(self):
        """Update model attributes for merged classes."""
        # First call parent method to set standard attributes
        super().set_model_attributes()
        
        # Then update model with the merged class names
        if hasattr(self.model, 'names'):
            # Use the merged names directly from the dataset class
            self.model.names = VisDroneDataset.merged_names
            self.model.nc = len(VisDroneDataset.merged_names)
            
            # Also update data dictionary
            if hasattr(self, 'data'):
                self.data['names'] = VisDroneDataset.merged_names
                self.data['nc'] = len(VisDroneDataset.merged_names)

class MergedClassDetectionValidator(DetectionValidator):
    """
    Custom validator for evaluating models trained on merged VisDrone classes.
    
    Works in tandem with MergedClassDetectionTrainer to ensure that validation
    uses the same class merging as training. The build_dataset() method creates
    VisDroneDataset instances for validation, and set_model_attributes() updates
    the model's class configuration to match the merged dataset.
    
    This allows for consistent metrics calculation across training and evaluation.
    """
    
    def build_dataset(self, img_path, mode="val", batch=None):
        """Build custom VisDroneDataset for validation."""
        return VisDroneDataset(
            img_path=img_path,
            imgsz=self.args.imgsz,
            batch_size=batch or self.args.batch,
            augment=False,
            hyp=self.args,
            rect=self.args.rect,
            cache=None,
            single_cls=self.args.single_cls,
            stride=self.stride,
            pad=0.5,
            prefix=colorstr(f"{mode}: "),
            task=self.args.task,
            classes=None,
            data=self.data,
        )
       
    def set_model_attributes(self):
        """Update model attributes for merged classes if using a PyTorch model."""
        super().set_model_attributes()
        
        # Then update model with the merged class names
        if hasattr(self.model, 'names'):
            # Use the merged names directly from the dataset class
            self.model.names = VisDroneDataset.merged_names
            self.model.nc = len(VisDroneDataset.merged_names)
            
            # Also update data dictionary
            if hasattr(self, 'data'):
                self.data['names'] = VisDroneDataset.merged_names
                self.data['nc'] = len(VisDroneDataset.merged_names)

In [None]:
# FeatureAdaptation

class FeatureAdaptation(nn.Module):
    """
    Custom neural network layer that adapts student features to match teacher dimensions.
    
    This module performs feature transformation between different network architectures,
    enabling more effective knowledge distillation. It includes:
    
    1. Channel attention gating to emphasize important features
    2. Spatial context modules (with varying complexity based on layer importance)
    3. Channel dimension adaptation
    
    The adaptation process allows for comparing features from different network sizes
    (e.g., YOLOv12n student vs. YOLOv12x teacher) by transforming the student's
    feature space to align with the teacher's.
    
    Args:
        in_channels (int): Number of input channels (from student model)
        out_channels (int): Number of output channels (to match teacher model)
        layer_idx (int, optional): Layer index to determine adaptation complexity
    """
    
    def __init__(self, in_channels, out_channels, layer_idx=None):
        super().__init__()
        
        # Determine if this layer should use spatial context
        use_spatial = layer_idx in [8, 11, 14, 17, 20] 
        
        # Add gating mechanism for all layers
        self.gate = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # Global pooling to get channel-wise statistics
            nn.Conv2d(in_channels, in_channels, kernel_size=1),
            nn.ReLU(inplace=True),    # ReLU for non-linearity
            nn.Conv2d(in_channels, in_channels, kernel_size=1),
            nn.Sigmoid()              # Output in range [0,1] to act as gates
        )
        
        if use_spatial:
            # For the most critical layers, use full depthwise
            self.spatial = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels)
           
        else:
            # For moderately important layers, use simpler spatial context
            self.spatial = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels//4)
            
        # Regular channel adapter
        self.adapter = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        
        # Initialize with Kaiming normalization
        nn.init.kaiming_normal_(self.adapter.weight)

        # Initialize with small weights to maintain stability
        nn.init.xavier_normal_(self.spatial.weight, gain=0.1)
            
        # Initialize gate with small weights to start with mild gating
        for m in self.gate.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight, gain=0.1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        # Compute importance gates
        importance = self.gate(x)
        
        # Applies attention weights to features to emphatize the most informative features
        gated_features = x * importance
        
        # Apply spatial context if available
        if self.spatial is not None:
            # Apply spatial context with residual connection (on gated features)
            gated_features = self.spatial(gated_features) + gated_features
            
        # Apply channel adaptation
        return self.adapter(gated_features)

In [None]:
# DecayingCyclicalLR scheduler

class DecayingCyclicalLRSchedulerCallback:
    """
    Callback that sets up the decaying cyclical learning rate scheduler.
    
    This callback integrates a custom learning rate scheduler into the training process
    by attaching it to the trainer during the pretrain routine. It forwards configuration
    parameters from the distillation config to the scheduler itself.
    
    Args:
        **kwargs: Configuration parameters from config["distillation"], including:
            max_lr (float): Maximum learning rate peak
            cycle_size (int): Number of epochs per cycle
            group_scalers (dict): Per-parameter group learning rate multipliers
    """

    def __init__(self, **kwargs):
        self.kwargs = kwargs
        
    def set_scheduler(self, trainer):
        trainer.scheduler = DecayingCyclicalLR(
            trainer=trainer, 
            **self.kwargs
        )

class DecayingCyclicalLR(LRScheduler):
    """
    Learning rate scheduler that combines cyclical patterns with gradual decay.
    
    Creates a wave-like learning rate pattern that oscillates between min_lr and max_lr
    while gradually decreasing the overall magnitude over time. Features include:
    
    1. Initial warmup phase with linear lr increase
    2. Cyclical pattern using cosine interpolation
    3. Gradual decay of cycle amplitude over training
    4. Support for different lr multipliers per parameter group
    
    Args:
        trainer: The trainer instance to access optimizer and training parameters
        **kwargs: Configuration parameters from config["distillation"], including:
            max_lr (float): Peak learning rate multiplier
            cycle_size (int): Number of epochs in each lr cycle
            group_scalers (dict): Parameter group-specific lr multipliers
    
    The scheduler integrates the trainer's built-in parameters like epochs, lrf (final lr factor),
    and warmup_epochs with the distillation-specific configuration.
    """
    def __init__(self, 
                 trainer, 
                 **kwargs
            ):
        self.optimizer = trainer.optimizer
        self.min_lr = trainer.args.lr0
        self.max_lr = trainer.args.lr0 * 10 if kwargs.get("max_lr") is None else kwargs.get("max_lr")
        self.lrf = trainer.args.lrf
        self.cycle_size = kwargs.get("cycle_size", 20)
        self.warmup_epochs = trainer.args.warmup_epochs if hasattr(trainer.args, 'warmup_epochs') else 0
        self.warmup_start_lr = trainer.args.warmup_start_lr if hasattr(trainer.args, 'warmup_start_lr') else self.min_lr * 0.1
        self.epochs = trainer.args.epochs

        self.group_scalers = kwargs.get("group_scalers") if kwargs.get("group_scalers") is not None else {0: 1.0, 1: 1.0, 2: 1.0}

        # Debug logging
        LOGGER.info(f"DecayingCyclicalLR: Min lr: {self.min_lr}, Max lr: {self.max_lr}, Cycle size: {self.cycle_size}, cycle size: {self.cycle_size}, lrf: {self.lrf}, group_scalers: {self.group_scalers}, warmup_epochs: {self.warmup_epochs}, warmup_start_lr: {self.warmup_start_lr}, epochs: {self.epochs}")

        super().__init__(self.optimizer)

    def get_lr(self):
        # During warmup: gradually increase from warmup_start_lr to min_lr over warmup_steps
        if self.last_epoch < self.warmup_epochs:
            warmup_factor = self.last_epoch / max(1, self.warmup_epochs)
            base_lr = self.warmup_start_lr + (self.min_lr - self.warmup_start_lr) * warmup_factor
            return [base_lr * self.group_scalers.get(i, 1.0) for i in range(len(self.optimizer.param_groups))]

        # After warmup: adjust using cyclical pattern with decay
        post_warmup_epoch = self.last_epoch - self.warmup_epochs
        remaining_epochs = self.epochs - self.warmup_epochs

        # Across-cycles decay (outer schedule)
        progress = post_warmup_epoch / max(1, remaining_epochs)
        progress = min(1.0, progress)  # cap at 1.0

        # decay_factor = self.lr_final_factor ** progress # Exponential decay
        # decay_factor = (1 - progress) ** 2 + self.lr_final_factor * progress # Quadratic decay
        # decay_factor = self.lr_final_factor + (1 - self.lr_final_factor) * 0.5 * (1 + math.cos(math.pi * progress)) # Cosine decay
        decay_factor = 1.0 - (1.0 - self.lrf) * progress  # Linear decay

        current_min_lr = self.min_lr * decay_factor
        current_max_lr = self.max_lr * decay_factor

        cycle_position = post_warmup_epoch % self.cycle_size

        # Cosine interpolation - adjusted to ensure proper cycle behavior
        # Produces a value that starts at 0, peaks at 1 at half cycle, and returns to 0
        normalized = cycle_position / self.cycle_size  # 0 to 1 over the cycle
        # Modified cosine formula: 0.5 * (1 - cos(2π * normalized))
        # At normalized=0: 0.5 * (1 - cos(0)) = 0.5 * (1 - 1) = 0
        # At normalized=0.5: 0.5 * (1 - cos(π)) = 0.5 * (1 - (-1)) = 1
        # At normalized=1: 0.5 * (1 - cos(2π)) = 0.5 * (1 - 1) = 0
        factor = 0.5 * (1 - math.cos(2 * math.pi * normalized))
        
        # Calculate learning rate
        base_lr = current_min_lr + (current_max_lr - current_min_lr) * factor

        # Apply group-specific scaling
        groups_lrs = [np.float64(base_lr * self.group_scalers.get(i, 1.0)) for i in range(len(self.optimizer.param_groups))]
        return groups_lrs

In [None]:
# Distillation Callback

class DistillationCallback:
    """
    Core callback that implements the knowledge distillation process.
    
    This callback manages the entire feature-based distillation mechanism by:
    
    1. Capturing intermediate features from both teacher and student models
    2. Creating and managing feature adaptation layers to align dimensions
    3. Computing cosine similarity losses between adapted features
    4. Balancing detection loss with feature distillation loss
    5. Learning optimal layer importance weights and loss component weights
    
    The callback hooks into multiple phases of training:
    - on_train_start: Initializes parameters and adds them to the optimizer
    - on_train_epoch_start: Sets up feature extraction hooks and custom loss function
    - on_fit_epoch_end: Logs metrics and visualizations to W&B
    
    Args:
        teacher_model: Pretrained YOLO model to use as knowledge source
        temperature (float): Controls softness of weight distribution (from config)
        alpha (float): Balance between detection and distillation loss (from config)
        feature_layers (list): Layer indices to extract features from (from config)
        layer_channels (dict): Pre-computed channel dimensions for each layer
    
    The callback dynamically learns two sets of weights during training:
    - layer_weights: Importance of each feature layer in the distillation loss
    - component_weights: Automatic balancing of box/class/DFL loss components
    """

    def __init__(self, 
            teacher_model, 
            temperature = 2.0, 
            alpha = 0.5, 
            feature_layers = [11, 14, 17, 20], 
            layer_channels = {}
        ):

        self.teacher_model = teacher_model
        self.temperature = temperature
        self.alpha = alpha
        self.feature_layers = feature_layers
        self.layer_channels = layer_channels
        self.current_features = {}
        self.teacher_hooks = {}
        self.student_hooks = {}
        self.adapters = {}
        self.layer_weight_values = {layer_idx: 0.0 for layer_idx in self.feature_layers}
        self.cosine_losses = []
        self.detection_losses = []
        self.distillation_losses = []
        self.layer_weights = None
        self.component_weights = None
        self.component_weight_history = {'box': [], 'cls': [], 'dfl': []}
        self.initialized = False

    def init_parameters(self, trainer):
        """Initialize parameters and register them with optimizer once"""
        if self.initialized:
            return
            
        # Move teacher to student's device/dtype
        device = next(self.teacher_model.parameters()).device
        dtype = next(self.teacher_model.parameters()).dtype
        self.teacher_model = self.teacher_model.to(device=device, dtype=dtype)
        
        # Initialize component weights
        self.component_weights = nn.Parameter(torch.zeros(3, device=device))
        self._add_optimizer_param_group(
            trainer=trainer,
            params=[self.component_weights],
            name='loss_component_weights',
            weight_decay_multiplier=0.2
        )
            
        # Initialize layer weights
        equal_weight = 1.0 / len(self.feature_layers)
        weights_tensor = torch.tensor([equal_weight for _ in self.feature_layers], device=device)
        self.layer_weights = nn.Parameter(weights_tensor, requires_grad=True)
        self._add_optimizer_param_group(
            trainer=trainer,
            params=[self.layer_weights],
            name='layer_weight_params',
            weight_decay_multiplier=1.0
        )
        
        # Create feature adapters
        for layer_idx in self.feature_layers:
            adapter_key = f'adapter_{layer_idx}'
            
            # Get channel dimensions from pre-computed values
            if layer_idx in self.layer_channels:
                student_channels = self.layer_channels[layer_idx]['student']
                teacher_channels = self.layer_channels[layer_idx]['teacher']
                
                # Create feature adaptation layer with normalization
                self.adapters[adapter_key] = FeatureAdaptation(student_channels, teacher_channels, layer_idx=layer_idx).to(device)
                
                # Add adapter parameters to optimizer with stronger regularization
                self._add_optimizer_param_group(
                    trainer=trainer,
                    params=list(self.adapters[adapter_key].parameters()),
                    name=f'adapter_params_{layer_idx}',
                    weight_decay_multiplier=5.0
                )
            else:
                LOGGER.warning(f"Missing channel dimensions for layer {layer_idx}")
                
        self.initialized = True
        
        # Set original criterion if not already done
        if not hasattr(trainer.model, "original_criterion"):
            trainer.model.original_criterion = trainer.model.init_criterion()
    
    def setup_criterion(self, trainer):
        """Register hooks for feature extraction"""
        # Clear features dictionary
        self.current_features = {}

        def get_features(name, layer_idx):
            def hook(module, input, output):
                if name not in self.current_features:
                    self.current_features[name] = {}
                self.current_features[name][layer_idx] = output.detach()
            return hook
       
        # Register hooks for both models at each feature layer
        for layer_idx in self.feature_layers:
            self.teacher_hooks[layer_idx] = self.teacher_model.model.model[layer_idx].register_forward_hook(
                get_features('teacher', layer_idx)
            )
            self.student_hooks[layer_idx] = trainer.model.model[layer_idx].register_forward_hook(
                get_features('student', layer_idx)
            )
        
        # Define distillation criterion
        def distillation_criterion(preds, batch):

            # Get original loss
            original_loss, loss_items = trainer.model.original_criterion(preds, batch)

            # Apply uncertainty weighting (Kendall et al. 2018)
            # Converting parameters to precisions using exponential - ensures positive weights
            precision = torch.exp(-self.component_weights)

            # Weight losses - dividing by precision penalizes high uncertainty components
            weighted_loss = original_loss * precision + 0.5 * self.component_weights  # The 0.5 * log_var term is from the math derivation
            weighted_detection_loss = weighted_loss.sum()

            # Run teacher forward pass to capture features
            with torch.no_grad():
                self.teacher_model.eval()
                teacher_input = batch["img"].float()
                _ = self.teacher_model.model(teacher_input)
            
            # Calculate cosine loss for each layer and combine
            layer_losses = []
            
            for i, layer_idx in enumerate(self.feature_layers):
                if ('teacher' in self.current_features and 
                    'student' in self.current_features and
                    layer_idx in self.current_features['teacher'] and
                    layer_idx in self.current_features['student']):
                    
                    t_feat = self.current_features['teacher'][layer_idx]
                    s_feat = self.current_features['student'][layer_idx]
                    
                    # Apply adaptation to student features
                    adapter_key = f'adapter_{layer_idx}'
                    if adapter_key in self.adapters:
                        s_feat_adapted = self.adapters[adapter_key](s_feat)
                        
                        # Normalize features
                        t_feat = F.normalize(t_feat.flatten(1), p=2, dim=1)
                        s_feat_adapted = F.normalize(s_feat_adapted.flatten(1), p=2, dim=1)
                        
                        # Cosine similarity loss for this layer
                        layer_cosine_loss = 1 - F.cosine_similarity(s_feat_adapted, t_feat, dim=1).mean()
                        layer_losses.append(layer_cosine_loss)

            # Apply softmax to weights with minimum weight guarantee
            alpha = 0.2  # This is your minimum weight parameter (not the same as distillation alpha)
            uniform_weights = torch.ones_like(self.layer_weights) / len(self.layer_weights)
            normalized_weights = (1 - alpha) * F.softmax(self.layer_weights / self.temperature, dim=0) + alpha * uniform_weights
            # normalized_weights = F.softmax(self.layer_weights / self.temperature, dim=0)

            # Calculate weighted loss
            total_cosine_loss = sum(weight * loss for weight, loss in zip(normalized_weights, layer_losses))

            # Combined loss using alpha parameter
            distillation_loss = (1 - self.alpha) * weighted_detection_loss + self.alpha * total_cosine_loss
            
            # Calculate current component weights for logging
            current_weights = precision.detach().cpu()
            
            # Safe scalar extraction for logging
            orig_loss_val = original_loss.sum().item()
            weighted_det_loss_val = weighted_detection_loss.item()
            cosine_loss_val = total_cosine_loss.item()
            dist_loss_val = distillation_loss.item()

            # Store current weights for history
            self.component_weight_history['box'].append(current_weights[0].item())
            self.component_weight_history['cls'].append(current_weights[1].item())
            self.component_weight_history['dfl'].append(current_weights[2].item())
            
            # Track metrics
            self.detection_losses.append(weighted_det_loss_val)
            self.cosine_losses.append(cosine_loss_val)
            self.distillation_losses.append(dist_loss_val)

            # Store weight values for monitoring
            for i, layer_idx in enumerate(self.feature_layers):
                self.layer_weight_values[layer_idx] = normalized_weights[i].item()

            return distillation_loss, loss_items
 
        # Set original criterion if not already done
        if not hasattr(trainer.model, "original_criterion"):
            trainer.model.original_criterion = trainer.model.init_criterion()
            
        trainer.model.criterion = distillation_criterion

    def _add_optimizer_param_group(self, trainer, params, name, weight_decay_multiplier=1.0):

        if not hasattr(trainer, 'optimizer') or trainer.optimizer is None:
            return False
            
        # Check if this parameter group already exists
        for group in trainer.optimizer.param_groups:
            if group.get('name') == name:
                LOGGER.info(f"Parameter group '{name}' already exists, skipping addition")
                return False
        
        # Get base parameters from the first parameter group
        pg0 = trainer.optimizer.param_groups[0]
        
        # Get all parameters from base group
        param_values = {k: v for k, v in pg0.items() if k != 'params' and k != 'name'}

        # Override weight decay with multiplier
        if weight_decay_multiplier != 0:
            param_values['weight_decay'] = param_values.get('weight_decay', 0.005) * weight_decay_multiplier
        
        # Create and add the parameter group
        new_param_group = {
            'params': params,
            'name': name,
            **param_values
        }
        
        trainer.optimizer.add_param_group(new_param_group)

        LOGGER.info(f"Added parameter group '{name}' with {len(params) if isinstance(params, list) else sum(1 for _ in params.parameters())} parameters")
        return True
    
    def log_metrics(self, trainer):
        """Clean up hooks and log metrics at the end of each epoch"""
        # Remove all hooks
        for layer_idx, hook in self.teacher_hooks.items():
            hook.remove()
        for layer_idx, hook in self.student_hooks.items():
            hook.remove()

        if len(self.detection_losses) > 0:

            # Log metrics
            avg_detection = sum(self.detection_losses) / max(len(self.detection_losses), 1)
            avg_cosine = sum(self.cosine_losses) / max(len(self.cosine_losses), 1)
            avg_distillation = sum(self.distillation_losses) / max(len(self.distillation_losses), 1)
            
            # Log to wandb with per-layer metrics
            try:
                metrics_dict = {
                    **trainer.metrics,
                    "metrics/fitness": trainer.fitness, 
                    "distillation/detection_loss": avg_detection,
                    "distillation/cosine_loss": avg_cosine, 
                    "distillation/distillation_loss": avg_distillation,
                }
                
                # Add per-layer metrics if available
                for layer_idx in self.feature_layers:
                    layer_prefix = f"distillation/layer_{layer_idx}"

                    adapter_key = f'adapter_{layer_idx}'
                    if adapter_key in self.adapters:
                        # Adapter network - transformation needed to make student features match teacher ones
                        adapter_norm = torch.linalg.norm(next(self.adapters[adapter_key].parameters())).item()
                        metrics_dict[f"{layer_prefix}/adapter_norm"] = adapter_norm
                    #  Contribution weights
                    metrics_dict[f"{layer_prefix}/importance"] = self.layer_weight_values[layer_idx]

                    # Importance weights lr optimizer (only first layer)
                    if hasattr(self, 'layer_weight_optimizer') and layer_idx == 0:
                        # Only log once since it's shared across all layers
                        for pg_idx, pg in enumerate(self.layer_weight_optimizer.param_groups):
                            metrics_dict[f"{layer_prefix}/importance_optimizer"] = pg.get('lr', 0)

                # Add component weights to metrics dictionary
                if len(self.component_weight_history['box']) > 0:
                    # Get the latest weight values
                    metrics_dict["distillation/box_weight"] = self.component_weight_history['box'][-1]
                    metrics_dict["distillation/cls_weight"] = self.component_weight_history['cls'][-1]
                    metrics_dict["distillation/dfl_weight"] = self.component_weight_history['dfl'][-1]

                wandb.log(metrics_dict, commit=False)

            except Exception as e:
                LOGGER.warning(f"Failed to log metrics to wandb: {e}")

        # Clear metrics for next epoch
        self.detection_losses = []
        self.cosine_losses = []
        self.distillation_losses = []
        self.component_weight_history = {'box': [], 'cls': [], 'dfl': []}


In [None]:
# Feature Visualization Callback

class FeatureVisualizationCallback:
    """
    Callback for visualizing feature maps during distillation training.
    
    Periodically captures and visualizes feature activations from specified layers
    of the student model to provide insights into the distillation process. The
    visualizations are logged to Weights & Biases as images.
    
    The callback works by:
    1. Registering hooks to capture feature maps at the start of specified epochs
    2. Creating grid visualizations of feature channels
    3. Logging these visualizations to W&B for monitoring and analysis
    
    Args:
        layers (list): Layer indices to visualize features from (from config)
        interval (int): Epoch interval between visualizations (from config)
        figsize (tuple): Size of the generated matplotlib figures
        grid_size (int): Number of feature channels per row/column in the grid
    
    Visualization only occurs every `interval` epochs to avoid excessive overhead
    while still providing useful insights into the feature learning process.
    """
    
    def __init__(self, layers, interval=20, figsize=(8.0, 8.0), grid_size=2):
        self.layers = layers
        self.interval = interval
        self.features = {}
        self.hooks = {}
        self.figsize = figsize
        self.grid_size = grid_size  # Number of images per row and column in each grid
        
    def set_hooks(self, trainer):
        # Clear previous features and hooks
        self.features = {}
        
        # Only register hooks on visualization epochs
        if (trainer.epoch) % self.interval != 0:
            return
            
        # Hook function to capture features
        def get_features(layer_idx):
            def hook(module, input, output):
                # Store just one image from batch for visualization
                self.features[layer_idx] = output[0].detach().clone()
            return hook
        
        # Register hooks for both models
        for layer_idx in self.layers:
            self.hooks[layer_idx] = trainer.model.model[layer_idx].register_forward_hook(
                get_features(layer_idx)
            )
    
    def plot_figures(self, trainer):
        # Skip if not visualization epoch
        if (trainer.epoch) % self.interval != 0:
            return
            
        # Remove hooks
        for hook in self.hooks.values():
            hook.remove()
   
        # Create visualizations for each layer
        for layer_idx, feature_maps in self.features.items():
            # Select all channels for visualization
            num_channels = feature_maps.shape[0]
            feature_subset = feature_maps.cpu()
            
            # Normalize each channel for better visualization
            for i in range(feature_subset.size(0)):
                min_val = feature_subset[i].min()
                max_val = feature_subset[i].max()
                if max_val > min_val:
                    feature_subset[i] = (feature_subset[i] - min_val) / (max_val - min_val)
            
            # Calculate how many grids we need to display all channels
            channels_per_grid = self.grid_size * self.grid_size
            num_grids = (num_channels + channels_per_grid - 1) // channels_per_grid  # Ceiling division
            
            # Create and log multiple grid images
            for grid_idx in range(num_grids):
                start_idx = grid_idx * channels_per_grid
                end_idx = min(start_idx + channels_per_grid, num_channels)
                current_channels = feature_subset[start_idx:end_idx]
                
                # Skip if no channels in this grid
                if len(current_channels) == 0:
                    continue
                
                # Create grid of feature maps
                grid = vutils.make_grid(
                    current_channels.unsqueeze(1),
                    nrow=self.grid_size,
                    padding=2
                )
                
                # Convert grid to numpy for matplotlib
                grid_np = grid.permute(1, 2, 0).numpy().astype(np.float32)
                
                # Create a matplotlib figure with specified size
                plt.figure(figsize=self.figsize)
                plt.imshow(grid_np)
                plt.title(f"Layer {layer_idx} Features - Epoch {trainer.epoch+1} - Group {grid_idx+1}/{num_grids}")
                plt.axis('off')
                plt.tight_layout()
                
                # Save figure to a buffer
                buf = io.BytesIO()
                plt.savefig(buf, format='png', dpi=100)
                plt.close()
                buf.seek(0)
                
                # Convert buffer to PIL Image
                feature_image = Image.open(buf)
                
                # Log to wandb
                wandb.log({f"features/layer_{layer_idx}_epoch_{trainer.epoch+1}_group_{grid_idx+1}": wandb.Image(feature_image)}, commit=False)

In [None]:
# Distillation Orchestrator

class DistillationOrchestrator:
    """
    Orchestrates the knowledge distillation process between teacher and student YOLO models.
    
    This class handles the setup and coordination of all components needed for distillation:
    - Initializes and connects teacher and student models
    - Sets up feature extraction hooks at specified network layers
    - Configures feature adaptation modules to match dimensions between models
    - Registers callbacks for distillation loss calculation and visualization
    - Manages learning rate scheduling (standard or cyclical)
    
    The orchestrator follows a callback-based architecture where different
    components hook into the training process at specific points (epoch start/end,
    training start, etc.) without modifying the core training loop.
    
    Most configuration parameters are passed through kwargs, which typically come from
    a YAML config file (e.g., config["distillation"] section). This allows for flexible
    configuration without changing code.
    
    Key processes:
    1. Pre-computing channel dimensions to create appropriate feature adapters
    2. Setting up distillation callbacks to capture and compare features
    3. Configuring learning rate schedulers (optional cyclical scheduler)
    4. Managing feature visualization at specified intervals
    
    Args:
        teacher_trainer (MergedClassDetectionTrainer): Trainer instance with teacher model
        student_trainer (MergedClassDetectionTrainer): Trainer instance with student model
        **kwargs: Configuration parameters typically from config["distillation"], including:
            feature_layers (list): Layer indices from which to extract features for distillation
            temperature (float): Temperature parameter for softening importance weights
            alpha (float): Weight balancing detection loss vs. distillation loss (0-1)
            use_cyclical_lr (bool): Whether to use cyclical learning rate scheduling
            max_lr (float): Maximum learning rate for cyclical scheduler
            cycle_size (int): Number of epochs per learning rate cycle
            group_scalers (dict): Per-parameter group learning rate multipliers
            feature_plot_interval (int): Interval (in epochs) for feature map visualization
    
    Example:
        ```
        teacher_trainer = MergedClassDetectionTrainer(teacher_config)
        student_trainer = MergedClassDetectionTrainer(student_config)
        
        # Load distillation settings from config
        orchestrator = DistillationOrchestrator(
            teacher_trainer=teacher_trainer,
            student_trainer=student_trainer,
            **config["distillation"]
        )
        
        # Start distillation training
        results = orchestrator.start()
        ```
    """

    def __init__(
            self, 
            teacher_trainer : MergedClassDetectionTrainer, 
            student_trainer : MergedClassDetectionTrainer, 
            **kwargs
        ):
        """Initialize with separate trainers for teacher and student"""
        self.teacher_trainer = teacher_trainer
        self.student_trainer = student_trainer
        
        self.feature_layers = kwargs.get("feature_layers", [20])
        self.temperature = kwargs.get("temperature", 2.0)
        self.alpha = kwargs.get("alpha", 0.5)
        self.use_cyclical_lr = kwargs.get("use_cyclical_lr", False)
        self.max_lr = kwargs.get("max_lr", 0.01)
        self.cycle_size = kwargs.get("cycle_size", 20)
        self.group_scalers = kwargs.get("group_scalers", {0: 1.0, 1: 1.0, 2: 1.0})
        self.visualize_feature_plot = kwargs.get("visualize_feature_plot", False)
        self.feature_plot_interval = kwargs.get("feature_plot_interval", 20)
        
        # Access models for convenience
        self.teacher_model = YOLO(self.teacher_trainer.args.model)
        self.student_model = YOLO(self.student_trainer.args.model)

        self.layer_channels = self._compute_all_channel_dimensions(self.feature_layers)
        
        if not self.feature_layers:
            return
            
        # Create and register the distillation callback
        self.distill_callback = DistillationCallback(
            teacher_model=self.teacher_model,
            temperature=self.temperature,
            alpha=self.alpha,
            feature_layers=self.feature_layers,
            layer_channels=self.layer_channels
        )
        
        # Add the callback to the student trainer
        self.student_trainer.add_callback("on_train_start", self.distill_callback.init_parameters)
        self.student_trainer.add_callback("on_train_epoch_start", self.distill_callback.setup_criterion)
        self.student_trainer.add_callback("on_fit_epoch_end", self.distill_callback.log_metrics)

        # Visualization callback

        if self.visualize_feature_plot:

            visualization_callback = FeatureVisualizationCallback(layers=self.feature_layers, interval=self.feature_plot_interval)

            self.student_trainer.add_callback("on_train_epoch_start", visualization_callback.set_hooks)
            self.student_trainer.add_callback("on_train_epoch_end", visualization_callback.plot_figures)

        if self.use_cyclical_lr:

            cyclical_lr_scheduler_callback = DecayingCyclicalLRSchedulerCallback(
                max_lr=self.max_lr,
                cycle_size=self.cycle_size,
                group_scalers=self.group_scalers
            )
            
            student_trainer.add_callback("on_pretrain_routine_end", cyclical_lr_scheduler_callback.set_scheduler)
        else:
            # Refresh default scheduler to ensure all parameter groups are included

            def refresh_scheduler_callback(trainer):
                """
                Refresh the scheduler once after all parameter groups have been added
                to ensure proper learning rate scheduling for all groups.
                """
                # Default YOLO scheduler with cosine lr - one_cycle
                trainer.lf = one_cycle(1, trainer.args.lrf, trainer.epochs)  # cosine 1->hyp['lrf']
                trainer.scheduler = optim.lr_scheduler.LambdaLR(trainer.optimizer, lr_lambda=trainer.lf)
                
                # Restore scheduler state
                trainer.scheduler.last_epoch = trainer.start_epoch - 1
                
                LOGGER.info(f"Refreshed scheduler to include all {len(trainer.optimizer.param_groups)} parameter groups")
            student_trainer.add_callback("on_train_start", refresh_scheduler_callback)

    def _compute_all_channel_dimensions(self, feature_layers):
        """Pre-compute channel dimensions for all feature layers"""
        LOGGER.info(f"Computing channel dimensions for {len(feature_layers)} layers...")
        
        # Dictionary to store channel dimensions
        layer_channels = {}
        
        # Function to extract channels from a single model
        def get_model_channels(model, name):
            device = next(model.parameters()).device
            dtype = next(model.parameters()).dtype
            
            # Create a very small dummy input to minimize computation
            dummy_input = torch.zeros(1, 3, self.student_trainer.args.imgsz, self.student_trainer.args.imgsz, device=device, dtype=dtype)
            
            # Store activation outputs
            activations = {}
            
            # Define a clean hook function
            def hook_func(idx):
                def hook(module, input, output):
                    activations[idx] = output
                return hook
            
            # Register hooks
            handles = []
            for idx in feature_layers:
                handles.append(model.model.model[idx].register_forward_hook(hook_func(idx)))
            
            # Do a single forward pass with no gradient computation
            with torch.no_grad():
                model.eval()  # Ensure model is in eval mode
                _ = model(dummy_input)
            
            # Get channels and clean up
            channels = {}
            for idx, activation in activations.items():
                channels[idx] = activation.shape[1]  # Channel dimension
            
            # Remove hooks
            for handle in handles:
                handle.remove()
            
            return channels
        
        # Get channels for both models
        student_channels = get_model_channels(self.student_model, "student")
        teacher_channels = get_model_channels(self.teacher_model, "teacher")
        
        # Combine and format results
        for layer_idx in feature_layers:
            s_ch = student_channels.get(layer_idx)
            t_ch = teacher_channels.get(layer_idx)
            
            if s_ch is not None and t_ch is not None:
                layer_channels[layer_idx] = {
                    'student': s_ch,
                    'teacher': t_ch
                }
                LOGGER.info(f"Layer {layer_idx}: Student channels = {s_ch}, Teacher channels = {t_ch}")
            else:
                LOGGER.warning(f"Missing channel info for layer {layer_idx}")
        
        return layer_channels

    def start(self):
        """Run distillation training by invoking student trainer once"""
        LOGGER.info(f"Distillation trainer initialized with temperature={self.temperature}, alpha={self.alpha}")
        results = self.student_trainer.train()

        return results

In [None]:
# Wandb init

settings.update({
    "wandb": True
})

# TODO: A wandb warning for step sync sometimes pops up, most likely because of feature imgs logging. 
# All the wandb log should be aggregated in a single call at the end of fit/train epoch.

key = get_wandb_key()
wandb_run_id = shortuuid.uuid()[:8]
wandb.login(key=key, relogin=True, force=True)
 
run = wandb.init(
    id=wandb_run_id,
    save_code=True,
    **config["wandb"]
)

In [None]:
# Load teacher model

# Uncomment to use teacher model defined in config for distillation or use wandb artifact
# teacher_model_path = config["distillation"]["teacher_model_path"]

# Wandb - Load teacher model artifact

teacher_artifact = run.use_artifact(config["distillation"]["teacher_model_path"], type='model')
teacher_artifact_dir = teacher_artifact.download()
teacher_model_path = f"{teacher_artifact_dir}/best.pt"

In [None]:
# Teacher trainer init

teacher_config = config["train"].copy()
teacher_config.update({
    "model": teacher_model_path,
    "freeze": 21,
})
teacher_trainer = MergedClassDetectionTrainer(overrides=teacher_config)

In [None]:
# Config update

config["distillation"].update({
    "use_cyclical_lr": True,
    "feature_layers": [20],
    "visualize_feature_plot": True,
    "temperature": 4,
    "alpha": 0.7,
    "max_lr": 0.005,
    "cycle_size": 3,   
    "group_scalers": {0: 1.0, 1: 1.0, 2: 1.0}
})

# Trainer config

config["train"].update({
    "pretrained": True,
    "epochs": 1,
    "imgsz": 640,
    "optimizer": "AdamW",
    "lr0": 0.005,
    "lrf": 0.001,
    "warmup_epochs": 0,
    "patience": config["distillation"]["use_cyclical_lr"] | 5,
    "batch": 16,
    "workers": 8,
    "momentum": 0.937,
    "box": 3.5,
    "cls": 0.3,
    "dfl": 1,
})

In [None]:
# WandB log config

wandb.log({
    **config["train"],
    **config["val"],
    **config["distillation"]
}, commit=False)

In [None]:
# Student trainer init

student_trainer = MergedClassDetectionTrainer(overrides=config["train"])

In [None]:
# Create the orchestrator

orchestrator = DistillationOrchestrator(
    teacher_trainer=teacher_trainer,
    student_trainer=student_trainer, 
    **config["distillation"]
)

In [None]:
# Run distillation training

train_results = orchestrator.start()

In [None]:
# Load the best student model artifact from wandb to test it

student_model_path = os.path.join(orchestrator.student_trainer.args.save_dir, "weights", "best.pt")

student_model_to_test = YOLO(student_model_path)

test_results = student_model_to_test.val(
    validator=MergedClassDetectionValidator, 
    **config['val']
)

In [None]:
prefixed_results_dict = {f"test/{k}": v for k, v in test_results.results_dict.items()}

wandb.init(
    id=run.id,
    resume="must",
)

# Prepare all metrics in their respective dictionaries
metrics = {
    "test/metrics/fitness": test_results.fitness,
    **prefixed_results_dict
}

# Add speed metrics
for key, value in test_results.speed.items():
    metrics[f"speed/{key}"] = value

# Add class-wise mAP values
for i, map_value in enumerate(test_results.maps):
    if i in test_results.names:
        class_name = test_results.names[i]
        metrics[f"test/metrics/mAP/{class_name}"] = float(map_value)

# Log everything in a single call
wandb.log(metrics)

In [None]:
wandb.finish()
remove_models()