In [None]:
%pip install loguru==0.7.3 python-dotenv==1.0.1 PyYAML==6.0.2 torch==2.5.1 tqdm==4.67.1 typer==0.15.1 matplotlib==3.10.0 pyarrow==18.1.0 setuptools==75.1.0 protobuf==4.25.3 wandb==0.19.7 ultralytics==8.3.90 ray==2.43.0 albumentations==2.0.5

In [None]:
from pathlib import Path
from ultralytics import YOLO, settings
import locale
import os
import sys
import torch
import wandb
import yaml
from ray import tune
from typing import Dict
from ultralytics.data.dataset import YOLODataset
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.utils import colorstr, LOGGER
import numpy as np

sys.dont_write_bytecode = True
locale.getpreferredencoding = lambda: "UTF-8"
os.environ["RAY_TRAIN_V2_ENABLED"] = "0"

In [None]:
# Get device

def get_device() -> str:
    try:
        return 0 if torch.cuda.is_available() else "cpu"
    except Exception as e:
        print(f"Error setting device: {e}")

In [None]:
# Wandb init

def get_wandb_key_colab() -> str:
    try:
        from google.colab import userdata # type: ignore

        if userdata.get("WANDB_API_KEY") is not None:
            return userdata.get("WANDB_API_KEY")
        else:
            raise ValueError("No WANDB key found")
    except:
        return None

def get_wandb_env(path: Path) -> str:
    try:
        from dotenv import dotenv_values # type: ignore

        """Get W&B API key from Colab userdata or environment variable"""

        path = Path(path)
        if not path.exists():
            raise FileNotFoundError(f"Could not find .env file at {path}")

        print(f"Loading secrets from {path}")

        secrets = dotenv_values(path)
        print(f"Found keys: {list(secrets.keys())}")

        if "WANDB_API_KEY" not in secrets:
            raise KeyError(f"WANDB_API_KEY not found in {path}. Available keys: {list(secrets.keys())}")

        return secrets['WANDB_API_KEY']
    except:
        return None

def get_wandb_key(path: Path = "../.env") -> str:
    return get_wandb_key_colab() if get_wandb_key_colab() is not None else get_wandb_env(path)

In [None]:
# Dataset, Trainer, Validator

class VisDroneDataset(YOLODataset):
    """
    Custom dataset for VisDrone that merges pedestrian (0) and people (1) classes.
    Handles class remapping at the earliest possible stage.
    """
    
    # Define the merged names as a class attribute to be accessible from the trainer
    merged_names = {
        0: 'persona',
        1: 'bicicletta',
        2: 'auto',
        3: 'furgone',
        4: 'camion',
        5: 'triciclo',
        6: 'triciclo-tendato',
        7: 'autobus',
        8: 'motociclo'
    }
    
    def __init__(self, *args, **kwargs):
        # Store original data before initialization if it exists in kwargs
        self.original_data = kwargs.get('data', {}).copy() if 'data' in kwargs else None
        
        # Adjust data names before parent initialization to make verification pass
        if self.original_data and 'names' in self.original_data:
            # Create a temporary data object with 10 classes for verification
            temp_data = self.original_data.copy()
            # Ensure we have all 10 original class names for validation
            if len(temp_data.get('names', {})) != 10:
                temp_data['names'] = {
                    0: 'pedestrian',
                    1: 'people',
                    2: 'bicycle',
                    3: 'car',
                    4: 'van',
                    5: 'truck',
                    6: 'tricycle',
                    7: 'awning-tricycle',
                    8: 'bus',
                    9: 'motor'
                }
            # Replace data in kwargs
            kwargs['data'] = temp_data
        
        # Initialize parent class with modified kwargs
        super().__init__(*args, **kwargs)
        
        # Log class mapping
        LOGGER.info(f"{colorstr('VisDroneDataset:')} Using merged classes: {self.merged_names}")
    
    def get_labels(self):
        """
        Load and process labels with class remapping.
        """
        # Get labels from parent method
        labels = super().get_labels()
        
        # Process statistics
        people_count = 0
        shifted_count = 0
        
        # Process labels to merge classes
        for i in range(len(labels)):
            cls = labels[i]['cls']
            
            if len(cls) > 0:
                # Count 'people' instances
                people_mask = cls == 1
                people_count += np.sum(people_mask)
                
                # Merge class 1 (people) into class 0 (pedestrian -> person)
                cls[people_mask] = 0
                
                # Shift classes > 1 down by 1
                gt1_mask = cls > 1
                shifted_count += np.sum(gt1_mask)
                cls[gt1_mask] -= 1
                
                # Store modified labels
                labels[i]['cls'] = cls
        
        # Now set correct class count and names for training
        if hasattr(self, 'data'):
            # Update names and class count
            self.data['names'] = self.merged_names
            self.data['nc'] = len(self.merged_names)
        
        # Log statistics
        person_count = sum(np.sum(label['cls'] == 0) for label in labels)
        LOGGER.info(f"\n{colorstr('VisDroneDataset:')} Remapped {people_count} 'people' instances to {self.merged_names[0]}")
        LOGGER.info(f"{colorstr('VisDroneDataset:')} Total 'persona' instances after merge: {person_count}")
        LOGGER.info(f"{colorstr('VisDroneDataset:')} Shifted {shifted_count} instances of other classes")
        
        return labels

