In [2]:
import os
import sys

sys.path.append('/gpfs/helios/home/ploter/projects/MultiSensorDropout/')

from engine import evaluate

GDRIVE_BASE_PATH = '../not_tracked_dir/'
os.makedirs(GDRIVE_BASE_PATH, exist_ok=True)
CFG_EXPERIMENT_NAME = "output_yolo_v8_2025-04-11"

In [3]:
# from google.colab import drive
# import os

# drive.mount('/content/drive')

# # Define the base directory ON YOUR GOOGLE DRIVE where projects should be saved
# # IMPORTANT: Make sure this path exists in your Google Drive or create it.
# # Example: Create a folder named 'YOLOv8_Training' in the root of your 'MyDrive'
# GDRIVE_BASE_PATH = '/content/drive/MyDrive/YOLOv8_Training' # ADJUST THIS PATH AS NEEDED

# # Create the base directory on Drive if it doesn't exist
# os.makedirs(GDRIVE_BASE_PATH, exist_ok=True)
# print(f"Google Drive mounted. Using base path: {GDRIVE_BASE_PATH}")
# CFG_EXPERIMENT_NAME = "train_yolo_nano_24k_dataset3"

# !pip install -q ultralytics
# !pip install -q datasets
# !pip install -q torchmetrics

In [4]:
# Standard imports
import argparse
import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms.functional import resize, to_pil_image, to_tensor
from PIL import Image
from tqdm.auto import tqdm
from datasets import load_dataset
from types import SimpleNamespace # For config object

# Ultralytics imports
from ultralytics import YOLO
from ultralytics.models.yolo.detect.train import DetectionTrainer # To subclass
from ultralytics.utils import ops, checks


# --- NEW: Import Ultralytics LetterBox ---
try:
    # Common location for augmentations
    from ultralytics.data.augment import LetterBox
except ImportError:
    try:
        # Alternative location in some versions might be utils
        from ultralytics.utils.ops import LetterBox
        print("Note: Imported LetterBox from ultralytics.utils.ops")
    except ImportError:
        raise ImportError("Could not import LetterBox from 'ultralytics.data.augment' or 'ultralytics.utils.ops'. "
                          "Check Ultralytics version or installation.")
# ---



## DS

In [5]:
# === In Cell 2: HFYOLODataset Class Definition ===

# === Cell 2: HFYOLODataset Class Definition ===
import torch
import numpy as np
from torch.utils.data import Dataset

from tqdm.auto import tqdm
from datasets import load_dataset
import os
from pathlib import Path
from copy import deepcopy

# --- Ultralytics Imports ---
# Need YOLODataset only for accessing its collate_fn later
from ultralytics.data import YOLODataset
# Import necessary transform components used in build_transforms
from ultralytics.data.augment import LetterBox, Format, Compose # Standard components
# v8_transforms might require many more internal imports if used for augmentation
# from ultralytics.data.augment import v8_transforms
from ultralytics.utils import DEFAULT_CFG, LOGGER # Used in build_transforms logic
from ultralytics.utils.instance import Instances # Make sure this is imported


class HFYOLODataset(Dataset):

    # ... (Keep __init__ and _build_transforms_internal as before) ...
    def __init__(self, hf_dataset_split, imgsz=640, stride=32, augment=True, hyp=None, rect=False, task='detect', trust_remote_code=False, prefix=""):
        # ... (Same init as previous step) ...
        super().__init__() # Base Dataset init
        self.imgsz = imgsz; self.augment = augment; self.hyp = hyp if hyp is not None else DEFAULT_CFG; self.rect = rect; self.stride = stride; self.task = task; self.prefix = prefix; self.hf_dataset = hf_dataset_split; self.trust_remote_code = trust_remote_code
        self.use_segments = self.task == "segment"; self.use_keypoints = self.task == "pose"; self.use_obb = self.task == "obb"
        print(f"Initializing HFYOLODataset (Using Ultralytics Transforms)...")
        # Determine frame info & nc/names
        self.num_videos = len(self.hf_dataset); # ... (rest of frame count/dim logic) ...
        if self.num_videos == 0: raise ValueError("Empty dataset.")
        try: # Determine frame info
            first_vid = self.hf_dataset[0]; first_data = first_vid['video']; # ... get first_frames_shape_approx ...
            if not isinstance(first_data, np.ndarray):
                 if isinstance(first_data, list) and first_data: first_frame_np = np.array(first_data[0]); first_video_frames_shape_approx = (len(first_data), *first_frame_np.shape)
                 else: first_video_frames = np.array(first_data); first_video_frames_shape_approx = first_video_frames.shape
            else: first_video_frames_shape_approx = first_data.shape
            if len(first_video_frames_shape_approx) == 4: self.num_frames_per_video, self.frame_height, self.frame_width, _ = first_video_frames_shape_approx
            elif len(first_video_frames_shape_approx) == 3: self.num_frames_per_video, self.frame_height, self.frame_width = first_video_frames_shape_approx
            else: raise ValueError(f"Unexpected video dimensions: {first_video_frames_shape_approx}")
            if self.num_frames_per_video <= 0: raise ValueError(f"Non-positive frame count: {self.num_frames_per_video}")
            print(f"Frames/Video: {self.num_frames_per_video}, Dim: ({self.frame_height}, {self.frame_width})")
        except Exception as e: raise RuntimeError(f"Failed getting frame info: {e}") from e
        self.total_frames = self.num_videos * self.num_frames_per_video; print(f"Total train frames: {self.total_frames}")

        self.nc = 10
        self.num_classes = self.nc
        self.names = {i:str(i) for i in range(self.nc)}
        print(f"Discovered nc={self.nc}, names={self.names}")
        self.data = {'names': self.names, 'nc': self.nc}
        # Build Transforms
        self.transforms = self._build_transforms_internal(hyp=self.hyp)
        print(f"Transforms created: {self.transforms}")

    def _build_transforms_internal(self, hyp):
        # ... (Same as before, includes LetterBox and Format) ...
         if self.augment:
             LOGGER.warning("Augmentation enabled, but using simple LetterBox transform.")
             transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=getattr(hyp, 'scaleup', True), stride=self.stride)])
         else:
             transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False, stride=self.stride)])
         transforms.append(
             Format(bbox_format="xywh",
                    normalize=True, return_mask=self.use_segments, return_keypoint=self.use_keypoints, return_obb=self.use_obb, batch_idx=True, mask_ratio=getattr(hyp, 'mask_ratio', 4), mask_overlap=getattr(hyp, 'overlap_mask', True), bgr=getattr(hyp, 'bgr', 0.0) if self.augment else 0.0))
         return transforms

    def __len__(self):
        return self.total_frames

    def __getitem__(self, index):
        if index >= self.total_frames: raise IndexError(...)
        video_idx = index // self.num_frames_per_video
        frame_idx = index % self.num_frames_per_video
        try:
            video_example = self.hf_dataset[video_idx]
            # --- Load raw image (uint8 HWC NumPy) ---
            # ... (same logic) ...
            video_data = video_example['video']; img_np = np.array(video_data[frame_idx], dtype=np.uint8) if isinstance(video_data, list) else video_data[frame_idx].astype(np.uint8)
            if img_np.ndim == 2: img_np = np.stack([img_np]*3, axis=-1)
            elif img_np.shape[-1] == 1: img_np = np.concatenate([img_np]*3, axis=-1)
            original_h, original_w = img_np.shape[:2]

            # --- Load annotations -> ABSOLUTE PIXEL COORDS (xyxy) ---
            bboxes = video_example['bboxes'][frame_idx]
            labels = video_example['bboxes_labels'][frame_idx]
            bboxes_abs_xyxy_list = []
            cls_list_of_lists = [] # <<< Change variable name for clarity
            if bboxes and labels:
                 for bbox, label in zip(bboxes, labels):
                    x_min, y_min, w, h = map(float, bbox) # Ensure float
                    # Convert xywh to xyxy, clamp to image bounds
                    x1 = max(0.0, x_min)
                    y1 = max(0.0, y_min)
                    x2 = min(float(original_w), x_min + w)
                    y2 = min(float(original_h), y_min + h)
                    if x2 > x1 and y2 > y1: # Check for valid box area
                         bboxes_abs_xyxy_list.append([x1, y1, x2, y2])

                          # --- CHANGE 1: Append label as a list ---
                         cls_list_of_lists.append([int(label)])
                         # --- End Change 1 ---

            # === FIX: Ensure bboxes_np has shape [N, 4] even if N=0 ===
            if not bboxes_abs_xyxy_list:
                # If list is empty, create array with shape (0, 4)
                bboxes_np = np.zeros((0, 4), dtype=np.float32)
            else:
                # If list has items, convert normally
                bboxes_np = np.array(bboxes_abs_xyxy_list, dtype=np.float32)
            # cls_np is okay as shape (0,) if cls_list is empty
            # --- CHANGE 2: Create cls_np with shape (N, 1) or (0, 1) ---
            if not cls_list_of_lists:
                # Explicitly create shape (0, 1) for empty case
                cls_np = np.array([], dtype=np.int64).reshape(0, 1)
            else:
                # np.array([[0], [2], [0]]) directly creates shape (N, 1)
                cls_np = np.array(cls_list_of_lists, dtype=np.int64)
            # --- End Change 2 ---

            # --- Create Instances object ---
            # Instances expects bboxes in xyxy format by default if normalized=False
            segments = np.zeros((0, 1000, 2), dtype=np.float32)
            instances = Instances(bboxes=bboxes_np, segments=segments, bbox_format='xyxy', normalized=False)

            # === START MINIMAL CHANGE ===
            # Calculate simple ratio (height_ratio, width_ratio) based on target imgsz and original shape.
            # This mimics the structure added by YOLODataset.get_image_and_label before transforms.
            # Use float division.
            ratio_h = float(self.imgsz) / original_h
            ratio_w = float(self.imgsz) / original_w
            # Create the simple tuple (rh, rw)
            simple_ratio_pad = (ratio_h, ratio_w)
            # === END MINIMAL CHANGE ===

            # Format expects 'img', 'cls', 'instances'
            sample = {
                'img': img_np,           # uint8 HWC NumPy
                'instances': instances,  # Instances obj with abs pixel xyxy boxes
                'cls': cls_np,           # int64 [N] NumPy
                'ori_shape': (original_h, original_w), # Add original shape if needed by transforms
                'ratio_pad': simple_ratio_pad,
            }
            # ---------------------------------------------

            # --- Apply transforms ---
            # This pipeline includes LetterBox and Format
            transformed_sample = self.transforms(sample)
            # Output should have 'img' (CHW Tensor, uint8), 'cls', 'bboxes' (normalized xywh Tensor), 'batch_idx'
            # ------------------------

            # --- Add metadata needed by plotting ---
            transformed_sample['im_file'] = f"video_{video_idx}_frame_{frame_idx}.jpg"
            transformed_sample['ori_shape'] = (original_h, original_w) # Ensure ori_shape is present



            # --------------------------------------------

            return transformed_sample

        except Exception as e:
             # ... (error handling) ...
             print(f"Error in __getitem__ for index {index}: {e}")
             import traceback; traceback.print_exc()
             raise e


    # --- Keep compatibility properties ---
    @property
    def build_type(self): return 'build_detection_dataset'
    ## @property
    # def data(self): return {'names': self.names, 'nc': self.num_classes} # Trainer handles this now

