In [None]:
%pip install loguru==0.7.3 python-dotenv==1.0.1 PyYAML==6.0.2 torch==2.5.1 tqdm==4.67.1 typer==0.15.1 matplotlib==3.10.0 pyarrow==18.1.0 setuptools==75.1.0 protobuf==4.25.3 ultralytics ray==2.43.0 albumentations==2.0.5 pandas

In [None]:
from eyeinthesky.utils import *
import yaml
from ultralytics import YOLO, settings
from ultralytics.data.dataset import YOLODataset
from ultralytics.nn.tasks import DetectionModel
from ultralytics.models.yolo.detect import DetectionTrainer, DetectionValidator
from ultralytics.utils import colorstr, LOGGER, DEFAULT_CFG
import wandb
import pandas as pd
import numpy as np
from ultralytics.utils.loss import v8DetectionLoss
from ultralytics.cfg import get_cfg
import torch.nn.functional as F
from copy import copy, deepcopy

In [None]:
# Config

config_data = """
wandb:
  project: "EyeInTheSky_merged"
  group: "distillation"
data: "VisDrone.yaml"
k_samples: 5
distillation:
  teacher_model: "yolo12m.pt"
  temperature: 2.0
  alpha: 0.5
train:
  project: "EyeInTheSky"
  data: "VisDrone.yaml"
  pretrained: True
  patience: 5
  task: detect
  epochs: 500
  seed: 42
  plots: True
  exist_ok: False
  save: True
  save_period: 10
  val: True
  warmup_epochs: 10
  visualize: True
  show: True
  single_cls: False
  rect: False
  resume: False
  fraction: 1.0
  freeze: None
  cache: False
  verbose: False
  amp: True
val:
  project: "EyeInTheSky"
  half: True
  conf: 0.25
  iou: 0.6
  split: "test"
  rect: True
  plots: True
  visualize: True
"""

In [None]:
# Get device

def get_device() -> str:
    try:
        return 0 if torch.cuda.is_available() else "cpu"
    except Exception as e:
        print(f"Error setting device: {e}")

In [None]:
# Load config

# config = Config.load("../config/config.yaml")
config = yaml.safe_load(config_data)
config["train"].update({"device" : get_device()})

In [None]:
# # Wandb

# def wandb_start(key, run_config, wandb_config):
#     settings.update({"wandb": True})
#     if wandb.run is None: 
#         wandb.login(key=key, relogin=True)

In [None]:
# Dataset, Trainer, Validator

class VisDroneDataset(YOLODataset):
    """
    Custom dataset for VisDrone that merges pedestrian (0) and people (1) classes.
    Handles class remapping at the earliest possible stage.
    """
    
    # Define the merged names as a class attribute to be accessible from the trainer
    merged_names = {
        0: 'persona',
        1: 'bicicletta',
        2: 'auto',
        3: 'furgone',
        4: 'camion',
        5: 'triciclo',
        6: 'triciclo-tendato',
        7: 'autobus',
        8: 'motociclo'
    }
    
    def __init__(self, *args, **kwargs):
        # Initialize parent class with modified kwargs
        super().__init__(*args, **kwargs)
        
        # Log class mapping
        LOGGER.info(f"{colorstr('VisDroneDataset:')} Using merged classes: {self.merged_names}")
    
    def get_labels(self):
        """
        Load and process labels with class remapping.
        """
        # Get labels from parent method
        labels = super().get_labels()
        
        # Process statistics
        people_count = 0
        shifted_count = 0
        
        # Process labels to merge classes
        for i in range(len(labels)):
            cls = labels[i]['cls']
            
            if len(cls) > 0:
                # Count 'people' instances
                people_mask = cls == 1
                people_count += np.sum(people_mask)
                
                # Merge class 1 (people) into class 0 (pedestrian -> person)
                cls[people_mask] = 0
                
                # Shift classes > 1 down by 1
                gt1_mask = cls > 1
                shifted_count += np.sum(gt1_mask)
                cls[gt1_mask] -= 1
                
                # Store modified labels
                labels[i]['cls'] = cls
        
        # Now set correct class count and names for training
        if hasattr(self, 'data'):
            # Update names and class count
            self.data['names'] = self.merged_names
            self.data['nc'] = len(self.merged_names)
        
        # Log statistics
        person_count = sum(np.sum(label['cls'] == 0) for label in labels)
        LOGGER.info(f"\n{colorstr('VisDroneDataset:')} Remapped {people_count} 'people' instances to {self.merged_names[0]}")
        LOGGER.info(f"{colorstr('VisDroneDataset:')} Total 'persona' instances after merge: {person_count}")
        LOGGER.info(f"{colorstr('VisDroneDataset:')} Shifted {shifted_count} instances of other classes")
        
        return labels