class MergedClassDetectionTrainer(DetectionTrainer):
    """
    Custom trainer that uses VisDroneDataset for merged class training.
    """
    
    def build_dataset(self, img_path, mode="train", batch=None):
        """Build custom VisDroneDataset."""
        return VisDroneDataset(
            img_path=img_path,
            imgsz=self.args.imgsz,
            batch_size=batch or self.batch_size,
            augment=mode == "train",
            hyp=self.args,
            rect=self.args.rect if mode == "train" else True,
            cache=self.args.cache or None,
            single_cls=self.args.single_cls,
            stride=self.stride,
            pad=0.0 if mode == "train" else 0.5,
            prefix=colorstr(f"{mode}: "),
            task=self.args.task,
            classes=None,
            data=self.data,
            fraction=self.args.fraction if mode == "train" else 1.0,
        )
    
    def set_model_attributes(self):
        """Update model attributes for merged classes."""
        # First call parent method to set standard attributes
        super().set_model_attributes()
        
        # Then update model with the merged class names
        if hasattr(self.model, 'names'):
            # Use the merged names directly from the dataset class
            self.model.names = VisDroneDataset.merged_names
            self.model.nc = len(VisDroneDataset.merged_names)
            
            # Also update data dictionary
            if hasattr(self, 'data'):
                self.data['names'] = VisDroneDataset.merged_names
                self.data['nc'] = len(VisDroneDataset.merged_names)

from ultralytics.models.yolo.detect import DetectionValidator

class MergedClassDetectionValidator(DetectionValidator):
    """
    Custom validator that uses VisDroneDataset for validation/testing with merged classes.
    """
    
    def build_dataset(self, img_path, mode="val", batch=None):
        """Build custom VisDroneDataset for validation."""
        return VisDroneDataset(
            img_path=img_path,
            imgsz=self.args.imgsz,
            batch_size=batch or self.args.batch,
            augment=False,  # no augmentation during validation
            hyp=self.args,
            rect=True,  # rectangular validation for better performance
            cache=None,
            single_cls=self.args.single_cls,
            stride=self.stride,
            pad=0.5,
            prefix=colorstr(f"{mode}: "),
            task=self.args.task,
            classes=self.args.classes,
            data=self.data,
        )
    
    def set_model_attributes(self):
        """Update model attributes for merged classes if using a PyTorch model."""
        super().set_model_attributes()
        
        # Update model names if it's a PyTorch model (not for exported models)
        if hasattr(self.model, 'names') and hasattr(self.model, 'model'):
            self.model.names = VisDroneDataset.merged_names
            if hasattr(self.data, 'names'):
                self.data['names'] = VisDroneDataset.merged_names
                self.data['nc'] = len(VisDroneDataset.merged_names)