## VALIDATOR

In [6]:
# === Add Custom Validator Class (e.g., Cell 4a) ===
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import LOGGER, emojis # For logging/errors if needed
from copy import copy
import torch

from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.ops import Profile
from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode

class CustomDetectionValidator(DetectionValidator):
    @smart_inference_mode()
    def __call__(self, trainer=None, model=None):
        """
        Execute validation process, running inference on dataloader and computing performance metrics.

        Args:
            trainer (object, optional): Trainer object that contains the model to validate.
            model (nn.Module, optional): Model to validate if not using a trainer.

        Returns:
            stats (dict): Dictionary containing validation statistics.
        """
        self.training = trainer is not None
        augment = self.args.augment and (not self.training)
        if self.training:
            self.device = trainer.device
            self.data = trainer.data
            # Force FP16 val during training
            self.args.half = self.device.type != "cpu" and trainer.amp
            model = trainer.ema.ema or trainer.model
            model = model.half() if self.args.half else model.float()
            # self.model = model
            self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
            self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
            model.eval()
        else:
            if str(self.args.model).endswith(".yaml") and model is None:
                LOGGER.warning("WARNING ⚠️ validating an untrained model YAML will result in 0 mAP.")
            callbacks.add_integration_callbacks(self)
            model = AutoBackend(
                weights=model or self.args.model,
                device=select_device(self.args.device, self.args.batch),
                dnn=self.args.dnn,
                data=self.args.data,
                fp16=self.args.half,
            )
            # self.model = model
            self.device = model.device  # update device
            self.args.half = model.fp16  # update half
            stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
            imgsz = check_imgsz(self.args.imgsz, stride=stride)
            if engine:
                self.args.batch = model.batch_size
            elif not pt and not jit:
                self.args.batch = model.metadata.get("batch", 1)  # export.py models default to batch-size 1
                LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})")

            if str(self.args.data).split(".")[-1] in {"yaml", "yml"}:
                self.data = {} # check_det_dataset(self.args.data)
            elif self.args.task == "classify":
                self.data = check_cls_dataset(self.args.data, split=self.args.split)
            else:
                raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))

            if self.device.type in {"cpu", "mps"}:
                self.args.workers = 0  # faster CPU val as time dominated by inference, not dataloading
            if not pt:
                self.args.rect = False
            self.stride = model.stride  # used in get_dataloader() for padding
            self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)

            model.eval()
            model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz))  # warmup

        self.run_callbacks("on_val_start")
        dt = (
            Profile(device=self.device),
            Profile(device=self.device),
            Profile(device=self.device),
            Profile(device=self.device),
        )
        bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
        self.init_metrics(de_parallel(model))
        self.jdict = []  # empty before each val
        for batch_i, batch in enumerate(bar):
            self.run_callbacks("on_val_batch_start")
            self.batch_i = batch_i
            # Preprocess
            with dt[0]:
                batch = self.preprocess(batch)

            # Inference
            with dt[1]:
                preds = model(batch["img"], augment=augment)

            # Loss
            with dt[2]:
                if self.training:
                    self.loss += model.loss(batch, preds)[1]

            # Postprocess
            with dt[3]:
                preds = self.postprocess(preds)

            self.update_metrics(preds, batch)
            if self.args.plots and batch_i < 3:
                self.plot_val_samples(batch, batch_i)
                self.plot_predictions(batch, preds, batch_i)

            self.run_callbacks("on_val_batch_end")
        stats = self.get_stats()
        self.check_stats(stats)
        self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
        self.finalize_metrics()
        self.print_results()
        self.run_callbacks("on_val_end")
        if self.training:
            model.float()
            results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
            return {k: round(float(v), 5) for k, v in results.items()}  # return results as 5 decimal place floats
        else:
            LOGGER.info(
                "Speed: {:.1f}ms preprocess, {:.1f}ms inference, {:.1f}ms loss, {:.1f}ms postprocess per image".format(
                    *tuple(self.speed.values())
                )
            )
            if self.args.save_json and self.jdict:
                with open(str(self.save_dir / "predictions.json"), "w", encoding="utf-8") as f:
                    LOGGER.info(f"Saving {f.name}...")
                    json.dump(self.jdict, f)  # flatten and save
                stats = self.eval_json(stats)  # update stats
            if self.args.plots or self.args.save_json:
                LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
            return stats

    # --- Keep the _prepare_batch override if it was needed for the 0-D tensor error ---
    # def _prepare_batch(self, si, batch):
    #      ... (implementation that handles 0-D cls tensor) ...