class MergedClassDetectionTrainer(DetectionTrainer):
    """
    Custom trainer that uses VisDroneDataset for merged class training.
    """
    
    def build_dataset(self, img_path, mode="train", batch=None):
        """Build custom VisDroneDataset."""
        return VisDroneDataset(
            img_path=img_path,
            imgsz=self.args.imgsz,
            batch_size=batch or self.batch_size,
            augment=mode == "train",
            hyp=self.args,
            rect=self.args.rect if mode == "train" else True,
            cache=self.args.cache or None,
            single_cls=self.args.single_cls,
            stride=self.stride,
            pad=0.0 if mode == "train" else 0.5,
            prefix=colorstr(f"{mode}: "),
            task=self.args.task,
            classes=None,
            data=self.data,
            fraction=self.args.fraction if mode == "train" else 1.0,
        )
    
    def set_model_attributes(self):
        """Update model attributes for merged classes."""
        # First call parent method to set standard attributes
        super().set_model_attributes()
        
        # Then update model with the merged class names
        if hasattr(self.model, 'names'):
            # Use the merged names directly from the dataset class
            self.model.names = VisDroneDataset.merged_names
            self.model.nc = len(VisDroneDataset.merged_names)
            
            # Also update data dictionary
            if hasattr(self, 'data'):
                self.data['names'] = VisDroneDataset.merged_names
                self.data['nc'] = len(VisDroneDataset.merged_names)

class MergedClassDetectionValidator(DetectionValidator):
    """
    Custom validator that uses VisDroneDataset for validation/testing with merged classes.
    """
    
    def build_dataset(self, img_path, mode="val", batch=None):
        """Build custom VisDroneDataset for validation."""
        return VisDroneDataset(
            img_path=img_path,
            imgsz=self.args.imgsz,
            batch_size=batch or self.args.batch,
            augment=False,
            hyp=self.args,
            rect=True,
            cache=None,
            single_cls=self.args.single_cls,
            stride=self.stride,
            pad=0.5,
            prefix=colorstr(f"{mode}: "),
            task=self.args.task,
            classes=self.args.classes,
            data=self.data,
        )
    
    def set_model_attributes(self):
        """Update model attributes for merged classes if using a PyTorch model."""
        super().set_model_attributes()
        
        # Update model names if it's a PyTorch model (not for exported models)
        if hasattr(self.model, 'names') and hasattr(self.model, 'model'):
            self.model.names = VisDroneDataset.merged_names
            if hasattr(self.data, 'names'):
                self.data['names'] = VisDroneDataset.merged_names
                self.data['nc'] = len(VisDroneDataset.merged_names)

In [None]:
# class DistillationLoss(v8DetectionLoss):
#     def __init__(self, model, temperature=2.0, alpha=0.5, tal_topk=10):
#         """Initialize the distillation loss with model parameters."""
#         super().__init__(model, tal_topk=tal_topk)
#         self.teacher_model = model.teacher_model
#         self.temperature = temperature
#         self.alpha = alpha
#         self.step = 0

#     def __call__(self, preds, batch):
#         """Calculate original detection loss and distillation loss."""
#         # Calculate standard detection loss using parent class
#         original_loss, loss_items = super().__call__(preds, batch)
        