In [None]:
config_data = """
wandb_project: "EyeInTheSky_merged_tuning"
data: "VisDrone.yaml"
train:
  model: "yolo12n.pt"
  project: "EyeInTheSky"
  data: "VisDrone.yaml"
  optimizer: "AdamW"
  task: detect
  epochs: 3
  batch: 8
  workers: 4
  seed: 42
  plots: True
  imgsz: 640
  exist_ok: False
  save: True
  save_period: 10
  val: True
  # warmup_epochs: 10
  visualize: False
  show: False
  single_cls: False # (bool) train multi-class data as single-class
  rect: False # (bool) rectangular training if mode='train' or rectangular validation if mode='val'
  cos_lr: False
  resume: False
  amp: True # (bool) Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check
  fraction: 1.0
  freeze: None
  cache: False
val:
  project: "EyeInTheSky"
  name: "YOLOv12-VisDrone-Validation"
  half: True
  conf: 0.25
  iou: 0.6
  split: "test"
  rect: True
  plots: True
  visualize: True
tune:
  project: "EyeInTheSky_tuned"
  name: "YOLOv12-VisDrone-Tuning"
  optimizer: "AdamW"
  batch: 16
  seed: 42
  val: False
  imgsz: 640 
  cache: False
  visualize: True
  show: True

"""
config = yaml.safe_load(config_data)

In [None]:
# Load config

# config = Config.load("../config/config.yaml")
config = yaml.safe_load(config_data)
config["train"].update({"device" : get_device()})

In [None]:
# Search space

search_space = {
    "lr0": tune.choice([1e-4, 1e-3]),
    "lrf": tune.choice([0.01, 0.1]),
    "momentum": tune.choice([0.8, 0.9, 0.95]),
    "weight_decay": tune.choice([0.0, 0.001]),
    "box": tune.uniform(3.0, 7.0),
    "cls": tune.uniform(0.5, 2.0),
    "dfl": tune.uniform(3.0, 6.0),
}

In [None]:
def train_with_merged_classes(ray_config: Dict, train_config: Dict=None):
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    """
    Training function for Ray Tune that uses a custom dataset.
    
    Parameters:
        ray_config (dict): Hyperparameters passed by Ray Tune.
        train_config (dict): Fixed training configuration.
    """
    # Merge fixed training config with Ray Tune hyperparameters
    # (Hyperparameters from ray_config will override those in train_config)
    if train_config is None:
        train_config = {}
    merged_config = {**train_config, **ray_config}
    
    model = YOLO(train_config["model"], verbose=False)
    
    setattr(sys.modules["ultralytics.data.dataset"], "VisDroneDataset", VisDroneDataset)
    
    # Train with the merged hyperparameters (only those keys that are expected by model.train should be present)
    results = model.tune(use_ray=True, grace_period=2, trainer=MergedClassDetectionTrainer, **merged_config)
    
    # Return a dictionary of metrics for Ray Tune to monitor (e.g. 'fitness')
    # metrics = {
    #     "precision": getattr(results, "precision", 0),
    #     "recall": getattr(results, "recall", 0),
    #     "mAP50-95": getattr(results, "mAP50_95", 0),
    #     "mAP50": getattr(results, "mAP50", 0),
    # }
    # tune.report(**metrics)

    if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
    return results

def tune_with_merged_classes(search_space, train_config, **kwargs):
    trainable_with_resources = tune.with_resources(train_with_merged_classes, {"cpu":8,"gpu":1})

    # Use tune.with_parameters to "bake in" the fixed training configuration.
    tuner = tune.Tuner(
        tune.with_parameters(trainable_with_resources, train_config=train_config),
        param_space=search_space,
        tune_config=tune.TuneConfig(
            num_samples=1,
            max_concurrent_trials=1,  # Run one at a time to avoid OOM
            mode="max"
        ),
        run_config=tune.RunConfig(**kwargs)
    )
    results = tuner.fit()
    return results


In [None]:
wandb_api_key = get_wandb_key()
wandb.login(key=wandb_api_key, relogin=True)
wandb.init(
    project=config["wandb_project"],
    name=f"{config['model']}_VisDrone_tune",
)
settings.update({"wandb": True})

results_path = os.path.abspath("../reports")
results = tune_with_merged_classes(
    search_space, 
    train_config=config["train"],
    storage_path=Path(results_path)
)

wandb.finish()
print(results)

In [None]:
# import shutil

# drive.mount('/content/drive')

# timestamp = time.strftime("%Y%m%d_%H%M%S")
# source_folder = '/content/EyeInTheSky/tune'
# destination_folder = f'/content/drive/My Drive/EyeInTheSky/tune_{timestamp}'

# shutil.copytree(source_folder, destination_folder)