## TRAINER

In [7]:
# from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
# Import YOLODataset
from ultralytics.data.dataset import YOLODataset

class CustomDetectionTrainer(DetectionTrainer):
    def __init__(self, trust_remote_code=False, debug=False, *args, **kwargs): # Removed hf_dataset_identifier from signature
        print("CustomDetectionTrainer __init__ started...")
        self.custom_trust_code = trust_remote_code
        self.hf_train_split = None; self.hf_val_split = None
        self._debug = debug
        # We rely on args.data being the HF identifier now
        super().__init__(*args, **kwargs) # Calls overridden get_dataset
        print("CustomDetectionTrainer __init__ finished.")
        # Verification AFTER get_dataset has run (called by super init)
        # if not hasattr(self.args, 'nc') or not isinstance(self.args.nc, int) or self.args.nc <= 0:
        #      raise RuntimeError(f"Trainer self.args.nc not set correctly after get_dataset. Check get_dataset override.")
        # if not hasattr(self.args, 'names') or not isinstance(self.args.names, dict):
        #      raise RuntimeError(f"Trainer self.args.names not set correctly after get_dataset.")
        # print(f"Trainer initialized: nc={self.args.nc}, names={self.args.names}")

    # --- Override get_dataset ---
    def get_dataset(self):
        """
        Overrides the base method. Loads HF dataset using args.data.
        Instantiates HFYOLODataset (which discovers nc/names).
        Sets args.nc and args.names based on the loaded dataset info.
        Sets self.trainset and self.testset.
        """
        print("****** Custom get_dataset called ******")
        if not hasattr(self, 'args'): raise RuntimeError("Trainer arguments missing.")

        hf_dataset_identifier = 'Max-Ploter/detection-moving-mnist-easy' #self.args.data # Get HF ID from data arg
        trust_remote_code = self.custom_trust_code # Get from instance attribute

        if not hf_dataset_identifier or not isinstance(hf_dataset_identifier, str):
             raise ValueError(f"HF dataset identifier '{hf_dataset_identifier}' (from args.data) is invalid.")

        # --- Load HF Splits ---
        # ... (Load self.hf_train_split / self.hf_val_split using hf_dataset_identifier) ...
        if self.hf_train_split is None:
            try: self.hf_train_split = load_dataset(hf_dataset_identifier, split='train', trust_remote_code=trust_remote_code); print(f"Loaded train split: {len(self.hf_train_split)} samples.")
            except Exception as e: raise RuntimeError(f"Failed load HF train '{hf_dataset_identifier}': {e}") from e
        if self.hf_val_split is None:
            try: # Try val then test
                try: self.hf_val_split = load_dataset(hf_dataset_identifier, split='validation', trust_remote_code=trust_remote_code); print("Using 'validation' split.")
                except ValueError: self.hf_val_split = load_dataset(hf_dataset_identifier, split='test', trust_remote_code=trust_remote_code); print("Using 'test' split.")
                if not self.hf_val_split: raise ValueError("No val/test split found.")
                print(f"Loaded val/test split: {len(self.hf_val_split)} samples.")
            except Exception as e: # Fallback split train
                 print(f"Warning: Failed loading val/test: {e}. Splitting train.")
                 # ... (Split train logic) ...
                 if self.hf_train_split is None: raise RuntimeError("Train split not available.")
                 if len(self.hf_train_split) < 2: raise RuntimeError("Train split too small.")
                 splits = self.hf_train_split.train_test_split(test_size=0.2, seed=getattr(self.args, 'seed', 42))
                 self.hf_train_split, self.hf_val_split = splits['train'], splits['test']
                 print(f"Used 80/20 split of 'train' for train ({len(self.hf_train_split)})/validation ({len(self.hf_val_split)}).")

        if self._debug:
          # reduce train and val split sizes
          print("Reducing train/val split sizes for debugging...")
          self.hf_train_split = self.hf_train_split.select(range(1))
          self.hf_val_split = self.hf_val_split.select(range(1))

        # --- Instantiate HFYOLODatasets (which find nc/names) ---
        print("Instantiating HFYOLODataset for trainset...")
        self.trainset = HFYOLODataset(self.hf_train_split, imgsz=self.args.imgsz, trust_remote_code=trust_remote_code)
        print("Instantiating HFYOLODataset for testset (validation)...")
        self.testset = HFYOLODataset(self.hf_val_split, imgsz=self.args.imgsz, trust_remote_code=trust_remote_code)

        # --- Get nc/names FROM the instantiated dataset ---
        known_nc = getattr(self.trainset, 'num_classes', 0)
        known_names = getattr(self.trainset, 'names', {})
        if known_nc <= 0: # Fallback to testset if trainset failed
            known_nc = getattr(self.testset, 'num_classes', 0)
            known_names = getattr(self.testset, 'names', {})

        if known_nc <= 0:
            raise ValueError("Could not determine number of classes from loaded HFYOLODataset instances.")
        # --------------------------------------------------

        # --- Set nc and names on self.args ---
        # This is the CRUCIAL step - fulfilling the presumed responsibility
        print(f"Setting trainer args: nc={known_nc}, names={known_names}")
        # self.args.nc = known_nc
        # self.args.names = known_names

        # --- <<< NEW: Explicitly set self.data attribute >>> ---
        # This dictionary is expected by the original get_model method
        self.data = {'nc': known_nc, 'names': known_names}
        # Add other keys if get_model relies on them, e.g. 'path' (can be dummy)
        # self.data['path'] = '.' # Example if path is needed
        print(f"Set self.data attribute: {self.data}")
        # ----------------------------------------------------

        # Optional: Update dataset.data for consistency
        if hasattr(self.trainset, 'data'): self.trainset.data = self.data
        if hasattr(self.testset, 'data'): self.testset.data = self.data


        # ------------------------------------

        print(f"****** Custom get_dataset finished. Set args.nc={known_nc}. ******")
        return self.trainset, self.testset


    # --- get_dataloader and plot_training_labels overrides remain the same ---
    def get_dataloader(self, dataset=None, batch_size=16, rank=0, mode='train'):
        # ... (implementation is unchanged) ...
        print(f"****** Custom get_dataloader called for mode: {mode} ******")
        if dataset is None: dataset = self.trainset if mode == 'train' else self.testset
        if not isinstance(dataset, HFYOLODataset): print(f"Warning: Dataset type {type(dataset)}.")
        batch_size_arg = getattr(self.args, 'batch', batch_size); batch_size_to_use = batch_size_arg if isinstance(batch_size_arg, int) and batch_size_arg > 0 else batch_size
        if mode != 'train': batch_size_to_use *= 2
        workers = getattr(self.args, 'workers', 0); shuffle = (mode == 'train')
        print(f"Creating DataLoader for mode '{mode}' with batch_size={batch_size_to_use}, workers={workers}...")
        loader = DataLoader(dataset, batch_size=batch_size_to_use, shuffle=shuffle, num_workers=workers, pin_memory=True,
                            collate_fn=YOLODataset.collate_fn
                            )
        print(f"****** Custom DataLoader created for {mode} ******")
        return loader

    def plot_training_labels(self): print("Skipping plot_training_labels in Custom Trainer."); pass

    def get_validator(self):
        """Returns a CustomDetectionValidator instance."""
        print("****** Custom get_validator called (Returning CustomDetectionValidator) ******")
        # Ensure validation dataloader exists
        if not hasattr(self, 'test_loader') or self.test_loader is None:
             print("Creating validation dataloader within get_validator...")
             if not hasattr(self, 'testset') or self.testset is None: raise RuntimeError("Validation dataset missing.")
             val_batch_size = getattr(self.args, 'batch', 16) * 2
             self.test_loader = self.get_dataloader(self.testset, batch_size=val_batch_size, mode='val')

        validator_args = copy(self.args) # Pass copy of trainer args

        # Instantiate OUR custom validator
        validator = CustomDetectionValidator( # Use the custom class
            dataloader=self.test_loader,
            save_dir=self.save_dir,
            args=validator_args,
            # --- FIX: Pass the main self.callbacks dict directly ---
            _callbacks=self.callbacks
            # --- End Fix ---
        )
        # Link model and data dict
        validator.model = self.model
        validator.data = self.data
        print(f"****** Custom get_validator finished. Validator created. Using data: {validator.data} ******")
        return validator