#         # Initialize distillation loss as zero tensor on same device
#         distillation_loss = torch.tensor(0.0, device=original_loss.device)
        
#         # Only calculate distillation loss if teacher model exists
#         if self.teacher_model is not None:
#             with torch.no_grad():
#                 # Get teacher predictions
#                 self.teacher_model.eval()
#                 teacher_preds = self.teacher_model(batch["img"])
            
#             # Get feature maps for distillation
#             student_feats = preds[1] if isinstance(preds, tuple) else preds
#             teacher_feats = teacher_preds[1] if isinstance(teacher_preds, tuple) else teacher_preds
            
#             # Calculate distillation loss on each feature map
#             for s_feat, t_feat in zip(student_feats, teacher_feats):
#                 # Apply temperature scaling
#                 s_logits = s_feat / self.temperature
#                 t_logits = t_feat / self.temperature
                
#                 # Apply log softmax to student and softmax to teacher
#                 # Flatten from channel dimension for KL div calculation
#                 s_log_prob = F.log_softmax(s_logits.flatten(2), dim=-1)
#                 t_prob = F.softmax(t_logits.flatten(2), dim=-1)
                
#                 # KL divergence loss
#                 feat_loss = F.kl_div(s_log_prob, t_prob, reduction='batchmean')
#                 distillation_loss += feat_loss * (self.temperature**2)  # Scale by temperature squared
        
#         # Combine losses using alpha weighting
#         combined_loss = (1 - self.alpha) * original_loss + self.alpha * distillation_loss
        
#         # # Increment step counter for consistent logging
#         # self.step += 1
        
#         if wandb.run is not None:
#             metrics = {
#                 "distill/combined_loss": combined_loss.item(),
#                 "distill/detection_loss": original_loss.item(),
#                 "distill/distillation_loss": distillation_loss.item(),
#                 "distill/box": loss_items[0].item(),
#                 "distill/cls": loss_items[1].item(),
#                 "distill/dfl": loss_items[2].item() if len(loss_items) > 2 else 0,
#                 "step": self.step  # or use epoch if that’s more appropriate
#             }
#             wandb.log(metrics, step=self.step)
        
#         return combined_loss, loss_items

In [None]:
class DistillationLoss(v8DetectionLoss):
    def __init__(self, model, temperature=2.0, alpha=0.5, tal_topk=10):
        super().__init__(model, tal_topk=tal_topk)
        self.teacher_model = model.teacher_model
        self.temperature = temperature
        self.alpha = alpha
        # Accumulators for epoch metrics
        self.epoch_metrics = {"combined_loss": 0.0, "detection_loss": 0.0, "distillation_loss": 0.0, "count": 1}

    def __call__(self, preds, batch):
        # Compute standard detection loss using the parent class
        original_loss, loss_items = super().__call__(preds, batch)
        distillation_loss = torch.tensor(0.0, device=original_loss.device)
        
        if self.teacher_model is not None:
            with torch.no_grad():
                self.teacher_model.eval()
                teacher_preds = self.teacher_model(batch["img"])
            
            # Get feature maps for distillation
            student_feats = preds[1] if isinstance(preds, tuple) else preds
            teacher_feats = teacher_preds[1] if isinstance(teacher_preds, tuple) else teacher_preds

            for s_feat, t_feat in zip(student_feats, teacher_feats):
                s_logits = s_feat / self.temperature
                t_logits = t_feat / self.temperature
                s_log_prob = F.log_softmax(s_logits.flatten(2), dim=-1)
                t_prob = F.softmax(t_logits.flatten(2), dim=-1)
                feat_loss = F.kl_div(s_log_prob, t_prob, reduction='batchmean')
                distillation_loss += feat_loss * (self.temperature**2)
        
        combined_loss = (1 - self.alpha) * original_loss + self.alpha * distillation_loss

        # Accumulate losses for this epoch
        self.epoch_metrics["combined_loss"] += combined_loss.item()
        self.epoch_metrics["detection_loss"] += original_loss.item()
        self.epoch_metrics["distillation_loss"] += distillation_loss.item()
        self.epoch_metrics["count"] += 1

        return combined_loss, loss_items

    # def log_epoch_metrics(self, epoch):
    #     if self.epoch_metrics["count"] > 0:
    #         avg_combined = self.epoch_metrics["combined_loss"] / self.epoch_metrics["count"]
    #         avg_detection = self.epoch_metrics["detection_loss"] / self.epoch_metrics["count"]
    #         avg_distill = self.epoch_metrics["distillation_loss"] / self.epoch_metrics["count"]

    #         metrics = {
    #             "distill/avg_combined_loss": avg_combined,
    #             "distill/avg_detection_loss": avg_detection,
    #             "distill/avg_distillation_loss": avg_distill,
    #             "epoch": epoch
    #         }
    #         wandb.log(metrics, step=epoch)

        # Reset accumulators for the next epoch
        self.epoch_metrics = {"combined_loss": 0.0, "detection_loss": 0.0, "distillation_loss": 0.0, "count": 0}


