### Library

In [None]:
import torch
import torchvision
from torchvision.ops import RoIAlign
import torchvision.transforms.v2 as T
from torchvision.transforms.v2 import functional as F, Transform, Compose, ToDtype, Normalize
from torch.utils.data import DataLoader, Dataset, Subset
import os
import numpy as np
from PIL import Image
import random
import time
from tqdm import tqdm
import torchmetrics
from ultralytics import YOLO

### Data Path

In [17]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using device: {DEVICE}")

DATA_PATH = './data/rare_fewshot_detection/'
TRAIN_IMG_DIR = os.path.join(DATA_PATH, 'train/images')
TRAIN_LABEL_DIR = os.path.join(DATA_PATH, 'train/labels')
TEST_IMG_DIR = os.path.join(DATA_PATH, 'test/images')
TEST_LABEL_DIR = os.path.join(DATA_PATH, 'test/labels')

Using device: cuda


In [18]:
N_SHOTS = 5
CLASS_NAMES = ['cerscospora', 'healthy', 'leaf rust','miner' ,'phoma','nematode', 'pink disease']
NUM_CLASSES = len(CLASS_NAMES)
print(f"Number of classes: {NUM_CLASSES}")

Number of classes: 7


In [None]:
BACKGROUND_CLASS_ID = 0
CLASS_MAP = {name: i + 1 for i, name in enumerate(CLASS_NAMES)}
INV_CLASS_MAP = {v: k for k, v in CLASS_MAP.items()}
print(f"Inverse class map: {INV_CLASS_MAP}")

Class map: {'cerscospora': 1, 'healthy': 2, 'leaf rust': 3, 'miner': 4, 'phoma': 5, 'nematode': 6, 'pink disease': 7}
Inverse class map: {1: 'cerscospora', 2: 'healthy', 3: 'leaf rust', 4: 'miner', 5: 'phoma', 6: 'nematode', 7: 'pink disease'}


In [26]:
TARGET_CLASS_IDS = set(CLASS_MAP.values()) # Should be {0, 1, 2, 3, 4, 5, 6}
print(f"Target Class IDs: {TARGET_CLASS_IDS}")

Target Class IDs: {1, 2, 3, 4, 5, 6, 7}


In [21]:
YOLO_MODEL_PATH = '../object_detection_full_data/runs/detect/train/best.pt'

In [22]:
FEATURE_LAYER_INDEX = -4
ROI_OUTPUT_SIZE = (7, 7)

### Configurations

In [24]:
NUM_EPOCHS = 10
LEARNING_RATE = 1e-5
WEIGHT_DECAY = 0.0005
BATCH_SIZE = 4
VALIDATION_SPLIT = 0.2
PROTO_UPDATE_FREQ = 5
FINETUNE_FEATURE_EXTRACTOR = True
SAVE_PROTOTYPES_PATH = './models/yolov11n_fewshot_prototypes.pt'
SAVE_FINETUNED_MODEL_PATH = './models/yolov11n_fewshot_finetuned.pt'

### Data Augmentation / Preprocessing

