In [None]:
%pip install loguru==0.7.3 python-dotenv==1.0.1 PyYAML==6.0.2 torch==2.5.1 tqdm==4.67.1 typer==0.15.1 matplotlib==3.10.0 pyarrow==18.1.0 setuptools==75.1.0 protobuf==4.25.3 ultralytics==8.3.107 ray==2.43.0 albumentations==2.0.5 pandas

In [None]:
from pathlib import Path
from PIL import Image 
from ultralytics import YOLO, settings
from ultralytics.data.dataset import YOLODataset
from ultralytics.models.yolo.detect import DetectionTrainer, DetectionValidator
from ultralytics.nn.tasks import DetectionModel
from ultralytics.utils import colorstr, LOGGER
import glob
import io
import matplotlib.pyplot as plt
import numpy as np
import os
import os
import sys
import torch
import torchvision.utils as vutils
import wandb
import yaml

sys.dont_write_bytecode = True
settings.update({"wandb": True})

In [None]:
# Utils

def _get_wandb_key_colab() -> str:
    try:
        from google.colab import userdata # type: ignore

        if userdata.get("WANDB_API_KEY") is not None:
            return userdata.get("WANDB_API_KEY")
        else:
            raise ValueError("No WANDB key found")
    except:
        return None

def _get_wandb_env(path: Path) -> str:
    try:
        from dotenv import dotenv_values # type: ignore

        """Get W&B API key from Colab userdata or environment variable"""

        path = Path(path)
        if not path.exists():
            raise FileNotFoundError(f"Could not find .env file at {path}")

        print(f"Loading secrets from {path}")

        secrets = dotenv_values(path)
        print(f"Found keys: {list(secrets.keys())}")

        if "WANDB_API_KEY" not in secrets:
            raise KeyError(f"WANDB_API_KEY not found in {path}. Available keys: {list(secrets.keys())}")

        return secrets['WANDB_API_KEY']
    except:
        return None

def get_wandb_key(path: Path = "../.env") -> str:
    return _get_wandb_key_colab() if _get_wandb_key_colab() is not None else _get_wandb_env(path)

def get_device() -> str:
    try:
        return 0 if torch.cuda.is_available() else "cpu"
    except Exception as e:
        print(f"Error setting device: {e}")

def load_config(config_file: str) -> dict:
    """Load and return configuration from YAML file."""
    with open(config_file, "r") as f:
        return yaml.safe_load(f)

def remove_models():
    pt_files = glob.glob("*.pt")
    print("Files to be removed:", pt_files)

    for file in pt_files:
        os.remove(file)

In [None]:
ROOT_PATH = Path.cwd().parents[0]
WANDB_REPORT_PATH = os.path.join(ROOT_PATH, "reports")

In [None]:
# Config

config_data = """
wandb:
  project: "EyeInTheSky_test"
  dir: "reports"
model: "<wandb_artifact_source>" # wandb artifact to download
val:
  project: "EyeInTheSky"
  data: "VisDrone.yaml"
  name: "test" 
  half: True
  conf: 0.15
  iou: 0.6
  split: "test"
  rect: True
  plots: True
  visualize: True
"""

In [None]:
# Load config

# config = Config.load("../config/config.yaml")
config = yaml.safe_load(config_data)
config["val"].update({"device" : get_device()})

config["wandb"].update({
    "dir": WANDB_REPORT_PATH,
})

In [None]:
# Dataset, Trainer, Validator