In [None]:
class DistillationModel(DetectionModel):
    def __init__(self, cfg='yolov12n.yaml', ch=3, nc=None, verbose=True, teacher_model=None, temperature=1.0, alpha=0.5):
        """Initialize distillation model with teacher model for knowledge transfer."""
        super().__init__(cfg, ch, nc, verbose)
        
        self.teacher_model = teacher_model
        self.temperature = temperature
        self.alpha = alpha
        
        # Freeze teacher model
        # if self.teacher_model is not None:
        #     self.teacher_model.eval()
        #     for param in self.teacher_model.parameters():
        #         param.requires_grad = False
            
        #     # Ensure teacher has same class names as student
        #     if hasattr(self, 'names') and hasattr(self.teacher_model, 'names'):
        #         self.teacher_model.names = self.names
            
        LOGGER.info(f"Initialized DistillationModel with temperature={self.temperature}, alpha={self.alpha}")

    def init_criterion(self):
        """Initialize the custom distillation loss criterion."""
        return DistillationLoss(self, temperature=self.temperature, alpha=self.alpha)
        
    def forward(self, x, *args, **kwargs):
        """Forward pass with optional teacher model evaluation."""
        # Standard forward pass
        result = super().forward(x, *args, **kwargs)
        
        # For validation, also get teacher predictions if requested
        if not self.training and kwargs.get('teacher_eval', False) and self.teacher_model is not None:
            with torch.no_grad():
                teacher_result = self.teacher_model(x, *args, **kwargs)
                # Return both student and teacher results
                return {'student': result, 'teacher': teacher_result}
        
        return result