## TRAIN

In [8]:
# === Configuration Cell in Jupyter Notebook ===
from types import SimpleNamespace
import os

# --- Define YAML Content and Filename ---
FAKE_YAML_FILENAME = "fake.yaml" # Name of the file to create
YAML_SHIM_CONTENT = """
# Minimal YAML to satisfy Ultralytics checks during validation
path: ./ignored_path # Ignored, but path key might be checked
train: images/train   # Ignored
val: images/val       # Ignored

# --- Important Part ---
nc: 10 # Your known number of classes
names:
  0: '0'
  1: '1'
  2: '2'
  3: '3'
  4: '4'
  5: '5'
  6: '6'
  7: '7'
  8: '8'
  9: '9'
# --- End Important Part ---
"""
# ---------------------------------------

# --- Configuration Variables ---
CFG_HF_DATASET_IDENTIFIER = "Max-Ploter/detection-moving-mnist-easy" # Your HF dataset path/name
CFG_MODEL_NAME = 'yolov8n.pt'
# CFG_NUM_CLASSES = 10  # No longer needed here
CFG_EPOCHS = 100
CFG_BATCH_SIZE = 64
CFG_IMG_SIZE = 320
CFG_WORKERS = 2
CFG_TRUST_REMOTE_CODE = True # Custom flag for trainer
CFG_PROJECT_NAME = "yolo_hf_custom_trainer"

CFG_PLOTS = True
CFG_SAVE_PERIOD = 1 # <<<<<<< ADDED: Set save_period to 1


# --- Construct the path to the last checkpoint ---
# !!! This path MUST exist from your previous training run !!!
checkpoint_dir = os.path.join(GDRIVE_BASE_PATH, CFG_EXPERIMENT_NAME, 'weights')
resume_checkpoint_path = os.path.join(checkpoint_dir, 'last.pt')

print(f"Attempting to resume training from: {resume_checkpoint_path}")

# --- Check if the checkpoint exists ---
if not os.path.exists(resume_checkpoint_path):
    raise FileNotFoundError(f"Checkpoint file not found at: {resume_checkpoint_path}. Cannot resume.")
else:
    print("Checkpoint file found.")
    # --- Set the model path to the checkpoint for resuming ---
    CFG_MODEL_TO_LOAD = resume_checkpoint_path


# --- Prepare config_args with standard args ---
# 'data' now holds the HF identifier
# 'nc' is NOT set here
config_args = SimpleNamespace(
    model = CFG_MODEL_NAME,
    data = FAKE_YAML_FILENAME, #CFG_HF_DATASET_IDENTIFIER, # Pass HF ID as data arg
    epochs = CFG_EPOCHS,
    batch = CFG_BATCH_SIZE,
    imgsz = CFG_IMG_SIZE,
    project = GDRIVE_BASE_PATH,
    name = CFG_EXPERIMENT_NAME,
    workers = CFG_WORKERS,
    device = None,
    plots = CFG_PLOTS,
    resume=CFG_MODEL_TO_LOAD,
    # trust_remote_code is handled separately
    # nc is handled by get_dataset
)

# with open(FAKE_YAML_FILENAME, 'w') as f:
#     f.write(YAML_SHIM_CONTENT)
# print(f"Created fake YAML file: {FAKE_YAML_FILENAME}")

# --- Sanity Check ---
if config_args.data == "your_huggingface_dataset_identifier": # Or similar check
     print(f"🛑 Error: Please set HF dataset identifier in CFG_HF_DATASET_IDENTIFIER.")

Attempting to resume training from: ../not_tracked_dir/output_yolo_v8_2025-04-11/weights/last.pt
Checkpoint file found.


In [10]:
# === Execution Cell in Jupyter Notebook ===
# Ensure CustomDetectionTrainer etc are defined

if config_args.data != "your_huggingface_dataset_identifier": # Or similar check
    print("🚀 Instantiating CustomDetectionTrainer manually...")
    try:
        trainer = CustomDetectionTrainer(
            overrides=vars(config_args), # Contains standard args + HF ID in 'data'
            # Pass ONLY custom args directly
            trust_remote_code=CFG_TRUST_REMOTE_CODE,
            debug=False
        )

        print("\n🚀 Starting trainer.train()...")
        trainer.train()
        print("\n✅ Training finished successfully!")
        # print(f"   Best model weights: {trainer.best}")

    except Exception as e:
        print(f"\n❌ Training failed with an error:")
        import traceback
        traceback.print_exc()
else:
    print("⚠️ Training not started. Please set HF dataset identifier.")

🚀 Instantiating CustomDetectionTrainer manually...
CustomDetectionTrainer __init__ started...
Ultralytics 8.3.109 🚀 Python-3.11.11 torch-2.5.1 CUDA:0 (Tesla V100-PCIE-16GB, 16144MiB)

❌ Training failed with an error:


Traceback (most recent call last):
  File "/gpfs/helios/home/ploter/.conda/envs/sensor_dropout/lib/python3.11/pathlib.py", line 1116, in mkdir
    os.mkdir(self, mode)
FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/YOLOv8_Training/train_yolo_nano_24k_dataset3/weights'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/gpfs/helios/home/ploter/.conda/envs/sensor_dropout/lib/python3.11/pathlib.py", line 1116, in mkdir
    os.mkdir(self, mode)
FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/YOLOv8_Training/train_yolo_nano_24k_dataset3'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/gpfs/helios/home/ploter/.conda/envs/sensor_dropout/lib/python3.11/pathlib.py", line 1116, in mkdir
    os.mkdir(self, mode)
FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/YOLOv8_Training'

## EVAL

In [9]:
import os
import torch
from types import SimpleNamespace
from ultralytics import YOLO  # Import the YOLO class
from ultralytics.data.dataset import YOLODataset # For collate_fn
from torch.utils.data import DataLoader
from datasets import load_dataset # Assuming this is used in HFYOLODataset

# --- 1. Define the Checkpoint to Evaluate ---
# Usually, you evaluate 'best.pt', but after 1 epoch, 'last.pt' is more likely
checkpoint_dir = os.path.join(GDRIVE_BASE_PATH, CFG_EXPERIMENT_NAME, 'weights')
checkpoint_to_eval = os.path.join(checkpoint_dir, 'best.pt') # Standard choice

# Check if the chosen checkpoint exists
best_checkpoint_path = os.path.join(checkpoint_dir, 'best.pt')
if os.path.exists(best_checkpoint_path):
    checkpoint_to_eval = best_checkpoint_path
    print(f"Found 'best.pt', evaluating it: {checkpoint_to_eval}")
elif os.path.exists(os.path.join(checkpoint_dir, 'last.pt')):
    checkpoint_to_eval = os.path.join(checkpoint_dir, 'last.pt')
    print(f"Found 'last.pt', evaluating it: {checkpoint_to_eval}")
else:
    raise FileNotFoundError(f"Neither 'best.pt' nor 'last.pt' found in {checkpoint_dir}")

if not os.path.exists(checkpoint_to_eval):
     raise FileNotFoundError(f"Checkpoint not found: {checkpoint_to_eval}")
print(f"Selected checkpoint for evaluation: {checkpoint_to_eval}")

# --- 2. Load the Trained Model ---
print("Loading trained model checkpoint...")
# This loads the model weights into the YOLO architecture
#model = YOLO(checkpoint_to_eval)
#print("Model loaded.")

# --- 3. Load the Test Dataset ---
print("Loading 'test' split from Hugging Face...")
try:
    hf_test_split = load_dataset(CFG_HF_DATASET_IDENTIFIER, split='test', trust_remote_code=CFG_TRUST_REMOTE_CODE)
    print(f"Loaded test split: {len(hf_test_split)} samples.")

    if True: # Debug
      print("Reducing test split sizes for debugging...")
      hf_test_split = hf_test_split.select(range(5000))
      print(f"Reduced test split to {len(hf_test_split)} samples.")

    # Create the HFYOLODataset instance for the test split
    # Ensure HFYOLODataset class is defined/imported
    test_dataset = HFYOLODataset(hf_test_split, imgsz=CFG_IMG_SIZE, trust_remote_code=CFG_TRUST_REMOTE_CODE)
    print("Test dataset created.")
except Exception as e:
    raise RuntimeError(f"Failed to load or process test split: {e}") from e

# --- 4. Create the Test DataLoader ---
print("Creating test dataloader...")
# Use a larger batch size for validation/testing is common
test_batch_size = CFG_BATCH_SIZE * 2
test_loader = DataLoader(
    test_dataset,
    batch_size=test_batch_size,
    shuffle=False, # No need to shuffle for evaluation
    num_workers=CFG_WORKERS,
    pin_memory=True,
    collate_fn=YOLODataset.collate_fn # Use the standard collate_fn
)
print("Test dataloader created.")

# --- 5. Prepare Validator Arguments and Instantiate Validator ---
print("Preparing validator...")
# Define a directory to save test evaluation results/plots
test_eval_save_dir = os.path.join(GDRIVE_BASE_PATH, CFG_EXPERIMENT_NAME, f'test_eval_{os.path.basename(checkpoint_to_eval).split(".")[0]}')
os.makedirs(test_eval_save_dir, exist_ok=True)
print(f"Test evaluation results will be saved in: {test_eval_save_dir}")

# Get nc and names from the test dataset (should match training)
test_nc = getattr(test_dataset, 'num_classes', 0)
test_names = getattr(test_dataset, 'names', {})
if test_nc <= 0:
    raise ValueError("Could not determine number of classes from test dataset.")
test_data_dict = {'nc': test_nc, 'names': test_names}
print(f"Test data info: nc={test_nc}, names={test_names}")

# Prepare minimal arguments needed by the Validator
# Check your CustomDetectionValidator's __init__ and methods if it needs more args
validator_args = SimpleNamespace(
    data = FAKE_YAML_FILENAME,
    save_dir=test_eval_save_dir,
    device=None, # Use the device determined during training setup
    batch=test_batch_size,
    imgsz=CFG_IMG_SIZE,
    split='test',    # Specify the split being evaluated
    task='detect',   # Specify the task
    plots=True,      # Enable saving plots (e.g., confusion matrix, PR curves)
    save_json=False, # Set True if you need COCO format JSON output
    # save_hybrid=False,
    conf=0.001,      # Confidence threshold (adjust if needed)
    iou=0.6,         # IoU threshold for NMS (adjust if needed)
    max_det=10,
    # data=FAKE_YAML_FILENAME, # Maybe needed if validator parses it, but we set validator.data directly
    project=GDRIVE_BASE_PATH, # Not directly used by validator usually, but good practice
    name=os.path.basename(test_eval_save_dir), # Logical name for this eval run
)

# Instantiate your CustomDetectionValidator
# Ensure the CustomDetectionValidator class is defined/imported
validator = CustomDetectionValidator(
    dataloader=test_loader,
    save_dir=Path(test_eval_save_dir),
    args=validator_args,
    _callbacks={} # Pass empty callbacks if not needed for pure evaluation
)

# Link the loaded model (usually the internal torch model) and data info
# The YOLO object holds the model in '.model'
# validator.model = model
validator.data = test_data_dict # Provide nc/names directly

print("Validator instance created and configured.")