In [None]:
class SquarePad(Transform):
    def _transform(self, inpt, params):
        img = inpt
        _, h, w = img.shape
        max_wh = max(w, h)
        p_left, p_top = [(max_wh - s) // 2 for s in (w, h)]
        p_right, p_bottom = [max_wh - s - p for s, p in zip((w, h), (p_left, p_top))]
        padding = (p_left, p_top, p_right, p_bottom)
        return F.pad(img, padding, fill=114/255.0)

def get_transform(target_size=(640, 640)):
    transforms = []
    # transforms.append(SquarePad()) # Optional, depends on YOLO model needs
    transforms.append(T.Resize(target_size, antialias=True))
    transforms.append(T.ToDtype(torch.float32, scale=True))
    return T.Compose(transforms)

### Dataset class

In [None]:
class RareDiseaseDataset(Dataset):
    def __init__(self, img_dir, label_dir, target_class_ids, transforms=None, target_img_size=(640, 640)):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.target_class_ids = target_class_ids
        self.transforms = transforms
        self.target_img_size = target_img_size

        self.image_files = sorted([f for f in os.listdir(img_dir) if f.endswith(('.jpg', '.png', '.jpeg'))])
        self.img_ids = [os.path.splitext(f)[0] for f in self.image_files]

        print("="*30)
        print("Dataset Initialized:")
        print(f"Image directory: {self.img_dir}")
        print(f"Label directory: {self.label_dir} (Expecting YOLO .txt format)")
        print(f"Target Few-Shot Class IDs: {self.target_class_ids}")
        print("IMPORTANT: Assumes class IDs in .txt files directly correspond to the target IDs.")
        print("           Ensure sufficient examples exist for N-shot learning per class.")
        print("="*30)

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_path = os.path.join(self.img_dir, self.image_files[idx])
        label_path = os.path.join(self.label_dir, f"{img_id}.txt")

        try:
            img = Image.open(img_path).convert("RGB")
            original_w, original_h = img.size
        except FileNotFoundError:
            print(f"Error: Image file not found at {img_path}")
            return None, None # Handle appropriately later

        boxes = []
        labels = []
        try:
            with open(label_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) != 5: continue
                    class_id = int(parts[0])

                    # Filter: Keep only target few-shot classes
                    if class_id not in self.target_class_ids:
                        continue

                    xc, yc, w, h = map(float, parts[1:])
                    xmin = (xc - w / 2) * original_w
                    ymin = (yc - h / 2) * original_h
                    xmax = (xc + w / 2) * original_w
                    ymax = (yc + h / 2) * original_h

                    xmin, ymin = max(0, xmin), max(0, ymin)
                    xmax, ymax = min(original_w, xmax), min(original_h, ymax)

                    if xmax <= xmin or ymax <= ymin: continue
                    boxes.append([xmin, ymin, xmax, ymax])
                    labels.append(class_id) # Use the target ID (0-6)

        except FileNotFoundError: pass
        except Exception as e: print(f"Warning: Error reading {label_path}: {e}")

        if not boxes:
            boxes_tensor = torch.zeros((0, 4), dtype=torch.float32)
            labels_tensor = torch.zeros((0,), dtype=torch.int64)
        else:
            boxes_tensor = torch.as_tensor(boxes, dtype=torch.float32)
            labels_tensor = torch.as_tensor(labels, dtype=torch.int64)

        target = {"boxes": boxes_tensor, "labels": labels_tensor, # Labels are 0-6
                  "image_id": torch.tensor([idx]), "original_size": torch.tensor([original_w, original_h])}

        if self.transforms:
            img_tensor = self.transforms(img)
            transformed_h, transformed_w = img_tensor.shape[1:]
            scale_w = transformed_w / original_w
            scale_h = transformed_h / original_h
            target["scale_factor"] = torch.tensor([scale_w, scale_h])
        else:
            img_tensor = T.functional.to_tensor(img)
            target["scale_factor"] = torch.tensor([1.0, 1.0])

        return img_tensor, target

    def get_image_ids_by_class(self):
        """ Maps image indices to the few-shot class IDs they contain. """
        class_to_image_idxs = {cls_id: [] for cls_id in self.target_class_ids}
        print("Mapping images to classes (using .txt labels)...")
        # It's more efficient to read labels once if dataset is large, but for small datasets getitem is fine.
        for idx in tqdm(range(len(self))):
            _, target = self.__getitem__(idx)
            if target is None or target['labels'].numel() == 0: continue
            unique_labels = torch.unique(target['labels'])
            for label in unique_labels:
                class_id = label.item()
                if class_id in class_to_image_idxs: # Check if it's one of our target classes
                    class_to_image_idxs[class_id].append(idx)
        print("Mapping complete.")
        insufficient_classes = []
        for class_id in self.target_class_ids:
             class_name = INV_CLASS_MAP.get(class_id, f"ID_{class_id}")
             count = len(class_to_image_idxs.get(class_id, []))
             print(f"  Class '{class_name}' (ID: {class_id}) found in {count} images.")
             if count < N_SHOTS:
                 print(f"ERROR: Class '{class_name}' needs {N_SHOTS} images for support set, but only found {count}.")
                 insufficient_classes.append(class_name)
        if insufficient_classes:
             raise ValueError(f"Insufficient images for N-shot learning for classes: {', '.join(insufficient_classes)}")
        return class_to_image_idxs

### Utility Functions

In [None]:
def collate_fn(batch):
    batch = list(filter(lambda x: x is not None and x[0] is not None and x[1] is not None, batch))
    if not batch: return None, None
    return tuple(zip(*batch))

def img_coords_to_feature_coords(boxes, img_shape, feature_map_shape):
    img_h, img_w = img_shape[-2:]
    feat_h, feat_w = feature_map_shape[-2:]
    scale_w = feat_w / img_w if img_w > 0 else 0
    scale_h = feat_h / img_h if img_h > 0 else 0
    scale_w = max(scale_w, 1e-6)
    scale_h = max(scale_h, 1e-6)
    scaled_boxes = boxes * torch.tensor([scale_w, scale_h, scale_w, scale_h], device=boxes.device)
    scaled_boxes[:, 0::2] = torch.clamp(scaled_boxes[:, 0::2], 0, feat_w)
    scaled_boxes[:, 1::2] = torch.clamp(scaled_boxes[:, 1::2], 0, feat_h)
    return scaled_boxes

### Feature Extraction

In [None]:
class YOLOFeatureExtractor(torch.nn.Module):
    # (Implementation is the same as the previous response)
    def __init__(self, yolo_model, feature_layer_index):
        super().__init__()
        self.yolo_model = yolo_model
        try:
            self.model_sequence = self.yolo_model.model.model
        except AttributeError:
            print("Error: Could not access `model.model` attribute in the loaded YOLO object.")
            print("       The model structure might be different. Try inspecting `yolo_model` directly.")
            raise
        self.feature_layer_index = feature_layer_index

    def forward(self, x):
        features = x
        target_layer = self.feature_layer_index
        if target_layer < 0:
            if not self.model_sequence:
                 raise ValueError("Model sequence not initialized")
            target_layer = len(self.model_sequence) + target_layer

        try:
            current_feature = features
            for i, module in enumerate(self.model_sequence):
                 current_feature = module(current_feature)
                 if i == target_layer:
                     features = current_feature
                     break
            else: # If loop finishes without break (target_layer invalid)
                 if target_layer >= len(self.model_sequence):
                     raise IndexError(f"Feature layer index {self.feature_layer_index} (->{target_layer}) is out of bounds for sequence length {len(self.model_sequence)}")
                 # This else might be reachable if target_layer is the *last* layer index
                 # In that case, features should already hold the output of the last layer.
                 # Check if features were updated correctly.
                 if i == target_layer:
                     features = current_feature


            # Handle potential list output from FPN/Concat layers
            if isinstance(features, (list, tuple)):
                 features = features[-1] # Take the last feature map (common heuristic)

        except IndexError as e:
            print(f"ERROR accessing feature layer {self.feature_layer_index} (->{target_layer}): {e}")
            raise
        except Exception as e:
            print(f"ERROR during feature extraction at layer {target_layer}: {e}")
            print(f"Input shape: {x.shape}")
            # print("Model sequence:", self.model_sequence) # Can be very long
            raise
        return features

### Prototype

In [None]:
def calculate_prototypes(support_dataloader, feature_extractor, roi_align, target_class_ids, device):
    # (Implementation is the same as the previous response, uses target_class_ids)
    feature_extractor.eval()
    prototypes = {cls_id: [] for cls_id in target_class_ids}
    print("Calculating prototypes...")
    with torch.no_grad():
        for images, targets in tqdm(support_dataloader, desc="Processing Support Set"):
            if images is None or targets is None: continue
            try: images = torch.stack(images).to(device)
            except Exception as e: print(f"Skipping batch due to image stacking error: {e}"); continue

            feature_maps = feature_extractor(images)
            if feature_maps is None: print("Feature extractor returned None"); continue

            for i in range(len(targets)):
                target = targets[i]
                gt_boxes = target['boxes'].to(device)
                gt_labels = target['labels'].to(device) # Should be 0-6
                if gt_boxes.shape[0] == 0: continue

                img_shape = images[i].shape
                # Handle case where feature map might have different batch dim if extractor behaves unexpectedly
                fm_idx = i if feature_maps.shape[0] == images.shape[0] else 0
                if fm_idx >= feature_maps.shape[0]: continue # Skip if feature map batch dim doesn't match
                feature_map_shape = feature_maps[fm_idx].shape

                feature_boxes = img_coords_to_feature_coords(gt_boxes, img_shape, feature_map_shape)
                valid_box_indices = torch.where((feature_boxes[:, 2] > feature_boxes[:, 0]) & (feature_boxes[:, 3] > feature_boxes[:, 1]))[0]
                if len(valid_box_indices) == 0: continue
                feature_boxes = feature_boxes[valid_box_indices]
                gt_labels = gt_labels[valid_box_indices]

                roi_boxes = torch.cat([torch.zeros(feature_boxes.shape[0], 1, device=device), feature_boxes], dim=1)
                try:
                    # Ensure feature map used for RoIAlign has batch dim 1
                    roi_features = roi_align(feature_maps[fm_idx].unsqueeze(0), [roi_boxes])
                except Exception as e:
                     print(f"Error during RoIAlign: {e}")
                     print(f"Feature map shape: {feature_maps[fm_idx].unsqueeze(0).shape}, RoI boxes shape: {roi_boxes.shape}")
                     continue

                roi_features_flat = roi_features.view(roi_features.shape[0], -1)

                for j, label in enumerate(gt_labels):
                    class_id = label.item()
                    if class_id in prototypes:
                        prototypes[class_id].append(roi_features_flat[j].cpu())

    final_prototypes = {}
    print("Averaging features for final prototypes...")
    for class_id, features_list in prototypes.items():
        if not features_list:
            print(f"Warning: No support features found for class ID {class_id}.")
            continue
        try:
             features_tensor = torch.stack(features_list)
             final_prototypes[class_id] = torch.mean(features_tensor, dim=0)
             print(f"  Class ID {class_id}: {final_prototypes[class_id].shape}")
        except Exception as e:
            print(f"Error processing features for class {class_id}: {e}")

    if len(final_prototypes) != len(target_class_ids):
         missing_ids = target_class_ids - set(final_prototypes.keys())
         print(f"ERROR: Failed to create prototypes for all target classes. Missing IDs: {missing_ids}")
         # Decide action: raise error or continue with available?
         raise ValueError("Failed to create all prototypes")

    return final_prototypes

### Training

In [None]:
def train_few_shot_yolo(feature_extractor, prototypes, optimizer, scheduler, query_dataloader, roi_align, device, epoch, finetune_extractor=True):
    if finetune_extractor: feature_extractor.train()
    else: feature_extractor.eval()

    total_loss = 0.0
    batch_count = 0
    progress_bar = tqdm(query_dataloader, desc=f"Epoch {epoch+1} Training", leave=False)

    if not prototypes: print("Error: Prototypes dict empty. Skipping training."); return 0.0
    proto_class_ids = sorted(list(prototypes.keys())) # e.g., [0, 1, 2, 3, 4, 5, 6]
    if not proto_class_ids: print("Error: No valid class IDs in prototypes. Skipping training."); return 0.0

    try: proto_tensor = torch.stack([prototypes[cid].to(device) for cid in proto_class_ids])
    except Exception as e: print(f"Error stacking proto tensors: {e}"); return 0.0
    class_id_to_proto_idx = {cls_id: idx for idx, cls_id in enumerate(proto_class_ids)}

    criterion = torch.nn.CrossEntropyLoss()

    for images, targets in progress_bar:
        if images is None or targets is None: continue
        try: images = torch.stack(images).to(device)
        except Exception as e: print(f"Skipping batch due to image stacking error: {e}"); continue

        feature_maps = feature_extractor(images)
        if feature_maps is None: print("Feature extractor returned None"); continue

        batch_loss = 0.0
        valid_targets_in_batch = 0

        for i in range(len(targets)):
            target = targets[i]
            gt_boxes = target['boxes'].to(device)
            gt_labels = target['labels'].to(device) # Should be 0-6
            if gt_boxes.shape[0] == 0: continue

            # Filter GT labels for those we have prototypes for
            valid_indices = [idx for idx, label in enumerate(gt_labels) if label.item() in prototypes]
            if not valid_indices: continue
            gt_boxes = gt_boxes[valid_indices]
            gt_labels = gt_labels[valid_indices]

            img_shape = images[i].shape
            fm_idx = i if feature_maps.shape[0] == images.shape[0] else 0
            if fm_idx >= feature_maps.shape[0]: continue
            feature_map_shape = feature_maps[fm_idx].shape

            feature_boxes = img_coords_to_feature_coords(gt_boxes, img_shape, feature_map_shape)
            valid_box_indices = torch.where((feature_boxes[:, 2] > feature_boxes[:, 0]) & (feature_boxes[:, 3] > feature_boxes[:, 1]))[0]
            if len(valid_box_indices) == 0: continue
            feature_boxes = feature_boxes[valid_box_indices]
            gt_labels = gt_labels[valid_box_indices]

            roi_boxes = torch.cat([torch.zeros(feature_boxes.shape[0], 1, device=device), feature_boxes], dim=1)

            try:
                fm_input = feature_maps[fm_idx].unsqueeze(0)
                if finetune_extractor: roi_features = roi_align(fm_input, [roi_boxes])
                else:
                    with torch.no_grad(): roi_features = roi_align(fm_input, [roi_boxes])
            except Exception as e:
                print(f"Error during RoIAlign in training: {e}")
                print(f"Feature map shape: {fm_input.shape}, RoI boxes shape: {roi_boxes.shape}")
                continue

            roi_features_flat = roi_features.view(roi_features.shape[0], -1)

            # Check feature dimension consistency
            if proto_tensor.shape[1] != roi_features_flat.shape[1]:
                 print(f"WARNING: Feature dimension mismatch! Proto: {proto_tensor.shape[1]}, RoI: {roi_features_flat.shape[1]}. Skipping batch.")
                 # This often indicates the FEATURE_LAYER_INDEX is wrong or features changed unexpectedly.
                 continue

            roi_features_norm = torch.nn.functional.normalize(roi_features_flat, p=2, dim=1)
            proto_tensor_norm = torch.nn.functional.normalize(proto_tensor, p=2, dim=1)
            similarities = torch.mm(roi_features_norm, proto_tensor_norm.t())

            try: target_proto_indices = torch.tensor([class_id_to_proto_idx[label.item()] for label in gt_labels], device=device, dtype=torch.long)
            except KeyError as e: print(f"Error: GT label {e} not in proto map. Skipping loss calc."); continue

            loss = criterion(similarities, target_proto_indices)

            if finetune_extractor and torch.isfinite(loss):
                batch_loss += loss
                valid_targets_in_batch += 1
            elif not torch.isfinite(loss): print(f"Warning: Non-finite loss: {loss.item()}. Skipping.")

        if finetune_extractor and valid_targets_in_batch > 0:
             avg_batch_loss = batch_loss / valid_targets_in_batch
             if torch.isfinite(avg_batch_loss):
                  optimizer.zero_grad()
                  avg_batch_loss.backward()
                  optimizer.step()
                  total_loss += avg_batch_loss.item()
                  batch_count += 1
                  progress_bar.set_postfix(loss=avg_batch_loss.item())
             else: print(f"Warning: Avg batch loss non-finite: {avg_batch_loss}. Skipping update.")

    if scheduler: scheduler.step()
    avg_epoch_loss = total_loss / batch_count if batch_count > 0 else 0
    if avg_epoch_loss > 0 or epoch == 0: print(f"Epoch {epoch+1} Training Summary: Fine-tuning Loss: {avg_epoch_loss:.4f}")
    return avg_epoch_loss

### Inference

In [None]:
@torch.no_grad()
def inference_yolo_few_shot(full_yolo_model, feature_extractor, prototypes, dataloader, roi_align, device, conf_threshold=0.3, nms_threshold=0.45):
    full_yolo_model.eval()
    feature_extractor.eval()

    if not prototypes: print("Error: Prototypes dict empty for inference."); return [], []
    proto_class_ids = sorted(list(prototypes.keys())) # e.g., [0, 1, ..., 6]
    if not proto_class_ids: print("Error: No class IDs in prototypes for inference."); return [], []

    try: proto_tensor = torch.stack([prototypes[cid].to(device) for cid in proto_class_ids])
    except Exception as e: print(f"Error stacking proto tensors for inference: {e}"); return [], []
    proto_idx_to_class_id = {idx: cls_id for idx, cls_id in enumerate(proto_class_ids)}

    all_preds, all_targets_for_metric = [], []
    print("Running inference with prototype re-classification...")
    progress_bar = tqdm(dataloader, desc="Inference")

    for images, targets in progress_bar:
        if images is None or targets is None: continue
        try: image_tensors = torch.stack(images).to(device)
        except Exception as e: print(f"Skipping batch due to image stacking error: {e}"); continue

        feature_maps = feature_extractor(image_tensors)
        if feature_maps is None: print("Feature extractor returned None in inference"); continue

        # Use lower confidence for initial proposals, verbose=False reduces spam
        yolo_results = full_yolo_model.predict(image_tensors, verbose=False, conf=0.05)

        batch_preds = []
        for i, result in enumerate(yolo_results):
            img_shape = image_tensors[i].shape
            fm_idx = i if feature_maps.shape[0] == image_tensors.shape[0] else 0
            if fm_idx >= feature_maps.shape[0]: continue
            feature_map_shape = feature_maps[fm_idx].shape

            if result.boxes is None or len(result.boxes) == 0: proposal_boxes = torch.empty((0, 4), device=device)
            else: proposal_boxes = result.boxes.xyxy.to(device)

            if proposal_boxes.shape[0] == 0: # Handle no proposals case
                 batch_preds.append({'boxes': torch.empty((0, 4), device=device), 'scores': torch.empty((0,), device=device), 'labels': torch.empty((0,), dtype=torch.int64, device=device)})
                 continue

            feature_boxes = img_coords_to_feature_coords(proposal_boxes, img_shape, feature_map_shape)
            valid_box_indices = torch.where((feature_boxes[:, 2] > feature_boxes[:, 0]) & (feature_boxes[:, 3] > feature_boxes[:, 1]))[0]
            if len(valid_box_indices) == 0: # Handle no valid feature boxes case
                 batch_preds.append({'boxes': torch.empty((0, 4), device=device), 'scores': torch.empty((0,), device=device), 'labels': torch.empty((0,), dtype=torch.int64, device=device)})
                 continue
            feature_boxes = feature_boxes[valid_box_indices]
            original_indices_kept = valid_box_indices

            roi_boxes = torch.cat([torch.zeros(feature_boxes.shape[0], 1, device=device), feature_boxes], dim=1)
            try:
                fm_input = feature_maps[fm_idx].unsqueeze(0)
                roi_features = roi_align(fm_input, [roi_boxes])
            except Exception as e:
                 print(f"Error during RoIAlign in inference: {e}")
                 batch_preds.append({'boxes': torch.empty((0, 4), device=device), 'scores': torch.empty((0,), device=device), 'labels': torch.empty((0,), dtype=torch.int64, device=device)})
                 continue # Skip to next image in batch


            roi_features_flat = roi_features.view(roi_features.shape[0], -1)

            if proto_tensor.shape[1] != roi_features_flat.shape[1]:
                 print(f"WARNING: Feature dimension mismatch inference! Proto: {proto_tensor.shape[1]}, RoI: {roi_features_flat.shape[1]}.")
                 batch_preds.append({'boxes': torch.empty((0, 4), device=device), 'scores': torch.empty((0,), device=device), 'labels': torch.empty((0,), dtype=torch.int64, device=device)})
                 continue

            roi_features_norm = torch.nn.functional.normalize(roi_features_flat, p=2, dim=1)
            proto_tensor_norm = torch.nn.functional.normalize(proto_tensor, p=2, dim=1)
            similarities = torch.mm(roi_features_norm, proto_tensor_norm.t())

            new_scores, proto_indices = torch.max(similarities, dim=1)
            # Map index (0..6) back to actual class ID (0..6)
            new_labels = torch.tensor([proto_idx_to_class_id[idx.item()] for idx in proto_indices], device=device)

            keep_indices = new_scores > conf_threshold
            final_boxes = proposal_boxes[original_indices_kept][keep_indices]
            final_scores = new_scores[keep_indices]
            final_labels = new_labels[keep_indices]

            if final_boxes.shape[0] > 0: # Only apply NMS if boxes remain
                 nms_indices = torchvision.ops.nms(final_boxes, final_scores, nms_threshold)
                 final_boxes = final_boxes[nms_indices]
                 final_scores = final_scores[nms_indices]
                 final_labels = final_labels[nms_indices]

            batch_preds.append({'boxes': final_boxes.cpu(), 'scores': final_scores.cpu(), 'labels': final_labels.cpu().to(torch.int64)})

        all_preds.extend(batch_preds)
        for target in targets: all_targets_for_metric.append({'boxes': target['boxes'].cpu(), 'labels': target['labels'].cpu().to(torch.int64)})

    return all_preds, all_targets_for_metric


### Evaluate

In [None]:
def evaluate_mAP(predictions, ground_truths):
    if not predictions or not ground_truths: print("Eval failed: No preds/GTs."); return None
    min_len = min(len(predictions), len(ground_truths))
    if len(predictions) != len(ground_truths): print(f"Eval Warning: Pred/GT mismatch ({len(predictions)} vs {len(ground_truths)}). Using {min_len} samples.")
    if min_len == 0: return None
    predictions, ground_truths = predictions[:min_len], ground_truths[:min_len]

    metric = torchmetrics.detection.MeanAveragePrecision(iou_type="bbox", class_metrics=True).to(DEVICE) # Send metric to device
    formatted_preds, formatted_gts = [], []
    for p in predictions:
         if isinstance(p, dict) and 'boxes' in p and 'scores' in p and 'labels' in p:
              formatted_preds.append({k: torch.as_tensor(v).to(DEVICE) for k, v in p.items()}) # Send preds to device
    for gt in ground_truths:
         if isinstance(gt, dict) and 'boxes' in gt and 'labels' in gt:
              formatted_gts.append({k: torch.as_tensor(v).to(DEVICE) for k, v in gt.items()}) # Send GTs to device

    if len(formatted_preds) != len(formatted_gts): print("Eval failed: Mismatch after formatting."); return None
    if not formatted_preds: print("Eval failed: No valid formatted data."); return None

    try:
        metric.update(formatted_preds, formatted_gts)
        computed_metrics = metric.compute() # Compute on device
        # Move results back to CPU for printing/logging if needed
        computed_metrics = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in computed_metrics.items()}
        print("\nEvaluation Metrics:")
        # ... [Rest of the printing logic remains the same, using INV_CLASS_MAP] ...
        for key, value in computed_metrics.items():
            if isinstance(value, torch.Tensor) and value.numel() == 1:
                 print(f"  {key}: {value.item():.4f}")
            elif isinstance(value, list) and key == 'map_per_class':
                print(f"  {key}:")
                metric_classes = computed_metrics.get('classes', [])
                if isinstance(metric_classes, torch.Tensor): metric_classes = [c.item() for c in metric_classes]
                for i, ap in enumerate(value):
                    if i < len(metric_classes):
                         class_id = metric_classes[i]
                         class_name = INV_CLASS_MAP.get(class_id, f"Unknown_ID_{class_id}")
                         if ap != -1: print(f"    - {class_name} (ID: {class_id}): {ap.item():.4f}")
                    else: print(f"    - Index {i} (Class ID unknown): {ap.item():.4f}")
            elif isinstance(value, torch.Tensor) and key == 'classes':
                 class_ids = [c.item() for c in value]
                 class_names = [INV_CLASS_MAP.get(cid, f"ID_{cid}") for cid in class_ids]
                 print(f"  Classes Evaluated (IDs): {class_ids}")
                 print(f"  Classes Evaluated (Names): {class_names}")
            else: print(f"  {key}: {value}")
        return computed_metrics
    except Exception as e: print(f"Error computing mAP metrics: {e}")