In [None]:
class DistillationTrainer(MergedClassDetectionTrainer):
    distillation_config = None
    teacher_model_dir = None
    temperature = 1.0
    alpha = 0.5

    @staticmethod
    def set_config(distill_config):
        DistillationTrainer.distillation_config = distill_config

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """Initialize distillation trainer with teacher model."""
        if DistillationTrainer.distillation_config is not None:
            self.distillation_config = DistillationTrainer.distillation_config
            self.teacher_model_dir = DistillationTrainer.distillation_config["teacher_model"]
            self.temperature = DistillationTrainer.distillation_config["temperature"]
            self.alpha = DistillationTrainer.distillation_config["alpha"]

        self.args = get_cfg(cfg, overrides)
        super().__init__(cfg, overrides, _callbacks)

        self.teacher_model = self._setup_teacher_model()

        overrides['freeze'] = 21
        self.teacher_model.args = get_cfg(cfg, overrides)
        
        # self.add_callback("on_train_epoch_end", self.log_training_metrics)
        # self.add_callback("on_val_end", self.log_validation_metrics)
        # self.add_callback("on_val_end", self.log_teacher_metrics)
        self.add_callback("on_train_epoch_end", self.log_distillation_metrics)

    def _setup_teacher_model(self):
        """Load teacher model from checkpoint."""
        print(f"Loading teacher model from {self.teacher_model_dir}")
        
        teacher = YOLO(self.teacher_model_dir)
        teacher_model = teacher.model
        teacher_model.names = VisDroneDataset.merged_names
        
        return teacher_model

    def get_model(self, cfg=None, weights=None, verbose=True):
        """Create and return a DistillationModel with teacher model."""
        model = DistillationModel(
            cfg=cfg, 
            nc=self.data["nc"],
            verbose=verbose,
            teacher_model=self.teacher_model,
            temperature=self.temperature, 
            alpha=self.alpha
        )

        model.args = self.args
        
        if weights:
            model.load(weights)
            
        return model        

    def _setup_train(self, world_size):
        """Override parent method to handle teacher model freezing correctly"""
        # Call parent method to set up most things
        super()._setup_train(world_size)
        
        self.teacher_model.eval()

        print(f"self.teacher_model.parameters() {self.teacher_model.parameters()}")
        # Now ensure the teacher model is properly frozen
        if hasattr(self, 'teacher_model'):
            LOGGER.info("Making sure teacher model parameters are frozen")
            for param in self.teacher_model.parameters():
                param.requires_grad = False
        
        # Check if the teacher model is in the student model
        print(f"self.teacher_model.parameters() {self.teacher_model.parameters()}")
        if hasattr(self.model, 'teacher_model') and self.model.teacher_model is not None:
            LOGGER.info("Making sure teacher model in student is frozen")
            for param in self.teacher_model.parameters():
                param.requires_grad = False

    def log_distillation_metrics(self, trainer):
        """Log distillation metrics at the end of each epoch."""
        if hasattr(self.model, 'criterion') and hasattr(self.model.criterion, 'epoch_metrics'):
            metrics = self.model.criterion.epoch_metrics
            if metrics["count"] > 0:
                avg_combined = metrics["combined_loss"] / metrics["count"]
                avg_detection = metrics["detection_loss"] / metrics["count"]
                avg_distill = metrics["distillation_loss"] / metrics["count"]
                
                wandb.log({
                    "distill/avg_combined_loss": avg_combined,
                    "distill/avg_detection_loss": avg_detection,
                    "distill/avg_distillation_loss": avg_distill,
                    "epoch": self.epoch
                })
                
                # Reset metrics for next epoch
                self.model.criterion.epoch_metrics = {
                    "combined_loss": 0.0, 
                    "detection_loss": 0.0, 
                    "distillation_loss": 0.0, 
                    "count": 0
                }   
    # def log_training_metrics(self, trainer):
    #     """Log training metrics at the end of each epoch."""
    #     if wandb.run is not None:
    #         # The trainer already tracks losses
    #         training_metrics = {
    #             "box": self.loss_items[0].item(),
    #             "cls": self.loss_items[1].item(),
    #             "dfl": self.loss_items[2].item() if len(self.loss_items) > 2 else 0,
    #             "loss": self.loss.item(),
    #         }
    #         wandb.log(training_metrics, step=self.epoch)
            # self.model.criterion.log_epoch_metrics(self.epoch)
    
    # def log_validation_metrics(self, trainer):
    #     """Log validation metrics including teacher model performance."""
    #     if wandb.run is not None:
    #         # Student metrics
    #         student_metrics = {k: v for k, v in self.metrics.items()}
            
    #         # Teacher metrics (if you're validating the teacher model too)
    #         teacher_metrics = {}
    #         if hasattr(self, 'teacher_metrics'):
    #             teacher_metrics = {f"val/teacher_{k}": v for k, v in self.teacher_metrics.items()}
            
    #         # Log everything together with epoch
    #         wandb.log({
    #             **student_metrics,
    #             **teacher_metrics
    #         }, step=self.epoch)
       
    # def log_teacher_metrics(self, trainer):
    #     """Log teacher model performance metrics during validation."""
    #     if not hasattr(self, 'teacher_validator'):
    #         # Create a validator for the teacher model
    #         self.teacher_validator = MergedClassDetectionValidator(
    #             dataloader=self.test_loader,
    #             save_dir=self.save_dir / 'teacher_val',
    #             args=copy(self.args)
    #         )
        
    #     # Run validation on teacher model
    #     LOGGER.info("\n--- Evaluating Teacher Model ---")
    #     teacher_metrics = self.teacher_validator(model=self.teacher_model)
        
    #     # Log metrics to wandb
    #     if wandb.run is not None:
    #         teacher_log = {f"teacher/{k}": v for k, v in teacher_metrics.items()}
    #         student_log = {f"student/{k}": v for k, v in self.metrics.items()}
            
    #         # Calculate improvement/degradation for key metrics
    #         for k in ['mAP50', 'mAP50-95', 'precision', 'recall']:
    #             if k in teacher_metrics and k in self.metrics:
    #                 diff = self.metrics[k] - teacher_metrics[k]
    #                 teacher_log[f"diff/{k}"] = diff
            
    #         wandb.log({**teacher_log, **student_log})
            
    #         # Create a comparison table
    #         comparison_data = []
    #         for k in ['mAP50', 'mAP50-95', 'precision', 'recall']:
    #             if k in teacher_metrics and k in self.metrics:
    #                 comparison_data.append([
    #                     k, 
    #                     teacher_metrics[k], 
    #                     self.metrics[k], 
    #                     self.metrics[k] - teacher_metrics[k]
    #                 ])
            
    #         comparison_table = wandb.Table(
    #             data=comparison_data,
    #             columns=["Metric", "Teacher", "Student", "Difference"]
    #         )
    #         wandb.log({"comparison": comparison_table}, step=self.epoch)