Found 'best.pt', evaluating it: ../not_tracked_dir/output_yolo_v8_2025-04-11/weights/best.pt
Selected checkpoint for evaluation: ../not_tracked_dir/output_yolo_v8_2025-04-11/weights/best.pt
Loading trained model checkpoint...
Loading 'test' split from Hugging Face...
Loaded test split: 10000 samples.
Reducing test split sizes for debugging...
Reduced test split to 5000 samples.
Initializing HFYOLODataset (Using Ultralytics Transforms)...
Frames/Video: 20, Dim: (128, 128)
Total train frames: 100000
Discovered nc=10, names={0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9'}
Augmentation enabled, but using simple LetterBox transform.
Transforms created: Compose(<ultralytics.data.augment.LetterBox object at 0x150fcdd38b90>, <ultralytics.data.augment.Format object at 0x1510dd857090>)
Test dataset created.
Creating test dataloader...
Test dataloader created.
Preparing validator...
Test evaluation results will be saved in: ../not_tracked_dir/output_yolo_v8_2025-04

In [None]:
# --- 6. Run Evaluation ---
print("Starting evaluation on the test set...")
try:
    # The validator instance is callable and runs the evaluation loop
    results = validator(model = checkpoint_to_eval)
    print("Evaluation finished.")

    # --- 7. Print Results ---
    # The 'results' object (often a dictionary) contains the metrics
    # Refer to ultralytics documentation or inspect the 'results' keys for specifics
    print("\n--- Test Set Evaluation Metrics ---")
    # Common metrics for detection:
    map50_95 = results.maps[0] # mAP50-95 for class 0 (or overall if only 1 class reported directly)
    map50 = results.maps[50] # mAP50 for class 0 (or overall)
    print(f"mAP50-95: {map50_95:.4f}")
    print(f"mAP50:    {map50:.4f}")
    # Print all metrics found
    print("\nFull metrics dictionary:")
    print(results.metrics_data) # Or just print(results) depending on object type

except Exception as e:
    print(f"\n❌ Evaluation failed with an error:")
    import traceback
    traceback.print_exc()

Found 'best.pt', evaluating it: ../not_tracked_dir/output_yolo_v8_2025-04-11/weights/best.pt
Selected checkpoint for evaluation: ../not_tracked_dir/output_yolo_v8_2025-04-11/weights/best.pt
Loading trained model checkpoint...
Loading 'test' split from Hugging Face...


Generating train split: 24000 examples [00:06, 3754.70 examples/s]
Generating test split: 10000 examples [00:02, 4042.25 examples/s]

Loaded test split: 10000 samples.
Initializing HFYOLODataset (Using Ultralytics Transforms)...
Frames/Video: 20, Dim: (128, 128)
Total train frames: 200000
Discovered nc=10, names={0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9'}
Augmentation enabled, but using simple LetterBox transform.





Transforms created: Compose(<ultralytics.data.augment.LetterBox object at 0x154bae4eda30>, <ultralytics.data.augment.Format object at 0x154bae48f580>)
Test dataset created.
Creating test dataloader...
Test dataloader created.
Preparing validator...
Test evaluation results will be saved in: ../not_tracked_dir/output_yolo_v8_2025-04-11/test_eval_best
Test data info: nc=10, names={0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9'}
Validator instance created and configured.
Starting evaluation on the test set...
Ultralytics 8.3.107 🚀 Python-3.9.12 torch-2.5.1+cu124 CUDA:0 (Tesla V100-PCIE-16GB, 16144MiB)
Model summary (fused): 72 layers, 3,007,598 parameters, 0 gradients, 8.1 GFLOPs


                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95):   0%|          | 2/1563 [00:07<1:25:49,  3.30s/it]

Downloading https://ultralytics.com/assets/Arial.ttf to '/gpfs/helios/home/ploter/.config/Ultralytics/Arial.ttf'...



100%|██████████| 755k/755k [00:00<00:00, 14.3MB/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 1563/1563 [28:50<00:00,  1.11s/it]


                   all     200000     851740      0.968      0.919      0.971      0.939
                     0      65391      82934      0.966      0.944       0.98      0.959
                     1      73283      95499      0.962      0.909      0.966      0.922
                     2      70732      91277      0.981      0.921      0.972      0.928
                     3      67769      86001      0.964       0.92      0.972      0.938
                     4      66242      82726      0.967      0.926      0.972       0.94
                     5      61559      75659      0.968      0.921      0.973      0.938
                     6      63775      79171      0.975       0.92      0.974       0.95
                     7      69464      87833      0.959      0.908      0.967      0.936
                     8      67065      85789      0.973      0.923      0.974      0.943
                     9      66988      84851      0.967      0.898      0.964       0.94
Speed: 0.0ms preproce

Traceback (most recent call last):
  File "/tmp/ipykernel_3175597/2828092570.py", line 125, in <module>
    map50_95 = results.maps[0] # mAP50-95 for class 0 (or overall if only 1 class reported directly)
AttributeError: 'dict' object has no attribute 'maps'


Found 'best.pt', evaluating it: ../not_tracked_dir/output_yolo_v8_2025-04-11/weights/best.pt
Selected checkpoint for evaluation: ../not_tracked_dir/output_yolo_v8_2025-04-11/weights/best.pt
Loading trained model checkpoint...
Loading 'test' split from Hugging Face...
Generating train split: 24000 examples [00:06, 3754.70 examples/s]
Generating test split: 10000 examples [00:02, 4042.25 examples/s]
Loaded test split: 10000 samples.
Initializing HFYOLODataset (Using Ultralytics Transforms)...
Frames/Video: 20, Dim: (128, 128)
Total train frames: 200000
Discovered nc=10, names={0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9'}
Augmentation enabled, but using simple LetterBox transform.

Transforms created: Compose(<ultralytics.data.augment.LetterBox object at 0x154bae4eda30>, <ultralytics.data.augment.Format object at 0x154bae48f580>)
Test dataset created.
Creating test dataloader...
Test dataloader created.
Preparing validator...
Test evaluation results will be saved in: ../not_tracked_dir/output_yolo_v8_2025-04-11/test_eval_best
Test data info: nc=10, names={0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9'}
Validator instance created and configured.
Starting evaluation on the test set...
Ultralytics 8.3.107 🚀 Python-3.9.12 torch-2.5.1+cu124 CUDA:0 (Tesla V100-PCIE-16GB, 16144MiB)
Model summary (fused): 72 layers, 3,007,598 parameters, 0 gradients, 8.1 GFLOPs
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95):   0%|          | 2/1563 [00:07<1:25:49,  3.30s/it]
Downloading https://ultralytics.com/assets/Arial.ttf to '/gpfs/helios/home/ploter/.config/Ultralytics/Arial.ttf'...

100%|██████████| 755k/755k [00:00<00:00, 14.3MB/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 1563/1563 [28:50<00:00,  1.11s/it]
                   all     200000     851740      0.968      0.919      0.971      0.939
                     0      65391      82934      0.966      0.944       0.98      0.959
                     1      73283      95499      0.962      0.909      0.966      0.922
                     2      70732      91277      0.981      0.921      0.972      0.928
                     3      67769      86001      0.964       0.92      0.972      0.938
                     4      66242      82726      0.967      0.926      0.972       0.94
                     5      61559      75659      0.968      0.921      0.973      0.938
                     6      63775      79171      0.975       0.92      0.974       0.95
                     7      69464      87833      0.959      0.908      0.967      0.936
                     8      67065      85789      0.973      0.923      0.974      0.943
                     9      66988      84851      0.967      0.898      0.964       0.94
Speed: 0.0ms preprocess, 0.3ms inference, 0.0ms loss, 1.0ms postprocess per image
Results saved to ../not_tracked_dir/output_yolo_v8_2025-04-11/test_eval_best
Evaluation finished.

--- Test Set Evaluation Metrics ---

❌ Evaluation failed with an error:

## EVAL_OWN

In [10]:
import torch
import torchmetrics
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import numpy as np
import gc


# Assuming necessary imports from your environment are available:
# HFYOLODataset class (your original version)
# yolo_hf_collate_fn (or YOLODataset.collate_fn if compatible)
# YOLO class from ultralytics
# load_dataset from datasets
# ops from ultralytics.utils
# YOLODataset from ultralytics.data.dataset (for collate_fn reference)

from ultralytics import YOLO
from ultralytics.utils import ops
from ultralytics.data.dataset import YOLODataset # For collate_fn if using standard
from datasets import load_dataset # Assuming you use this

N = 50 # Collect garbage every 50 batches (adjust as needed)


# Make sure your original HFYOLODataset class definition is available
# Example placeholder - replace with your actual class definition if needed
# from your_dataset_module import HFYOLODataset, yolo_hf_collate_fn

# --- Custom Evaluation Loop for YOLO ---
def evaluate_yolo_with_torchmetrics(model, dataloader, device, conf_thres=0.001, iou_thres=0.5):
    """
    Evaluates a YOLO model using torchmetrics mAP with a custom loop.

    Args:
        model: The loaded YOLO model object.
        dataloader: DataLoader for the test set (using original HFYOLODataset).
        device: The torch device ('cuda' or 'cpu').
        conf_thres (float): Confidence threshold for predictions.
        iou_thres (float): IoU threshold for NMS used in model.predict.

    Returns:
        dict: A dictionary containing the computed mAP metrics.
    """

    print(f"conf_thres: {conf_thres}")
    print(f"iou_thres: {iou_thres}")

    # Ensure model is on the correct device and in evaluation mode
    model.to(device)
    model.eval()

    # Initialize the mAP metric
    # Ensure box_format matches the format of boxes you provide ('xyxy')
    map_metric = torchmetrics.detection.MeanAveragePrecision(box_format='xyxy', iou_type='bbox').to(device)

    print(f"Starting custom evaluation loop on device: {device}")
    progress_bar = tqdm(enumerate(dataloader), desc="Evaluating")

    for batcch_index_i, batch in progress_bar:
        # Move batch data to the device
        # Assumes collate_fn provides 'img', 'bboxes' (norm xywh), 'cls', 'batch_idx', 'ori_shape'
        # Use try-except for robustness against missing keys if collate fn varies
        try:
            samples = batch['img'].to(device).float() / 255.0 # Normalize images
            gt_bboxes_norm = batch['bboxes'].to(device) # Normalized xywh
            # Ensure gt_cls is 1D
            gt_cls = batch['cls'].to(device).squeeze() # Add squeeze()
            batch_idx = batch['batch_idx'].to(device)
            original_shapes = batch['ori_shape'] # List of tuples [(h, w), ...]
        except KeyError as e:
            print(f"\nError: Missing key {e} in batch. Check dataset and collate_fn.")
            print(f"Batch keys: {batch.keys()}")
            raise e
        except Exception as e:
            print(f"\nError processing batch data: {e}")
            raise e


        if samples.shape[0] == 0: # Handle empty batches if they occur
             print("Warning: Skipping empty batch.")
             continue

        batch_size = samples.shape[0]
        resized_shape = samples.shape[2:] # Shape after transforms (e.g., (320, 320))

        # --- Run Inference ---
        with torch.no_grad():
            # Use model.predict for easier handling of results and NMS
            preds = model.predict(samples, conf=conf_thres, iou=iou_thres, verbose=False)
            # 'preds' is typically a list of Results objects, one per image

        # --- Prepare Predictions and Targets for Metric ---
        preds_for_metric = []
        targets_for_metric = []

        # Process Predictions
        for i in range(batch_size):
            result = preds[i] # Ultralytics Results object for image i
            original_shape_i = original_shapes[i] # Original (h, w)

            # Clone the tensor to prevent in-place modification error
            pred_boxes_resized_xyxy = result.boxes.xyxy.clone() # Clone the tensor
            pred_scores = result.boxes.conf             # Tensor [N]
             # Ensure pred_labels is 1D and Long type
            pred_labels = result.boxes.cls.long().squeeze() # <<< FIX: Use .long() and ensure squeeze

            # Ensure pred_labels is 1D, even if squeeze resulted in a scalar
            if pred_labels.ndim == 0:
                # If prediction is scalar, it means single prediction. Unsqueeze.
                # Also handle case where result might be truly empty
                if result.boxes.cls.numel() > 0: # Check original number before squeeze
                     pred_labels = pred_labels.unsqueeze(0) # Convert scalar tensor to 1D tensor [1]
                else: # If original was empty, ensure label is empty 1D
                     pred_labels = torch.empty(0, dtype=torch.long, device=device)


            # Scale boxes from resized_shape to original_shape
            if pred_boxes_resized_xyxy.numel() > 0:
                 # ops.scale_boxes expects shape (h, w)
                 scaled_pred_boxes = ops.scale_boxes(resized_shape, pred_boxes_resized_xyxy, original_shape_i)
                 # Ensure labels/scores match boxes after potential scaling/filtering if needed
                 if scaled_pred_boxes.shape[0] == 0 and pred_labels.shape[0] > 0:
                     print(f"Warning: Scaled boxes became empty but labels exist for pred item {i}. Forcing labels/scores empty.")
                     pred_labels = torch.empty(0, dtype=torch.long, device=device) # Ensure it's size [0] and long
                     pred_scores = torch.empty(0, dtype=torch.float32, device=device)
                 elif scaled_pred_boxes.shape[0] != pred_labels.shape[0]:
                      print(f"Warning: Mismatch after scaling boxes ({scaled_pred_boxes.shape[0]}) vs labels ({pred_labels.shape[0]}) for pred item {i}. Check scaling logic.")
                      # Attempt to keep only labels/scores corresponding to potentially remaining boxes if possible, otherwise clear?
                      # For safety, maybe clear this prediction if inconsistent? Or trust result.boxes structure?
                      # Assuming result.boxes components are consistent in length N initially.
                      # If scaling removes boxes, it implies coords were invalid. Let's keep labels/scores for now.
                      pass # Let metric handle potential issues if shapes mismatch internally? Risky.

            else:
                 # If no boxes initially, ensure labels/scores are also empty and 1D ([0])
                 scaled_pred_boxes = torch.empty((0, 4), device=device)
                 pred_labels = torch.empty(0, dtype=torch.long, device=device) # Ensure it's size [0] and long
                 pred_scores = torch.empty(0, dtype=torch.float32, device=device) # Ensure scores are also size [0]


            preds_for_metric.append({
                'boxes': scaled_pred_boxes, # Absolute xyxy
                'scores': pred_scores,
                'labels': pred_labels, # Already ensured to be 1D Long
            })

        # Process Ground Truth Targets
        for i in range(batch_size):
            mask = (batch_idx == i)
            img_boxes_norm = gt_bboxes_norm[mask] # Normalized xywh
            img_labels = gt_cls[mask] # gt_cls is already squeezed.

            # Ensure img_labels is 1D, even if mask selects only one item
            if img_labels.ndim == 0:
                 # Check if original gt_cls[mask] had elements before making it 1D
                 if gt_cls[mask].numel() > 0:
                     img_labels = img_labels.unsqueeze(0) # Convert scalar tensor to 1D tensor [1]
                 else: # If original was empty, ensure label is empty 1D
                     img_labels = torch.empty(0, dtype=torch.long, device=device)

            img_h, img_w = original_shapes[i] # Original H, W

            if img_boxes_norm.numel() > 0:
                # Denormalize from xywh [0,1] to absolute xywh
                img_boxes_abs_xywh = img_boxes_norm * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32, device=device)
                # Convert absolute xywh to absolute xyxy
                # Assumes input 'bboxes' from dataset/collate are center_x, center_y, width, height
                gt_boxes_abs_xyxy = ops.xywh2xyxy(img_boxes_abs_xywh)

                # Clamp boxes to image dimensions
                gt_boxes_abs_xyxy = ops.clip_boxes(gt_boxes_abs_xyxy, (img_h, img_w))
            else:
                # If no boxes, ensure labels are also empty and 1D ([0])
                gt_boxes_abs_xyxy = torch.empty((0, 4), device=device)
                img_labels = torch.empty(0, dtype=torch.long, device=device) # Ensure it's size [0]


            targets_for_metric.append({
                'boxes': gt_boxes_abs_xyxy, # Absolute xyxy
                'labels': img_labels.long(), # <<< FIX: Ensure labels are torch.long >>>
            })

        # --- Update Metric ---
        try:
            # Ensure consistency in case of empty predictions/targets for an image
            if len(preds_for_metric) != len(targets_for_metric):
                 print(f"\nError: Mismatch between processed predictions ({len(preds_for_metric)}) and targets ({len(targets_for_metric)}) count in batch.")
                 # Skip update for this batch might be safest
                 continue

            map_metric.update(preds_for_metric, targets_for_metric)
        except Exception as e:
            print(f"\nError updating map_metric in batch: {e}")
            # Provide more context on error
            print(f"Batch Index: {progress_bar.n}")
            print(f"Number of preds: {len(preds_for_metric)}, Number of targets: {len(targets_for_metric)}")
            if preds_for_metric:
                 print(f"Preds[0] keys: {preds_for_metric[0].keys()}")
                 print(f"  boxes shape: {preds_for_metric[0]['boxes'].shape}, dtype: {preds_for_metric[0]['boxes'].dtype}")
                 print(f"  scores shape: {preds_for_metric[0]['scores'].shape}, dtype: {preds_for_metric[0]['scores'].dtype}")
                 print(f"  labels shape: {preds_for_metric[0]['labels'].shape}, dtype: {preds_for_metric[0]['labels'].dtype}")
            if targets_for_metric:
                 print(f"Targets[0] keys: {targets_for_metric[0].keys()}")
                 print(f"  boxes shape: {targets_for_metric[0]['boxes'].shape}, dtype: {targets_for_metric[0]['boxes'].dtype}")
                 print(f"  labels shape: {targets_for_metric[0]['labels'].shape}, dtype: {targets_for_metric[0]['labels'].dtype}")
            raise e # Stop evaluation on error

        # --- Explicit Garbage Collection ---
        # Check if it's time to collect garbage
        if (batcch_index_i + 1) % N == 0:
            # print(f"Batch {batcch_index_i + 1}: Triggering gc.collect()")
            # Optional: Explicitly delete large variables from this iteration if possible
            try:
                del preds # Example: delete model output tensor
                del batch   # Example: delete batch data
                del samples
                del gt_bboxes_norm
                del gt_cls
                del batch_idx
                del original_shapes
            except NameError:
                pass # In case they weren't created or already deleted

            gc.collect() # Force garbage collection



    try:
        del preds # Example: delete model output tensor
        del batch   # Example: delete batch data
        del samples
    except NameError:
        pass # In case they weren't created or already deleted

    gc.collect() # Force garbage collection


    # --- Compute Final Metrics ---
    print("Computing final metrics...")
    processed_results = {} # Initialize before try block
    try:
        map_results = map_metric.compute()
        # Process results into a flat dictionary for easier logging
        for k, v in map_results.items():
             if isinstance(v, torch.Tensor):
                 processed_results[k] = v.item() if v.numel() == 1 else v.tolist()
             else:
                 processed_results[k] = v
        print(f"Computed Metrics: {processed_results}")

    except Exception as e:
        print(f"Error computing final metrics: {e}")
        # Assign error message to the dict if compute fails
        processed_results = {"error": str(e)}


    return processed_results

In [11]:

# --- Main EVAL_OWN Function (using custom loop) ---
def main_fn(checkpoint_path, ds_loader):
    """
    Main function to evaluate a YOLO model using the custom evaluation loop.

    Args:
        checkpoint_path (str): Path to the trained YOLO model checkpoint (.pt file).
        hf_dataset_id (str): Hugging Face dataset identifier.
        dataset_split (str): Split to evaluate ('test', 'validation').
        imgsz (int): Image size used for evaluation.
        batch_size (int): Batch size for the dataloader.
        workers (int): Number of workers for the dataloader.
        device (str, optional): Device ('cuda', 'cpu'). Auto-detects if None.
        trust_remote_code (bool): Whether to trust remote code for HF dataset.

    Returns:
        dict: Dictionary with evaluation metrics (e.g., mAP).
    """
    print(f"--- Starting EVAL_OWN (custom loop) ---")
    # --- 1. Setup Device ---
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- 2. Load Model ---
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    print(f"Loading model from: {checkpoint_path}")
    model = YOLO(checkpoint_path)
    print("Model loaded.")

    # --- 4. Run Custom Evaluation ---
    print("Calling custom evaluation function...")
    results = evaluate_yolo_with_torchmetrics(
        model=model,
        dataloader=ds_loader,
        device=device
        # Pass conf/iou thresholds if you want to override defaults
    )

    # --- 5. Return Results ---
    print(f"--- EVAL_OWN Finished ---")
    return results

In [12]:

# --- Find Checkpoint ---

print(f"Using best checkpoint: {checkpoint_to_eval}")

# --- Run Evaluation ---
try:
    # Make sure load_dataset is defined (even if dummy)
    # if 'load_dataset' not in globals():
        #  def load_dataset(*args, **kwargs): return list(range(100))
    eval_metrics = main_fn(
        checkpoint_path=checkpoint_to_eval,
        ds_loader=test_loader
    )
    print("\nFinal Evaluation Metrics:")
    print(eval_metrics)
except FileNotFoundError as e:
    print(f"Evaluation failed: Checkpoint not found - {e}")
except NameError as e:
     print(f"Evaluation failed: A required class or function is not defined - {e}")
     print("Ensure HFYOLODataset and its collate function are available.")
except Exception as e:
    print(f"An unexpected error occurred during evaluation:")
    import traceback
    traceback.print_exc()


Using best checkpoint: ../not_tracked_dir/output_yolo_v8_2025-04-11/weights/best.pt
--- Starting EVAL_OWN (custom loop) ---
Using device: cuda
Loading model from: ../not_tracked_dir/output_yolo_v8_2025-04-11/weights/best.pt
Model loaded.
Calling custom evaluation function...
conf_thres: 0.001
iou_thres: 0.5
Starting custom evaluation loop on device: cuda


Evaluating: 782it [13:05,  1.00s/it]


Computing final metrics...
Computed Metrics: {'map': 0.9237134456634521, 'map_50': 0.9629759788513184, 'map_75': 0.9419248104095459, 'map_small': 0.9237366318702698, 'map_medium': -1.0, 'map_large': -1.0, 'mar_1': 0.7469073534011841, 'mar_10': 0.9397873878479004, 'mar_100': 0.9397911429405212, 'mar_small': 0.9397911429405212, 'mar_medium': -1.0, 'mar_large': -1.0, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'classes': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}
--- EVAL_OWN Finished ---

Final Evaluation Metrics:
{'map': 0.9237134456634521, 'map_50': 0.9629759788513184, 'map_75': 0.9419248104095459, 'map_small': 0.9237366318702698, 'map_medium': -1.0, 'map_large': -1.0, 'mar_1': 0.7469073534011841, 'mar_10': 0.9397873878479004, 'mar_100': 0.9397911429405212, 'mar_small': 0.9397911429405212, 'mar_medium': -1.0, 'mar_large': -1.0, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'classes': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}


On half dataset

Final Evaluation Metrics:
{'map': 0.9237134456634521, 'map_50': 0.9629759788513184, 'map_75': 0.9419248104095459, 'map_small': 0.9237366318702698, 'map_medium': -1.0, 'map_large': -1.0, 'mar_1': 0.7469073534011841, 'mar_10': 0.9397873878479004, 'mar_100': 0.9397911429405212, 'mar_small': 0.9397911429405212, 'mar_medium': -1.0, 'mar_large': -1.0, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'classes': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}