class VisDroneDataset(YOLODataset):
    """
    Custom dataset for VisDrone that merges pedestrian (0) and people (1) classes.
    
    This dataset handler performs class remapping at the earliest stage of the pipeline
    by combining pedestrian and people into a single 'persona' class and shifting all 
    other class indices down by one. The merged class mapping is stored as a class 
    attribute for access during training and validation.
    
    The remapping happens in the get_labels() method which modifies the label tensors
    directly, ensuring all downstream processing uses the merged classes.
    
    Class attributes:
        merged_names (dict): New class mapping after merging pedestrian and people classes
    """
    
    # Define the merged names as a class attribute to be accessible from the trainer
    merged_names = {
        0: 'persona',
        1: 'bicicletta',
        2: 'auto',
        3: 'furgone',
        4: 'camion',
        5: 'triciclo',
        6: 'triciclo-tendato',
        7: 'autobus',
        8: 'motociclo'
    }
    
    def __init__(self, *args, **kwargs):
        # Initialize parent class with modified kwargs
        super().__init__(*args, **kwargs)
        
        # Log class mapping
        LOGGER.info(f"{colorstr('VisDroneDataset:')} Using merged classes: {self.merged_names}")
    
    def get_labels(self):
        """
        Load and process labels with class remapping.
        """
        # Get labels from parent method
        labels = super().get_labels()
        
        # Process statistics
        people_count = 0
        shifted_count = 0
        
        # Process labels to merge classes
        for i in range(len(labels)):
            cls = labels[i]['cls']
            
            if len(cls) > 0:
                # Count 'people' instances
                people_mask = cls == 1
                people_count += np.sum(people_mask)
                
                # Merge class 1 (people) into class 0 (pedestrian -> person)
                cls[people_mask] = 0
                
                # Shift classes > 1 down by 1
                gt1_mask = cls > 1
                shifted_count += np.sum(gt1_mask)
                cls[gt1_mask] -= 1
                
                # Store modified labels
                labels[i]['cls'] = cls
        
        # Now set correct class count and names for training
        if hasattr(self, 'data'):
            # Update names and class count
            self.data['names'] = self.merged_names
            self.data['nc'] = len(self.merged_names)
        
        # Log statistics
        person_count = sum(np.sum(label['cls'] == 0) for label in labels)
        LOGGER.info(f"\n{colorstr('VisDroneDataset:')} Remapped {people_count} 'people' instances to {self.merged_names[0]}")
        LOGGER.info(f"{colorstr('VisDroneDataset:')} Total 'persona' instances after merge: {person_count}")
        LOGGER.info(f"{colorstr('VisDroneDataset:')} Shifted {shifted_count} instances of other classes")
        
        return labels

class MergedClassDetectionTrainer(DetectionTrainer):
    """
    Custom YOLO trainer that uses the VisDroneDataset with merged classes.
    
    Extends the standard DetectionTrainer to work with the merged-class dataset.
    The key modifications are in build_dataset() to use VisDroneDataset instead of
    the default, and in set_model_attributes() to properly update the model's class
    names and count to match the merged dataset.
    
    This ensures that all aspects of training - from data loading to loss calculation -
    work consistently with the merged class structure.
    """
    
    def build_dataset(self, img_path, mode="train", batch=None):
        """Build custom VisDroneDataset."""
        return VisDroneDataset(
            img_path=img_path,
            imgsz=self.args.imgsz,
            batch_size=batch or self.batch_size,
            augment=mode == "train",
            hyp=self.args,
            rect=self.args.rect if mode == "train" else True,
            cache=self.args.cache or None,
            single_cls=self.args.single_cls,
            stride=self.stride,
            pad=0.0 if mode == "train" else 0.5,
            prefix=colorstr(f"{mode}: "),
            task=self.args.task,
            classes=None,
            data=self.data,
            fraction=self.args.fraction if mode == "train" else 1.0,
        )
    
    def get_model(self, cfg=None, weights=None, verbose=True):
        """Create and return a DetectionModel."""
        
        model = DetectionModel(
            cfg=cfg, 
            nc=self.data["nc"],
            verbose=verbose,
        )

        model.args = self.args
        
        if weights:
            LOGGER.info(f"Loading weights into model")
            model.load(weights)
            
        return model    
    
    def set_model_attributes(self):
        """Update model attributes for merged classes."""
        # First call parent method to set standard attributes
        super().set_model_attributes()
        
        # Then update model with the merged class names
        if hasattr(self.model, 'names'):
            # Use the merged names directly from the dataset class
            self.model.names = VisDroneDataset.merged_names
            self.model.nc = len(VisDroneDataset.merged_names)
            
            # Also update data dictionary
            if hasattr(self, 'data'):
                self.data['names'] = VisDroneDataset.merged_names
                self.data['nc'] = len(VisDroneDataset.merged_names)