In [None]:
key = get_wandb_key()
settings.update({"wandb": True})
wandb.login(key=key, relogin=True)

# wandb_start(key, config["train"], config["wandb"])

trial_config = config.copy()

trial_config["train"].update({
    "model": "yolo12n.pt",
    "pretrained": True,
    "imgsz": 640,
    "optimizer": "AdamW",
    "lr0": 0.005,
    "lrf": 0.001,
    "momentum": 0.937,
    "warmup_epochs": 20,
    "patience": 10,
    "batch": 16,
    "workers": 4,
    "box": 3.5,
    "cls": 0.3,
    "dfl": 1,
    "cos_lr": False,
})

run = wandb.init(
    project=trial_config["wandb"]["project"], 
    group=trial_config["wandb"]["group"]
)
wandb.log(trial_config["train"])

teacher_artifact = run.use_artifact('francescoperagine-universit-degli-studi-di-bari-aldo-moro/EyeInTheSky_merged/run_uao1xs95_model:v0', type='model')
teacher_artifact_dir = teacher_artifact.download()
teacher_model = f"{teacher_artifact_dir}/best.pt"

trial_config["distillation"].update({
    "teacher_model": teacher_model
})

DistillationTrainer.set_config(trial_config["distillation"])
trainer = DistillationTrainer(overrides=trial_config["train"])

results = trainer.train()

In [None]:
# Log final results
final_metrics = {f"final/{k}": v for k, v in results.items()}
wandb.log(final_metrics)

# Run final validation
LOGGER.info("Running final validation on student model...")
student_model = YOLO(str(trainer.best))  # Load best model
student_test_results = student_model.val(
    validator=MergedClassDetectionValidator,
    **config['val']
)

# Clean up
wandb.finish()
clear_cache()
# remove_models()

print(f"Training completed. Results: {results}")
print(f"Student test results: {student_test_results}")

In [None]:
# import glob
# import os

# def remove_models():
#     pt_files = glob.glob("*.pt")
#     print("Files to be removed:", pt_files)

#     for file in pt_files:
#         os.remove(file)