### Main

In [None]:
if __name__ == "__main__":
    print(f"Starting Few-Shot YOLO (yolov11n) Detection - {NUM_CLASSES} Classes, {N_SHOTS}-Shot...")
    print(f"CUDA Available: {torch.cuda.is_available()}")

    # 1. Load Base YOLO Model
    print(f"\nLoading base YOLO model from: {YOLO_MODEL_PATH}")
    try:
        full_yolo_model = YOLO(YOLO_MODEL_PATH)
        full_yolo_model.to(DEVICE)
        print("Base YOLO model loaded successfully.")
        # --- Uncomment to Inspect model structure ---
        print("\nInspecting model structure (adjust FEATURE_LAYER_INDEX if needed):")
        try: print(full_yolo_model.model.model)
        except AttributeError: print("Could not access default model structure.")
        print("\n")
        print(f"Using feature layer index: {FEATURE_LAYER_INDEX}")
    except Exception as e:
        print(f"Error loading YOLO model '{YOLO_MODEL_PATH}': {e}")
        exit()

    # 2. Create Feature Extractor Wrapper
    try:
        feature_extractor = YOLOFeatureExtractor(full_yolo_model, FEATURE_LAYER_INDEX).to(DEVICE)
    except Exception as e:
        print(f"Error creating feature extractor: {e}")
        exit()

    # 3. Prepare Data
    target_img_size = (640, 640)
    transforms = get_transform(target_size=target_img_size)
    try:
        full_train_dataset = RareDiseaseDataset(TRAIN_IMG_DIR, TRAIN_LABEL_DIR, TARGET_CLASS_IDS, transforms, target_img_size)
        if len(full_train_dataset) == 0: raise ValueError("No training images loaded.")
        # Perform Train/Val Split
        # (Split logic remains the same as previous response)
        dataset_size = len(full_train_dataset)
        indices = list(range(dataset_size))
        min_required_train = N_SHOTS * NUM_CLASSES
        can_split = dataset_size > min_required_train # Check if we have more images than needed for support

        if can_split and VALIDATION_SPLIT > 0:
            split = int(np.floor(VALIDATION_SPLIT * dataset_size))
            # Ensure validation set is at least 1 if possible, and train set is large enough
            split = max(1, split) if dataset_size > 1 else 0
            if dataset_size - split < min_required_train:
                print(f"Warning: Validation split ({split}) leaves too few images ({dataset_size - split}) for {N_SHOTS}-shot training. Disabling validation.")
                split = 0 # Disable validation split

            np.random.shuffle(indices)
            train_indices, val_indices = indices[split:], indices[:split]

        else: # Cannot split or validation disabled
             print("Warning: Dataset too small or validation split is 0. Using all data for training, no validation.")
             train_indices = indices
             val_indices = []

        train_subset_indices = train_indices
        val_subset = Subset(full_train_dataset, val_indices) if val_indices else None
        print(f"Total train samples available: {len(train_subset_indices)}, Val samples: {len(val_subset) if val_subset else 0}")


        # Select Support Set
        img_ids_by_class = full_train_dataset.get_image_ids_by_class() # This will raise ValueError if insufficient data

        support_indices = []
        print(f"\nSelecting {N_SHOTS}-shot support set from training split...")
        possible_support_indices = {cls_id: [idx for idx in ids if idx in train_subset_indices]
                                    for cls_id, ids in img_ids_by_class.items()}

        for cls_id in TARGET_CLASS_IDS:
            available_indices = possible_support_indices.get(cls_id, [])
            # Check against N_SHOTS (get_image_ids_by_class already did, but double-check within split)
            if len(available_indices) < N_SHOTS:
                raise ValueError(f"Insufficient images ({len(available_indices)}) for class {INV_CLASS_MAP.get(cls_id, cls_id)} (ID: {cls_id}) within the training split after validation split.")
            support_indices.extend(random.sample(available_indices, N_SHOTS))
        support_indices = sorted(list(set(support_indices)))
        print(f"Total unique support images selected: {len(support_indices)}")

        support_dataset = Subset(full_train_dataset, support_indices)
        support_dataloader = DataLoader(support_dataset, batch_size=max(1, BATCH_SIZE // 2), shuffle=False, collate_fn=collate_fn, num_workers=2)

        # Query Set
        query_indices = [idx for idx in train_subset_indices if idx not in support_indices]
        if not query_indices and len(train_subset_indices) > 0:
            print("Warning: No query images left. Using support set also as query set.")
            query_indices = support_indices
        elif not query_indices:
            raise ValueError("Error: No training images available for query set.")

        query_dataset = Subset(full_train_dataset, query_indices)
        query_dataloader = DataLoader(query_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=2)
        print(f"Support set size: {len(support_dataset)}, Query set size: {len(query_dataset)}")

        # Validation Loader
        val_dataloader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=2) if val_subset else None

        # Test Loader
        test_dataset = RareDiseaseDataset(TEST_IMG_DIR, TEST_LABEL_DIR, TARGET_CLASS_IDS, transforms, target_img_size)
        test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=2) if len(test_dataset) > 0 else None
        print(f"Test set size: {len(test_dataset)}")

    except Exception as e:
        print(f"Error during data loading or splitting: {e}")
        import traceback
        traceback.print_exc()
        exit()

    # 4. Initialize RoIAlign
    roi_align = RoIAlign(output_size=ROI_OUTPUT_SIZE, spatial_scale=1.0, sampling_ratio=-1).to(DEVICE)

    # 5. Calculate Initial Prototypes
    try:
        prototypes = calculate_prototypes(support_dataloader, feature_extractor, roi_align, TARGET_CLASS_IDS, DEVICE)
        if len(prototypes) < NUM_CLASSES:
            print(f"WARNING: Only generated prototypes for {len(prototypes)} out of {NUM_CLASSES} classes. Training/Inference might be affected.")
            if not prototypes: raise ValueError("No prototypes generated.")
    except Exception as e:
        print(f"Error calculating initial prototypes: {e}")
        exit()

    # 6. Setup Fine-tuning
    optimizer = None
    scheduler = None
    if FINETUNE_FEATURE_EXTRACTOR:
        params_to_tune = list(feature_extractor.parameters())
        print(f"Setting up optimizer for {len(params_to_tune)} parameters...")
        if not params_to_tune: FINETUNE_FEATURE_EXTRACTOR = False; print("No params found, disabling fine-tuning.")
        else:
            optimizer = torch.optim.AdamW(params_to_tune, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    if not FINETUNE_FEATURE_EXTRACTOR: print("Fine-tuning disabled.")

    # 7. Training Loop
    # (Loop logic remains the same as previous response)
    print("\nStarting Training/Adaptation Phase...")
    best_val_map = -1.0
    try:
        for epoch in range(NUM_EPOCHS):
            if FINETUNE_FEATURE_EXTRACTOR and epoch > 0 and epoch % PROTO_UPDATE_FREQ == 0:
                print(f"\n--- Recalculating Prototypes (Epoch {epoch+1}) ---")
                current_prototypes = calculate_prototypes(support_dataloader, feature_extractor, roi_align, TARGET_CLASS_IDS, DEVICE)
                if len(current_prototypes) == len(prototypes): prototypes = current_prototypes # Update only if all protos recalculate successfully
                else: print("Warning: Failed to recalculate all prototypes. Keeping previous.")

            epoch_loss = train_few_shot_yolo(
                feature_extractor, prototypes, optimizer, scheduler, query_dataloader,
                roi_align, DEVICE, epoch, FINETUNE_FEATURE_EXTRACTOR
            )

            if val_dataloader:
                print(f"\n--- Running Validation after Epoch {epoch+1} ---")
                val_preds, val_gts = inference_yolo_few_shot(
                    full_yolo_model, feature_extractor, prototypes, val_dataloader,
                    roi_align, DEVICE
                )
                val_metrics = evaluate_mAP(val_preds, val_gts)
                current_map = -1.0
                if val_metrics and 'map' in val_metrics:
                    val_map_tensor = val_metrics['map']
                    if isinstance(val_map_tensor, torch.Tensor) and val_map_tensor.numel() == 1: current_map = val_map_tensor.item()

                if current_map != -1.0: # Check if mAP calculation was successful
                    if current_map > best_val_map:
                        best_val_map = current_map
                        print(f"*** New best validation mAP: {best_val_map:.4f}. Saving prototypes and model... ***")
                        os.makedirs(os.path.dirname(SAVE_PROTOTYPES_PATH), exist_ok=True)
                        torch.save({cls_id: p.cpu() for cls_id, p in prototypes.items()}, SAVE_PROTOTYPES_PATH)
                        if FINETUNE_FEATURE_EXTRACTOR: torch.save(full_yolo_model.state_dict(), SAVE_FINETUNED_MODEL_PATH)
                    else: print(f"Validation mAP: {current_map:.4f} (Best: {best_val_map:.4f})")
                else: print("Validation mAP could not be calculated.")
                print("-------------------------------------------\n")
            # Saving logic if no validation remains same

    except Exception as e:
         print(f"\nAn error occurred during training loop: {e}")
         import traceback
         traceback.print_exc()
         print("Attempting to proceed to final evaluation...")

    print("Training/Adaptation finished.")

    # 8. Final Evaluation
    if test_dataloader:
        print("\n--- Running Final Evaluation on Test Set ---")
        best_proto_path = SAVE_PROTOTYPES_PATH if os.path.exists(SAVE_PROTOTYPES_PATH) else f"{os.path.splitext(SAVE_PROTOTYPES_PATH)[0]}_epoch{NUM_EPOCHS}.pt" # Example fallback
        best_model_path = SAVE_FINETUNED_MODEL_PATH if os.path.exists(SAVE_FINETUNED_MODEL_PATH) else f"{os.path.splitext(SAVE_FINETUNED_MODEL_PATH)[0]}_epoch{NUM_EPOCHS}.pt"

        if os.path.exists(best_proto_path):
             print(f"Loading best prototypes from: {best_proto_path}")
             loaded_protos_cpu = torch.load(best_proto_path, map_location='cpu')
             prototypes = {cls_id: p.to(DEVICE) for cls_id, p in loaded_protos_cpu.items()}
             print(f"Loaded {len(prototypes)} prototypes.")
        else:
             print("Warning: Best prototypes file not found. Using prototypes from last training state.")
             if 'prototypes' not in locals() or not prototypes: print("Error: No prototypes available for final evaluation."); exit()

        if FINETUNE_FEATURE_EXTRACTOR and os.path.exists(best_model_path):
             print(f"Loading fine-tuned model state from: {best_model_path}")
             try:
                full_yolo_model = YOLO(YOLO_MODEL_PATH) # Re-init architecture
                full_yolo_model.load_state_dict(torch.load(best_model_path, map_location=DEVICE))
                full_yolo_model.to(DEVICE)
                feature_extractor = YOLOFeatureExtractor(full_yolo_model, FEATURE_LAYER_INDEX).to(DEVICE) # Re-wrap
                print("Fine-tuned model loaded.")
             except Exception as e:
                 print(f"Error loading fine-tuned model state dict: {e}. Using last state from training.")
                 # Ensure feature_extractor still points to the model used in training
                 if 'feature_extractor' not in locals():
                     print("Error: Feature extractor not available.")
                     exit()

        elif FINETUNE_FEATURE_EXTRACTOR: print("Warning: Fine-tuned model file not found. Using last state.")
        else: print("Using original pre-trained YOLO model features.")

        try:
            test_preds, test_gts = inference_yolo_few_shot(
                full_yolo_model, feature_extractor, prototypes, test_dataloader,
                roi_align, DEVICE
            )
            print("\n--- Final Test Set Performance ---")
            evaluate_mAP(test_preds, test_gts)
        except Exception as e:
             print(f"An error occurred during final evaluation: {e}")
             import traceback
             traceback.print_exc()
    else: print("\nSkipping final evaluation as no test data was loaded.")

    print("\nScript finished.")