class MergedClassDetectionValidator(DetectionValidator):
    """
    Custom validator for evaluating models trained on merged VisDrone classes.
    
    Works in tandem with MergedClassDetectionTrainer to ensure that validation
    uses the same class merging as training. The build_dataset() method creates
    VisDroneDataset instances for validation, and set_model_attributes() updates
    the model's class configuration to match the merged dataset.
    
    This allows for consistent metrics calculation across training and evaluation.
    """
    
    def build_dataset(self, img_path, mode="val", batch=None):
        """Build custom VisDroneDataset for validation."""
        return VisDroneDataset(
            img_path=img_path,
            imgsz=self.args.imgsz,
            batch_size=batch or self.args.batch,
            augment=False,
            hyp=self.args,
            rect=self.args.rect,
            cache=None,
            single_cls=self.args.single_cls,
            stride=self.stride,
            pad=0.5,
            prefix=colorstr(f"{mode}: "),
            task=self.args.task,
            classes=None,
            data=self.data,
        )
       
    def set_model_attributes(self):
        """Update model attributes for merged classes if using a PyTorch model."""
        super().set_model_attributes()
        
        # Then update model with the merged class names
        if hasattr(self.model, 'names'):
            # Use the merged names directly from the dataset class
            self.model.names = VisDroneDataset.merged_names
            self.model.nc = len(VisDroneDataset.merged_names)
            
            # Also update data dictionary
            if hasattr(self, 'data'):
                self.data['names'] = VisDroneDataset.merged_names
                self.data['nc'] = len(VisDroneDataset.merged_names)

In [None]:
# Remove models

def remove_models():
    pt_files = glob.glob("*.pt")
    print("Files to be removed:", pt_files)

    for file in pt_files:
        os.remove(file)
remove_models()

In [None]:
# WandB init

key = get_wandb_key()
wandb.login(key=key, relogin=True)

run = wandb.init(
    save_code=True,
    **config["wandb"],
)
wandb.log({**config["val"]})

In [None]:
# Artifact download 

artifact = run.use_artifact(config["model"], type='model')
artifact_dir = artifact.download()

In [None]:
model_file = artifact_dir + "/best.pt"
model = YOLO(model_file, task="detect")

In [None]:
def log_metrics(trainer):
    """
    Log the fitness metric to wandb.
    """
    if trainer.fitness is not None and trainer.fitness > 0.0:
      metrics_dict = {
          **trainer.metrics,
          "metrics/fitness": trainer.fitness,
      }
      wandb.log(metrics_dict)
      
# visualization has commit=false in on_train_epoch_end
model.add_callback("on_fit_epoch_end", log_metrics)

In [None]:
test_results = model.val(
    validator=MergedClassDetectionValidator,
    **config['val']
)

In [None]:
prefixed_results_dict = {f"test/{k}": v for k, v in test_results.results_dict.items()}

wandb.init(
    id=run.id,
    resume="must",
)

# Prepare all metrics in their respective dictionaries
metrics = {
    "test/metrics/fitness": test_results.fitness,
    **prefixed_results_dict
}

# Add speed metrics
for key, value in test_results.speed.items():
    metrics[f"speed/{key}"] = value

# Add class-wise mAP values
for i, map_value in enumerate(test_results.maps):
    if i in test_results.names:
        class_name = test_results.names[i]
        metrics[f"test/metrics/mAP/{class_name}"] = float(map_value)

# Log everything in a single call
wandb.log(metrics)

In [None]:
wandb.finish()
remove_models()