In [None]:
import os
import glob
import xml.etree.ElementTree as ET
import random
import math
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patches
from PIL import Image
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler # For Mixed Precision
import timm # For ViT Enhancer
from torch.optim.lr_scheduler import LambdaLR
import math
from transformers import get_scheduler

# Suppress warnings from torchvision NMS
warnings.filterwarnings("ignore", category=UserWarning)

# ===== CONFIGURATION (Updated for YOLO Head) =====
CONFIG = {
    # Fusion & Backbone Architecture
    'd_model': 512,
    'nheads': 8,
    'backbone_embed_dims': {"layer4": 512}, # "layer2": 128, "layer3": 256, 
    'backbone_feature_sizes': {"layer4": 16}, # "layer2": 64, "layer3": 32, 
    'img_size': 512,

    # YOLO Architecture
    'num_classes': 1,
    'anchors': [
        [[10, 13], [16, 30], [33, 23]],
    ],
    'yolo_strides': [32],

    # ViT Enhancer
    'vit_model_name': 'deit_small_patch16_224',

    # Loss & Optimization
    'lambda_deep_supervision': 0.1,
    'optimizer': 'SGD',                   # NEW: Specify SGD optimizer
    'lr': 0.001,                           # CHANGED: Adjusted for SGD
    'lr_backbone': 1e-5,                 # CHANGED: Adjusted for SGD
    'sgd_momentum': 0.937,                # NEW: Momentum for SGD
    'dropout_rate': 0.2,                  # NEW: Dropout rate for fusion modules
    'weight_decay': 5e-4,
    'clip_norm_head': 10.0,
    'clip_norm_backbone': 1.0,
    'loss_box_weight': 0.05,               # changed from 0.05
    'loss_obj_weight': 1.0,
    'loss_cls_weight': 0.5,
    'label_smoothing': 0.1,

    # Training Settings
    'batch_size': 8,
    'num_epochs': 100,
    'subset_size': 4000,
    'ACCUMULATION_STEPS': 4,
    'FREEZE_EPOCHS': 25,                  # CHANGED: Delayed unfreezing to epoch 50
    'early_stop_patience': 20,            # 10 for now to get rid of the overfit first
    'num_workers': 2,
    'conf_threshold': 0.25,
    'iou_threshold_nms': 0.45,
}

# ===== DATASET HANDLING =====
class LLVIPDataset(Dataset):
    """Loads LLVIP dataset pairs (RGB/Infrared) and their annotations."""
    def __init__(self, root_dir, split, subset_size=None, img_size=512):
        self.root_dir = root_dir
        self.img_size = img_size
        self.is_train = split == 'train' # Set this attribute early
        
        self.rgb_dir = os.path.join(root_dir, 'LLVIP', 'visible', split)
        self.ir_dir = os.path.join(root_dir, 'LLVIP', 'infrared', split)
        self.ann_dir = os.path.join(root_dir, 'LLVIP', 'Annotations')

        all_ids = [os.path.splitext(os.path.basename(p))[0] for p in glob.glob(os.path.join(self.rgb_dir, '*.jpg'))]
        # Filter for IDs that are purely numeric to avoid any non-image files
        self.image_ids = sorted([img_id for img_id in all_ids if img_id.isdigit()])

        if subset_size and subset_size > 0:
            self.image_ids = self.image_ids[:subset_size]
            print(f"Using subset of {len(self.image_ids)} images for split '{split}'.")
            
        # Correctly define transforms based on the split
        if self.is_train:
            self.transform_rgb = T.Compose([
                T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
                T.ToTensor(),
                T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            self.transform_rgb = T.Compose([
                T.ToTensor(),
                T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
            
        self.transform_ir = T.Compose([
            T.ToTensor(),
            T.Normalize([0.5], [0.5])
        ])
        self.resize = T.Resize((self.img_size, self.img_size))

    def __len__(self):
        return len(self.image_ids)

    def _load_annotation(self, img_id, orig_w, orig_h):
        ann_path = os.path.join(self.ann_dir, f"{img_id}.xml")
        bboxes_norm = []
        if not os.path.exists(ann_path):
            return bboxes_norm
        try:
            tree = ET.parse(ann_path)
            for obj in tree.getroot().findall('object'):
                if obj.find('name').text.lower() == 'person':
                    bbox = obj.find('bndbox')
                    xmin = float(bbox.find('xmin').text) / orig_w
                    ymin = float(bbox.find('ymin').text) / orig_h
                    xmax = float(bbox.find('xmax').text) / orig_w
                    ymax = float(bbox.find('ymax').text) / orig_h
                    bboxes_norm.append([
                        max(0.0, min(1.0, xmin)), max(0.0, min(1.0, ymin)),
                        max(0.0, min(1.0, xmax)), max(0.0, min(1.0, ymax))
                    ])
        except ET.ParseError:
            print(f"Warning: Could not parse annotation {ann_path}")
        return bboxes_norm

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        rgb_path = os.path.join(self.rgb_dir, f"{img_id}.jpg")
        ir_path = os.path.join(self.ir_dir, f"{img_id}.jpg")

        try:
            rgb_img = Image.open(rgb_path).convert('RGB')
            ir_img = Image.open(ir_path).convert('L')
            orig_w, orig_h = rgb_img.size

            bboxes = self._load_annotation(img_id, orig_w, orig_h)
            bboxes = torch.tensor(bboxes, dtype=torch.float32).view(-1, 4)

            rgb_img = self.resize(rgb_img)
            ir_img = self.resize(ir_img)

            if self.is_train and random.random() < 0.5:
                rgb_img = T.functional.hflip(rgb_img)
                ir_img = T.functional.hflip(ir_img)
                if bboxes.shape[0] > 0:
                    bboxes[:, [0, 2]] = 1.0 - bboxes[:, [2, 0]]

            return {
                'rgb': self.transform_rgb(rgb_img),
                'ir': self.transform_ir(ir_img),
                'bboxes': bboxes, # Outputting xyxy boxes
                'img_id': img_id
            }
        except Exception as e:
            print(f"Error loading image ID {img_id}: {e}")
            return None

# ===== MODEL COMPONENTS (Backbone and Fusion - Unchanged) =====

class TimmResNetBackbone(nn.Module):
    """ A wrapper for a Timm ResNet-18 model to extract multi-scale features. """
    def __init__(self, model_name='resnet18', pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(
            model_name,
            pretrained=pretrained,
            features_only=True,
            out_indices=(2, 3, 4) # Corresponds to ResNet stages 2, 3, 4
        )
        self.feature_dims = self.backbone.feature_info.channels()
        print(f"Initialized {model_name} with feature dimensions: {self.feature_dims}")

    def forward(self, x):
        features = self.backbone(x)
        out = {
            "layer2": features[0],
            "layer3": features[1],
            "layer4": features[2]
        }
        return out

class SimpleConvHead(nn.Module):
    def __init__(self, in_channels, num_classes=2):
        super().__init__()
        # A simpler head for deep supervision, focusing only on classification
        self.refine = nn.Sequential(
            nn.Conv2d(in_channels, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.class_head = nn.Conv2d(256, num_classes, 1)

    def forward(self, x):
        # Only return logits, as this is just for providing a gradient signal
        return {'pred_logits': self.class_head(self.refine(x))}

class TimmViTEnhancer(nn.Module):
    """ Applies pretrained ViT blocks to input feature maps. """
    def __init__(self, embed_dim, img_size, model_name=CONFIG['vit_model_name'], pretrained=True):
        super().__init__()
        self.vit = timm.create_model(model_name, pretrained=pretrained)
        self.img_size = img_size
        self.vit.head = nn.Identity()
        if hasattr(self.vit, 'norm'): self.vit.norm = nn.Identity()
        self.input_proj = nn.Conv2d(embed_dim, self.vit.embed_dim, kernel_size=1) if embed_dim != self.vit.embed_dim else nn.Identity()
        self.output_proj = nn.Conv2d(self.vit.embed_dim, embed_dim, kernel_size=1) if embed_dim != self.vit.embed_dim else nn.Identity()
        self.norm = nn.LayerNorm(self.vit.embed_dim)
        self.pos_embed_enhancer = nn.Parameter(torch.zeros(1, img_size*img_size, self.vit.embed_dim))
        nn.init.trunc_normal_(self.pos_embed_enhancer, std=.02)

    def forward(self, x):
        B, C_in, H, W = x.shape
        assert H == self.img_size and W == self.img_size, f"Input size mismatch: {H}x{W} vs expected {self.img_size}x{self.img_size}"
        x_proj = self.input_proj(x)
        C_vit = x_proj.shape[1]
        tokens = x_proj.flatten(2).transpose(1, 2)
        tokens = tokens + self.pos_embed_enhancer
        tokens = self.vit.blocks(tokens)
        tokens = self.norm(tokens)
        tokens_spatial = tokens.transpose(1, 2).view(B, C_vit, H, W)
        out = self.output_proj(tokens_spatial)
        return out

class BiDirectionalCrossAttentionFusion(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout_rate=CONFIG['dropout_rate']):
        super().__init__()
        self.cross_attn_rgb = nn.MultiheadAttention(embed_dim, num_heads, batch_first=False)
        self.cross_attn_ir = nn.MultiheadAttention(embed_dim, num_heads, batch_first=False)
        self.query_proj_rgb = nn.Linear(embed_dim, embed_dim); self.key_proj_ir = nn.Linear(embed_dim, embed_dim); self.value_proj_ir = nn.Linear(embed_dim, embed_dim)
        self.query_proj_ir = nn.Linear(embed_dim, embed_dim); self.key_proj_rgb = nn.Linear(embed_dim, embed_dim); self.value_proj_rgb = nn.Linear(embed_dim, embed_dim)
        ffn_dim = embed_dim * 2
        self.ffn_rgb = nn.Sequential(nn.Conv2d(embed_dim, ffn_dim, 1), nn.GELU(), nn.Dropout(dropout_rate), nn.Conv2d(ffn_dim, embed_dim, 1))
        self.ffn_ir = nn.Sequential(nn.Conv2d(embed_dim, ffn_dim, 1), nn.GELU(), nn.Dropout(dropout_rate), nn.Conv2d(ffn_dim, embed_dim, 1))
        self.norm_rgb_1 = nn.GroupNorm(1, embed_dim); self.norm_rgb_2 = nn.GroupNorm(1, embed_dim); self.norm_ir_1 = nn.GroupNorm(1, embed_dim); self.norm_ir_2 = nn.GroupNorm(1, embed_dim)
        self.gate_conv = nn.Sequential(
            nn.Conv2d(embed_dim * 2, embed_dim, 3, padding=1, bias=False), 
            nn.BatchNorm2d(embed_dim), 
            nn.ReLU(inplace=True), 
            nn.Conv2d(embed_dim, 4, 1)
        )
        """# Add a call to the new initialization method
        self._initialize_gate_bias()

    def _initialize_gate_bias(self):
        
        Initializes the bias of the final conv layer in the gate
        to give a higher initial weight to the IR features.
        
        # The 4 output channels of gate_conv correspond to weights for:
        # [fused_rgb, fused_ir, original_rgb, original_ir]
        with torch.no_grad():
            self.gate_conv[-1].bias.fill_(0.0)  # Start with a neutral bias for RGB
            self.gate_conv[-1].bias[1] = 1.0  # Give a positive bias to the fused_ir channel
            self.gate_conv[-1].bias[3] = 1.0  # Give a positive bias to the original_ir channel
            # This will cause the softmax to initially assign more weight to IR features."""
            
    def forward(self, feat_rgb, feat_ir):
        B, C, H, W = feat_rgb.shape
        def to_seq(feat): return feat.flatten(2).permute(2, 0, 1)
        def from_seq(seq): return seq.permute(1, 2, 0).view(B, C, H, W)
        rgb_seq, ir_seq = to_seq(feat_rgb), to_seq(feat_ir)
        Q_rgb = self.query_proj_rgb(rgb_seq); K_ir, V_ir = self.key_proj_ir(ir_seq), self.value_proj_ir(ir_seq)
        attn_out_rgb, _ = self.cross_attn_rgb(Q_rgb, K_ir, V_ir)
        x_rgb_res1 = feat_rgb + from_seq(attn_out_rgb); x_rgb_norm1 = self.norm_rgb_1(x_rgb_res1)
        ffn_out_rgb = self.ffn_rgb(x_rgb_norm1); fused_rgb_processed = self.norm_rgb_2(x_rgb_norm1 + ffn_out_rgb)
        Q_ir = self.query_proj_ir(ir_seq); K_rgb, V_rgb = self.key_proj_rgb(rgb_seq), self.value_proj_rgb(rgb_seq)
        attn_out_ir, _ = self.cross_attn_ir(Q_ir, K_rgb, V_rgb)
        x_ir_res1 = feat_ir + from_seq(attn_out_ir); x_ir_norm1 = self.norm_ir_1(x_ir_res1)
        ffn_out_ir = self.ffn_ir(x_ir_norm1); fused_ir_processed = self.norm_ir_2(x_ir_norm1 + ffn_out_ir)
        gate_input = torch.cat([feat_rgb, feat_ir], dim=1); spatial_weights = F.softmax(self.gate_conv(gate_input), dim=1)
        w_fused_rgb, w_fused_ir, w_res_rgb, w_res_ir = spatial_weights.split(1, dim=1)
        fused_output = (w_fused_rgb * fused_rgb_processed +
                w_fused_ir * fused_ir_processed +
                w_res_rgb * feat_rgb +
                w_res_ir * feat_ir)
        return fused_output, spatial_weights

class AxialAttention(nn.Module):
    """ Performs attention along a single axis (height or width). """
    def __init__(self, embed_dim, num_heads, axis): super().__init__(); self.axis = axis; self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=False)
    def forward(self, x):
        B, C, H, W = x.shape
        if self.axis == 1: x_perm = x.permute(3, 0, 2, 1).reshape(W * B, H, C).permute(1, 0, 2); attn_out, _ = self.attn(x_perm, x_perm, x_perm); out = attn_out.permute(1, 0, 2).reshape(W, B, H, C).permute(1, 3, 2, 0)
        else: x_perm = x.permute(2, 0, 3, 1).reshape(H * B, W, C).permute(1, 0, 2); attn_out, _ = self.attn(x_perm, x_perm, x_perm); out = attn_out.permute(1, 0, 2).reshape(H, B, W, C).permute(1, 3, 0, 2)
        return out

class AxialAttentionBlock(nn.Module):
    """ A block applying axial attention sequentially along height and width, with FFN and residuals. """
    def __init__(self, embed_dim, num_heads, dropout_rate=CONFIG['dropout_rate']):
        super().__init__()
        self.axial_attn_height = AxialAttention(embed_dim, num_heads, axis=1); self.axial_attn_width = AxialAttention(embed_dim, num_heads, axis=2)
        self.norm1 = nn.GroupNorm(1, embed_dim); ffn_dim = embed_dim * 2
        self.ffn = nn.Sequential(nn.Conv2d(embed_dim, ffn_dim, 1), nn.GELU(), nn.Conv2d(ffn_dim, embed_dim, 1)); self.norm2 = nn.GroupNorm(1, embed_dim)
    def forward(self, x):
        residual1 = x; out_h = self.axial_attn_height(x); out_w = self.axial_attn_width(out_h)
        x = self.norm1(residual1 + out_w); residual2 = x; out_ffn = self.ffn(x); x = self.norm2(residual2 + out_ffn)
        return x

class MultiLevelFusion(nn.Module):
    """ Fuses two feature maps using Cross-Attention and optionally Axial Attention. """
    def __init__(self, embed_dim, num_heads, use_axial: bool = True, dropout_rate=CONFIG['dropout_rate']):
        super().__init__()
        self.cross_attn = BiDirectionalCrossAttentionFusion(embed_dim, num_heads); self.use_axial = use_axial
        if self.use_axial: self.axial_block = AxialAttentionBlock(embed_dim, num_heads); print(f"MultiLevelFusion: Initialized WITH AxialAttention (dim={embed_dim})")
        else: self.axial_block = nn.Identity()
    def forward(self, feat1, feat2):
        # 1. Get both the fused features and the attention weights from the cross-attention module
        cross_attn_output, spatial_weights = self.cross_attn(feat1, feat2)
        
        # 2. Apply axial attention only to the feature map
        final_output = self.axial_block(cross_attn_output)
        
        # 3. Return both the final output and the passthrough weights for visualization
        return final_output, spatial_weights

class EnhancedHybridFeatureExtractor(nn.Module):
    """
    Uses TimmResNetBackbone and progressive fusion, with a corrected
    deep supervision path.
    """
    def __init__(self,
                 embed_dims=CONFIG['backbone_embed_dims'],
                 feature_sizes=CONFIG['backbone_feature_sizes'],
                 num_heads=CONFIG['nheads'],
                 d_model_head=CONFIG['d_model'],
                 use_vit_enhancer=False,
                 use_axial_fusion=True,
                 dropout_rate=CONFIG['dropout_rate']):
        super().__init__()
        print(f"Initializing PROGRESSIVE FUSION Feature Extractor: ViT Enhancer={'ON' if use_vit_enhancer else 'OFF'}")
        self.use_vit_enhancer = use_vit_enhancer
        self.scales = list(embed_dims.keys())
        
        # --- Backbone Modules ---
        self.backbone_rgb = TimmResNetBackbone()
        self.backbone_ir = TimmResNetBackbone()
        
        # Adapt IR backbone for 1-channel input
        original_conv1_weights = self.backbone_ir.backbone.conv1.weight.clone()
        new_conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        new_conv1.weight.data = original_conv1_weights.mean(dim=1, keepdim=True)
        self.backbone_ir.backbone.conv1 = new_conv1
        print("Adapted IR backbone's first convolutional layer for 1-channel input.")
        
        # --- Optional Enhancement & Intra-Modal Fusion Modules ---
        if self.use_vit_enhancer:
            self.shared_vit_enhancers = nn.ModuleDict()
            self.fusions_intra_rgb = nn.ModuleDict()
            self.fusions_intra_ir = nn.ModuleDict()
            for scale in self.scales:
                dim, size = embed_dims[scale], feature_sizes[scale]
                self.shared_vit_enhancers[scale] = TimmViTEnhancer(dim, img_size=size)
                self.fusions_intra_rgb[scale] = MultiLevelFusion(dim, num_heads, use_axial=use_axial_fusion)
                self.fusions_intra_ir[scale] = MultiLevelFusion(dim, num_heads, use_axial=use_axial_fusion)

        # --- Progressive Cross-Modal Fusion Modules ---
        self.progressive_fusions = nn.ModuleDict({
            scale: MultiLevelFusion(
                embed_dims[scale], 
                num_heads, 
                use_axial=use_axial_fusion, 
                dropout_rate=dropout_rate
            )
            for scale in self.scales
        })
        
        # --- Deep Supervision Head ---
        # The input channels should match the backbone's output channel depth
        self.deep_supervision_head = SimpleConvHead(
            in_channels=self.backbone_rgb.feature_dims[-1],
            num_classes=2
        )

    def freeze_backbone(self, freeze=True):
        """Set requires_grad for all backbone parameters."""
        print(f"Setting backbone requires_grad = {not freeze}")
        for param in self.backbone_rgb.parameters():
            param.requires_grad = not freeze
        for param in self.backbone_ir.parameters():
            param.requires_grad = not freeze

    def unfreeze_backbone_layer(self, layer_name: str):
        """Unfreezes a specific layer of the ResNet backbones."""
        if layer_name not in ["layer2", "layer3", "layer4"]:
            print(f"Warning: Invalid layer name '{layer_name}' for unfreezing.")
            return
        print(f"--- Unfreezing backbone layer: {layer_name} ---")
        for name, param in self.backbone_rgb.named_parameters():
            if layer_name in name:
                param.requires_grad = True
        for name, param in self.backbone_ir.named_parameters():
            if layer_name in name:
                param.requires_grad = True

    def forward(self, rgb, ir):
        # 1. Get the features from both backbones
        rgb_features_by_layer = self.backbone_rgb(rgb)
        ir_features_by_layer = self.backbone_ir(ir)

        # 2. Apply deep supervision
        deep_supervision_preds = self.deep_supervision_head(rgb_features_by_layer["layer4"])
        
        # 3. Perform the fusion steps
        final_fused_outputs = {}
        spatial_weights_outputs = {} # Dict to store weights

        for scale in self.scales:
            rgb_feat_raw = rgb_features_by_layer[scale]
            ir_feat_raw = ir_features_by_layer[scale]
            
            processed_rgb = rgb_feat_raw
            processed_ir = ir_feat_raw

            if self.use_vit_enhancer:
                # This part has a separate issue, see the note below
                vit_enhanced_rgb = self.shared_vit_enhancers[scale](rgb_feat_raw) + rgb_feat_raw
                vit_enhanced_ir = self.shared_vit_enhancers[scale](ir_feat_raw) + ir_feat_raw
                processed_rgb = self.fusions_intra_rgb[scale](rgb_feat_raw, vit_enhanced_rgb)
                processed_ir = self.fusions_intra_ir[scale](ir_feat_raw, vit_enhanced_ir)
            
            # This is the corrected logic for the main fusion path
            fused_feature, weights = self.progressive_fusions[scale](processed_rgb, processed_ir)
            final_fused_outputs[scale] = fused_feature
            spatial_weights_outputs[scale] = weights # Make sure 'weights' here matches the line above

        # 4. Prepare final outputs
        main_head_input = final_fused_outputs["layer4"]

        return main_head_input, deep_supervision_preds, spatial_weights_outputs

class YOLOHead(nn.Module):
    """
    A simple YOLOv3-style detection head with separate convolutional layers
    for box, objectness, and class predictions to improve stability.
    """
    def __init__(self, in_channels, num_classes, anchors, stride):
        super().__init__()
        self.num_classes = num_classes
        self.num_anchors = len(anchors)
        self.stride = stride

        # --- Create separate prediction heads for each task ---
        # Head for bounding box regression (x, y, w, h)
        self.conv_box = nn.Conv2d(in_channels, self.num_anchors * 4, kernel_size=1)
        # Head for objectness score (is an object present?)
        self.conv_obj = nn.Conv2d(in_channels, self.num_anchors * 1, kernel_size=1)
        # Head for class probabilities
        self.conv_cls = nn.Conv2d(in_channels, self.num_anchors * self.num_classes, kernel_size=1)

        # Initialize the biases for the new heads
        self._initialize_biases()

        # Register anchors as buffers
        self.register_buffer('anchors', torch.tensor(anchors).float().view(self.num_anchors, 2) / self.stride)
        self.register_buffer('anchor_grid', self.anchors.clone().view(1, -1, 1, 1, 2))

    def _initialize_biases(self):
        # Initialize the objectness head bias to encourage object prediction
        # log(p / (1-p)) -> for p=0.01, bias is approx -4.59
        self.conv_obj.bias.data.fill_(-4.59)
        
        # Initialize class head bias for better starting point
        if self.num_classes > 0:
            # log(1/num_classes) is a good starting point
            initial_bias = -math.log((1 - 0.01) / 0.01 / (self.num_classes - 1)) if self.num_classes > 1 else 0
            self.conv_cls.bias.data.fill_(initial_bias)

    def forward(self, x):
        # x is the feature map from the backbone, e.g., [B, 512, 16, 16]
        B, _, H, W = x.shape

        # Get predictions from each separate head
        pred_box = self.conv_box(x) # [B, num_anchors * 4, H, W]
        pred_obj = self.conv_obj(x) # [B, num_anchors * 1, H, W]
        pred_cls = self.conv_cls(x) # [B, num_anchors * num_classes, H, W]

        # Reshape and concatenate the predictions
        pred_box = pred_box.view(B, self.num_anchors, 4, H, W).permute(0, 3, 4, 1, 2).contiguous()
        pred_obj = pred_obj.view(B, self.num_anchors, 1, H, W).permute(0, 3, 4, 1, 2).contiguous()
        pred_cls = pred_cls.view(B, self.num_anchors, self.num_classes, H, W).permute(0, 3, 4, 1, 2).contiguous()

        # Final prediction tensor format: [B, H, W, num_anchors, 5 + num_classes]
        pred = torch.cat([pred_box, pred_obj, pred_cls], dim=-1)

        return pred

class YOLOLoss(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_classes = config['num_classes']
        self.stride = config['yolo_strides'][0]
        
        # Scale anchors to the feature map grid size
        scaled_anchors = torch.tensor(config['anchors'][0]).float() / self.stride
        self.register_buffer('anchors', scaled_anchors)
        
        self.num_anchors = len(self.anchors)
        self.balance = [4.0] 
        self.bce_cls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1.0]))
        self.bce_obj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([self.balance[0]]))
        self.box_weight = config['loss_box_weight']
        self.obj_weight = config['loss_obj_weight']
        self.cls_weight = config['loss_cls_weight']
        self.anchor_t = 4.0
        # Get label smoothing value from config
        self.label_smoothing = config.get('label_smoothing', 0.0)

    def forward(self, preds, targets):
        device = preds.device
        lbox, lobj, lcls = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)

        p = preds
        tcls, tbox, indices, anchors = self.build_targets(p, targets.to(device))

        tcls = tcls.to(device)
        tbox = tbox.to(device)
        indices = tuple(i.to(device) for i in indices)
        anchors = anchors.to(device)

        B, H, W, A, _ = p.shape
        tobj = torch.zeros_like(p[..., 4], device=device) # Target objectness

        b, a, gj, gi = indices
        if b.shape[0] > 0:
            ps = p[b, gj, gi, a] 

            # Box Loss (CIoU)
            pxy = ps[:, :2].sigmoid() * 2 - 0.5
            pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors
            pbox = torch.cat((pxy, pwh), 1)
            iou = bbox_iou(pbox, tbox, x1y1x2y2=False, CIoU=True)
            lbox += (1.0 - iou).mean()

            # Objectness target is a hard 1.0 for matched anchors
            tobj[b, gj, gi, a] = 1.0

            # Classification Loss with Label Smoothing
            if self.num_classes > 0:
                # Create target tensor with smoothed labels
                t = torch.full_like(ps[:, 5:], self.label_smoothing / (self.num_classes - 1) if self.num_classes > 1 else 0, device=device)
                t[range(b.shape[0]), tcls.long()] = 1.0 - self.label_smoothing
                lcls += self.bce_cls(ps[:, 5:], t)

        # Objectness loss for all anchors (positive and negative)
        lobj += self.bce_obj(p[..., 4], tobj)

        total_loss = self.box_weight * lbox + self.obj_weight * lobj + self.cls_weight * lcls
        return {
            'total_loss': total_loss * B,
            'loss_box': lbox * B,
            'loss_obj': lobj * B,
            'loss_cls': lcls * B
        }

    def build_targets(self, p, targets):
        na = self.num_anchors
        nt = targets.shape[0]
        tcls, tbox, indices, anch = [], [], [], []
        gain = torch.ones(7, device=targets.device)
        
        ai = torch.arange(na, device=targets.device).float().view(na, 1)
        if nt:
            targets = torch.cat((targets.repeat(na, 1, 1), ai.repeat(1, nt)[:, :, None]), 2)
        else:
            device = p.device
            return torch.tensor([], device=device, dtype=torch.long), torch.tensor([], device=device), \
                   (torch.tensor([], device=device, dtype=torch.long), torch.tensor([], device=device, dtype=torch.long), \
                    torch.tensor([], device=device, dtype=torch.long), torch.tensor([], device=device, dtype=torch.long)), \
                   torch.tensor([], device=device)

        g = 0.5
        off = torch.tensor([[0, 0], [1, 0], [0, 1], [-1, 0], [0, -1]], device=targets.device).float() * g

        anchors = self.anchors.to(targets.device)
        gain[2:6] = torch.tensor(p.shape)[[1, 0, 1, 0]]

        t = targets * gain
        if nt:
            r = t[:, :, 4:6] / anchors[:, None]
            j = torch.max(r, 1. / r).max(2)[0] < self.anchor_t
            t = t[j]

        if t.shape[0] == 0:
            device = p.device
            return torch.tensor([], device=device, dtype=torch.long), torch.tensor([], device=device), \
                   (torch.tensor([], device=device, dtype=torch.long), torch.tensor([], device=device, dtype=torch.long), \
                    torch.tensor([], device=device, dtype=torch.long), torch.tensor([], device=device, dtype=torch.long)), \
                   torch.tensor([], device=device)

        gxy = t[:, 2:4]
        gxi = gain[2:4] - gxy
        j, k = ((gxy % 1. < g) & (gxy > 1.)).T
        l, m = ((gxi % 1. < g) & (gxi > 1.)).T
        j = torch.stack((torch.ones_like(j), j, k, l, m))
        t = t.repeat((5, 1, 1))[j]
        offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]

        b, c = t[:, :2].long().T
        gxy = t[:, 2:4]
        gwh = t[:, 4:6]
        gij = (gxy - offsets).long()
        gi, gj = gij.T

        a = t[:, 6].long()
        indices.append((b, a, gj.clamp_(0, int(gain[3].item()) - 1), gi.clamp_(0, int(gain[2].item()) - 1)))
        tbox.append(torch.cat((gxy - gij, gwh), 1))
        anch.append(anchors[a])
        tcls.append(c)

        return torch.cat(tcls, 0), torch.cat(tbox, 0), indices[0], torch.cat(anch, 0)


# ===== UTILITIES (Visualization, Gradients, Dataloader, Box ops) =====

def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=-1).clamp(min=0.0, max=1.0)

def box_xyxy_to_cxcywh(x):
    x0, y0, x1, y1 = x.unbind(-1)
    b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
    return torch.stack(b, dim=-1)

def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
    """
    Returns the IoU of box1 to box2.
    box1 and box2 are expected to be of shape [N, 4] for element-wise calculation.
    """
    # Get the coordinates of bounding boxes
    if x1y1x2y2:
        # Unbind the last dimension to get (x1, y1, x2, y2)
        b1_x1, b1_y1, b1_x2, b1_y2 = box1.unbind(-1)
        b2_x1, b2_y1, b2_x2, b2_y2 = box2.unbind(-1)
    else:  # transform from cxcywh to xyxy
        # Use slicing for batch operations
        b1_x1, b1_y1 = box1[:, 0] - box1[:, 2] / 2, box1[:, 1] - box1[:, 3] / 2
        b1_x2, b1_y2 = box1[:, 0] + box1[:, 2] / 2, box1[:, 1] + box1[:, 3] / 2
        b2_x1, b2_y1 = box2[:, 0] - box2[:, 2] / 2, box2[:, 1] - box2[:, 3] / 2
        b2_x2, b2_y2 = box2[:, 0] + box2[:, 2] / 2, box2[:, 1] + box2[:, 3] / 2

    # Intersection area
    inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
            (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)

    # Union Area
    w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
    w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
    union = w1 * h1 + w2 * h2 - inter + eps

    iou = inter / union
    if CIoU or DIoU or GIoU:
        cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1)  # convex width
        ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1)  # convex height
        if CIoU or DIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
            c2 = cw ** 2 + ch ** 2 + eps  # convex diagonal squared
            rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
                    (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4  # center distance squared
            if DIoU:
                return iou - rho2 / c2  # DIoU
            elif CIoU:  # Complete IoU
                v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
                with torch.no_grad():
                    alpha = v / (v - iou + (1 + eps))
                return iou - (rho2 / c2 + v * alpha)  # CIoU
        else:  # GIoU https://arxiv.org/pdf/1902.09630.pdf
            c_area = cw * ch + eps  # convex area
            return iou - (c_area - union) / c_area  # GIoU
    return iou

def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, max_det=300):
    """Performs Non-Maximum Suppression (NMS) on inference results."""
    # prediction: [B, H*W*num_anchors, 5+num_classes]
    nc = prediction.shape[2] - 5  # number of classes
    xc = prediction[..., 4] > conf_thres  # candidates

    output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
    for xi, x in enumerate(prediction):  # image index, image inference
        x = x[xc[xi]]  # confidence

        if not x.shape[0]:
            continue

        box = box_cxcywh_to_xyxy(x[:, :4])
        if multi_label:
            i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
        else:  # best class only
            conf, j = x[:, 5:].max(1, keepdim=True)
            x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]

        if not x.shape[0]:
            continue

        c = x[:, 5:6] * (0 if agnostic else 4096)  # classes
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
        if i.shape[0] > max_det:  # limit detections
            i = i[:max_det]

        output[xi] = x[i]

    return output

def prepare_deep_supervision_targets(targets, feature_map_size, device):
    """ Creates target tensors for the dense deep supervision head. """
    B = len(targets)
    H = W = feature_map_size
    
    # Create tensors directly on the target device
    target_labels = torch.full((B, H, W), 0, dtype=torch.long, device=device)
    target_boxes = torch.zeros((B, 4, H, W), dtype=torch.float32, device=device)

    for i in range(B):
        # It's safer to work with CPU tensors for the looping logic
        gt_boxes_xyxy = targets[i]['gt_boxes_xyxy'].cpu()
        if gt_boxes_xyxy.shape[0] == 0:
            continue
            
        gt_boxes_grid = gt_boxes_xyxy * H
        gt_boxes_cxcywh = box_xyxy_to_cxcywh(gt_boxes_xyxy)
        for j, box_grid in enumerate(gt_boxes_grid):
            xmin, ymin, xmax, ymax = box_grid.long()
            xmin, ymin = max(0, xmin), max(0, ymin)
            xmax, ymax = min(W - 1, xmax), min(H - 1, ymax)
            if xmin > xmax or ymin > ymax: continue
            
            # Assign positive class and box targets to the tensors already on the correct device
            target_labels[i, ymin:ymax+1, xmin:xmax+1] = 1 # Use label 1 for 'person'
            target_boxes[i, :, ymin:ymax+1, xmin:xmax+1] = gt_boxes_cxcywh[j].unsqueeze(-1).unsqueeze(-1)

    return target_labels, target_boxes

def visualize_boxes(rgb_tensor, pred_boxes_xyxy, gt_boxes_xyxy, score_thresh=0.5, denormalize=True, title_suffix=""):
    """ Displays an image with predicted and ground truth boxes. """
    if denormalize:
        mean=np.array([0.485,0.456,0.406]); std=np.array([0.229,0.224,0.225])
        img = np.clip(((rgb_tensor.cpu().numpy().transpose(1,2,0)*std)+mean),0,1)
    else: img = rgb_tensor.cpu().numpy().transpose(1,2,0)
    H,W,_ = img.shape; fig,ax = plt.subplots(figsize=(8,8)); ax.imshow(img); ax.axis('off')
    # Draw predicted boxes (red)
    for box in pred_boxes_xyxy:
        score = box[4]
        if score > score_thresh:
            xmin,ymin,xmax,ymax = box[:4]*torch.tensor([W,H,W,H],device=box.device)
            rect=patches.Rectangle((xmin.item(),ymin.item()),(xmax-xmin).item(),(ymax-ymin).item(),lw=2,edgecolor='r',facecolor='none')
            ax.add_patch(rect); ax.text(xmin.item(),ymin.item()-5,f'{score:.2f}',color='r',fontsize=8,bbox=dict(facecolor='white',alpha=0.7,pad=0.1))
    # Draw ground truth boxes (green)
    for box in gt_boxes_xyxy.cpu():
        xmin,ymin,xmax,ymax = box*torch.tensor([W,H,W,H],device=box.device)
        rect=patches.Rectangle((xmin.item(),ymin.item()),(xmax-xmin).item(),(ymax-ymin).item(),lw=2,edgecolor='g',facecolor='none')
        ax.add_patch(rect)
    ax.set_title(f"Pred(r,>{score_thresh}) vs GT(g) {title_suffix}"); plt.show()

def print_grad_norms(model, name):
    """ Prints total L2 norm of gradients for named parameters. """
    norm_sq=0; count=0; model_ = model.module if isinstance(model,nn.DataParallel) else model
    for _,p in model_.named_parameters():
        if p.grad is not None and p.requires_grad: norm_sq+=p.grad.norm(2).item()**2; count+=1
    if count>0: print(f"--> {name} grad norm: {norm_sq**0.5:.4f} ({count} params with grad)")
    else: print(f"--> {name}: No parameters with gradients found.")

def custom_collate_fn(batch):
    """
    Custom collate function to handle variable-size bounding box tensors
    and convert them to the format required by the YOLO loss function.
    """
    batch = [item for item in batch if item is not None]
    if not batch: return None

    collated = {}
    # Collate standard tensors (rgb, ir)
    collated['rgb'] = torch.stack([d['rgb'] for d in batch])
    collated['ir'] = torch.stack([d['ir'] for d in batch])
    # Collate other data as lists
    collated['img_id'] = [d['img_id'] for d in batch]
    collated['gt_boxes_xyxy'] = [d['bboxes'] for d in batch] # Keep original xyxy for visualization

    # --- Create YOLO targets ---
    # This is the required format for YOLOLoss: [batch_idx, class_idx, cx, cy, w, h]
    yolo_targets_list = []
    for i, d in enumerate(batch):
        bboxes_xyxy = d['bboxes']
        if bboxes_xyxy.shape[0] > 0:
            # Convert from xyxy to cxcywh
            bboxes_cxcywh = box_xyxy_to_cxcywh(bboxes_xyxy)
            
            # Create the target tensor for this image
            num_boxes = bboxes_cxcywh.shape[0]
            # [batch_idx, class_idx, cx, cy, w, h]
            targets = torch.zeros((num_boxes, 6))
            targets[:, 0] = i # Batch index
            targets[:, 1] = 0 # Class index (0 for 'person')
            targets[:, 2:] = bboxes_cxcywh
            yolo_targets_list.append(targets)

    # Concatenate all targets from the batch into a single tensor
    if yolo_targets_list:
        collated['yolo_targets'] = torch.cat(yolo_targets_list, 0)
    else:
        collated['yolo_targets'] = torch.zeros((0, 6))

    return collated

def get_dataloader(root,split,bs,subset=None,img_sz=512,workers=2):
    dset=LLVIPDataset(root,split,subset_size=subset,img_size=img_sz)
    loader=DataLoader(dset,batch_size=bs,shuffle=(split=='train'),num_workers=workers,collate_fn=custom_collate_fn,pin_memory=True)
    return loader

def visualize_attention_weights(rgb_tensor, ir_tensor, attention_weights, epoch_num, score_thresh=0.5):
    """
    Visualizes the four spatial attention weights from the fusion module in a proper 2x3 grid.
    """
    # Denormalize RGB for viewing
    mean_rgb = np.array([0.485, 0.456, 0.406])
    std_rgb = np.array([0.229, 0.224, 0.225])
    rgb_img = np.clip(((rgb_tensor.cpu().numpy().transpose(1, 2, 0) * std_rgb) + mean_rgb), 0, 1)

    # Denormalize IR for viewing (assuming [0.5], [0.5] normalization)
    ir_img = np.clip(ir_tensor.cpu().numpy().squeeze() * 0.5 + 0.5, 0, 1)

    # The weights tensor has shape [4, H, W]
    weights = attention_weights.cpu().numpy()
    weight_titles = [
        'Weight on Fused RGB', 'Weight on Fused IR',
        'Weight on Original RGB', 'Weight on Original IR'
    ]

    # --- CORRECTED PLOTTING LOGIC ---
    fig, axes = plt.subplots(2, 3, figsize=(20, 12)) # Create a 2x3 grid
    fig.suptitle(f'Spatial Attention Weights - Epoch {epoch_num}', fontsize=20)
    
    # Flatten the axes array for easy indexing
    ax = axes.flat

    # Plot Original Images in the first two slots
    ax[0].imshow(rgb_img)
    ax[0].set_title("Original RGB")
    ax[0].axis('off')

    ax[1].imshow(ir_img, cmap='gray')
    ax[1].set_title("Original IR")
    ax[1].axis('off')

    # Plot the 4 Weight Heatmaps in the remaining slots
    for i in range(4):
        # Heatmaps will be placed in ax[2], ax[3], ax[4], and ax[5]
        current_ax = ax[i + 2] 
        
        # Upsample the heatmap for better visualization
        heatmap = torch.from_numpy(weights[i]).unsqueeze(0).unsqueeze(0)
        heatmap_upsampled = F.interpolate(heatmap, size=rgb_img.shape[:2], mode='bilinear', align_corners=False)
        im = current_ax.imshow(heatmap_upsampled.squeeze().numpy(), cmap='viridis')
        
        current_ax.set_title(weight_titles[i])
        current_ax.axis('off')
        fig.colorbar(im, ax=current_ax, fraction=0.046, pad=0.04)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig(f"attention_weights_epoch_{epoch_num}.png")
    plt.show()
    print(f"Saved attention weight visualization for epoch {epoch_num}.")
    
from torchmetrics.detection import MeanAveragePrecision

@torch.no_grad()
def evaluate(feature_extractor, yolo_head, criterion, val_loader, device, return_sample_for_vis=False):
    feature_extractor.eval(); yolo_head.eval(); criterion.eval()
    run_loss = 0
    vis_sample = None
    vis_idx = -1
    if return_sample_for_vis and len(val_loader) > 0:
        vis_idx = random.randint(0, len(val_loader) - 1)

    # 1. Instantiate the metric object
    # It will calculate mAP, mAP_50, mAP_75, etc.
    metric = MeanAveragePrecision(box_format='xyxy').to(device)

    for i, data in enumerate(val_loader):
        if data is None: continue
        rgb = data['rgb'].to(device, non_blocking=True)
        ir = data['ir'].to(device, non_blocking=True)
        yolo_targets = data['yolo_targets'].to(device)
        gt_boxes_xyxy_list = data['gt_boxes_xyxy'] # List of tensors on CPU

        # --- Forward pass and loss calculation (remains the same) ---
        main_feats, _, spatial_weights = feature_extractor(rgb, ir)
        preds_raw = yolo_head(main_feats)
        if yolo_targets.shape[0] > 0:
            loss_dict = criterion(preds_raw, yolo_targets)
            run_loss += loss_dict['total_loss'].item()

        # --- Decode predictions to get final boxes (same logic as for visualization) ---
        yolo_head_module = yolo_head.module if isinstance(yolo_head, nn.DataParallel) else yolo_head
        B, H, W, A, C = preds_raw.shape
        preds_reshaped = preds_raw.view(B, H * W * A, C)
        grid = torch.meshgrid(torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij')
        grid = torch.stack(grid, dim=-1).view(1, H, W, 1, 2).repeat(B, 1, 1, A, 1)
        anchor_grid = yolo_head_module.anchor_grid.to(device).view(1, 1, 1, A, 2).repeat(B, H, W, 1, 1)
        pred_xy = (preds_reshaped[..., 0:2].sigmoid() * 2 - 0.5 + grid.view(B, -1, 2)) * yolo_head_module.stride
        pred_wh = (preds_reshaped[..., 2:4].sigmoid() * 2)**2 * anchor_grid.view(B, -1, 2) * yolo_head_module.stride
        pred_cxcywh = torch.cat([pred_xy, pred_wh], dim=-1)
        pred_cxcywh[..., 0::2] /= CONFIG['img_size']; pred_cxcywh[..., 1::2] /= CONFIG['img_size']
        pred_for_nms = torch.cat((pred_cxcywh, preds_reshaped[..., 4:5].sigmoid(), preds_reshaped[..., 5:].sigmoid()), dim=-1)
        final_preds_list = non_max_suppression(pred_for_nms, conf_thres=CONFIG['conf_threshold'], iou_thres=CONFIG['iou_threshold_nms'])

        # 2. Format predictions and ground truths for the metric
        preds_formatted = []
        gts_formatted = []
        for idx in range(B): # Iterate through each image in the batch
            pred_boxes_scores = final_preds_list[idx] # Tensor of shape [N, 6] -> x1,y1,x2,y2,conf,class
            preds_formatted.append({
                "boxes": pred_boxes_scores[:, :4],
                "scores": pred_boxes_scores[:, 4],
                "labels": pred_boxes_scores[:, 5].int()
            })

            gt_boxes_for_img = gt_boxes_xyxy_list[idx].to(device)
            gts_formatted.append({
                "boxes": gt_boxes_for_img,
                "labels": torch.zeros(gt_boxes_for_img.shape[0], dtype=torch.int, device=device) # All labels are 0
            })

        # 3. Update the metric with the current batch's data
        metric.update(preds_formatted, gts_formatted)

        if i == vis_idx:
            vis_sample = {
                'rgb': data['rgb'],
                'ir': data['ir'],
                'gt_boxes_xyxy': data['gt_boxes_xyxy'],
                'final_preds': final_preds_list,
                'attention_weights': spatial_weights['layer4'][0].detach().cpu() # Add this line
            }
    # 4. Compute the final metrics over the entire validation set
    avg_loss = run_loss / len(val_loader) if len(val_loader) > 0 else 0
    try:
        map_results = {k: v.item() for k, v in metric.compute().items()}
    except Exception as e:
        print(f"Could not compute mAP, likely no detections or GT boxes found in validation set. Error: {e}")
        map_results = {k: 0.0 for k in ['map', 'map_50', 'map_75', 'map_large', 'map_medium', 'map_small']}

    if return_sample_for_vis:
        return avg_loss, map_results, vis_sample
    else:
        return avg_loss, map_results, None
        
    
def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu"); print(f"Using device: {device}")
    seed=42; random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

    data_root = '/kaggle/input/llvip-data'
    if not os.path.exists(os.path.join(data_root, 'LLVIP')):
        print(f"Error: LLVIP data not found at {data_root}. Please check path."); return

    train_loader = get_dataloader(data_root,'train',CONFIG['batch_size'],CONFIG['subset_size'],CONFIG['img_size'],CONFIG['num_workers'])
    val_loader = get_dataloader(data_root, 'test', CONFIG['batch_size'], CONFIG['subset_size'] // 4, CONFIG['img_size'], CONFIG['num_workers'])
    print(f"Validation loader created with {len(val_loader.dataset)} samples.")

    # Initialize Models
    feature_extractor = EnhancedHybridFeatureExtractor(use_vit_enhancer=False, use_axial_fusion=True, dropout_rate=CONFIG['dropout_rate']).to(device)
    yolo_head = YOLOHead(
        in_channels=CONFIG['d_model'],
        num_classes=CONFIG['num_classes'],
        anchors=CONFIG['anchors'][0],
        stride=CONFIG['yolo_strides'][0]
    ).to(device)

    if torch.cuda.is_available() and torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs with nn.DataParallel.")
        feature_extractor = nn.DataParallel(feature_extractor); yolo_head = nn.DataParallel(yolo_head)
        model_fe_opt = feature_extractor.module; model_yolo_opt = yolo_head.module
    else:
        model_fe_opt = feature_extractor; model_yolo_opt = yolo_head

    FREEZE_EPOCHS = CONFIG['FREEZE_EPOCHS']
    STABILIZATION_EPOCHS = 5
    STABILIZATION_FACTOR = 100.0 # Reduce LR by 100x
    model_fe_opt.freeze_backbone(freeze=True)
    print(f"Backbone is FULLY FROZEN for the first {FREEZE_EPOCHS} epochs.")

    # Loss Criterion
    criterion = YOLOLoss(CONFIG).to(device)

    # Optimizer
    param_dicts = [
        {"params": [p for n, p in model_fe_opt.named_parameters() if "backbone" in n and p.requires_grad], "lr": CONFIG['lr_backbone'], "name": "backbone"},
        {"params": [p for n, p in model_fe_opt.named_parameters() if "backbone" not in n and p.requires_grad], "lr": CONFIG['lr'], "name": "fusion_head"},
        {"params": model_yolo_opt.parameters(), "lr": CONFIG['lr'], "name": "yolo_head"},
    ]
    #optimizer = torch.optim.AdamW(param_dicts, weight_decay=CONFIG['weight_decay'])
    optimizer = torch.optim.SGD(
        param_dicts,
        momentum=CONFIG['sgd_momentum'],
        nesterov=True, # Nesterov momentum is often beneficial
        weight_decay=CONFIG['weight_decay']
    )
    # Calculate the total number of optimizer steps
    num_update_steps_per_epoch = math.ceil(len(train_loader) / CONFIG['ACCUMULATION_STEPS'])
    num_training_steps = CONFIG['num_epochs'] * num_update_steps_per_epoch
    # Set warmup to be the first epoch
    num_warmup_steps = 1 * num_update_steps_per_epoch
    # --- CHANGED: Use CosineAnnealingLR for a better learning rate schedule ---
    #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['num_epochs'], eta_min=1e-7)
    # Create the scheduler with warmup
    scheduler = get_scheduler(
        "cosine",  # Use a cosine decay curve after warmup
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps
    )
    print(f"Scheduler created. Total steps: {num_training_steps}, Warmup steps: {num_warmup_steps}")
    scaler = GradScaler(enabled=torch.cuda.is_available())
    best_loss=float('inf'); best_map = 0.0; no_improve_epochs=0; losses_hist=[]; val_losses_hist=[]
    ACCUMULATION_STEPS = CONFIG['ACCUMULATION_STEPS']

    print(f"Starting training for {CONFIG['num_epochs']} epochs.")
    print(f"Effective batch size: {CONFIG['batch_size'] * ACCUMULATION_STEPS}")

    for epoch in range(CONFIG['num_epochs']):
        if epoch == FREEZE_EPOCHS: model_fe_opt.unfreeze_backbone_layer("layer4")
        elif epoch == FREEZE_EPOCHS+20: model_fe_opt.unfreeze_backbone_layer("layer3")
        elif epoch == FREEZE_EPOCHS+50: model_fe_opt.unfreeze_backbone_layer("layer2")

        feature_extractor.train(); yolo_head.train(); criterion.train()
        run_loss=0; batch_count=0

        for batch_idx, data in enumerate(train_loader):
            if data is None: continue
            rgb=data['rgb'].to(device, non_blocking=True); ir=data['ir'].to(device, non_blocking=True)
            yolo_targets = data['yolo_targets'].to(device)

            if yolo_targets.shape[0] == 0:
                continue

            with torch.amp.autocast(device_type="cuda", enabled=torch.cuda.is_available()):
                main_feats, deep_supervision_preds, spatial_weights  = feature_extractor(rgb, ir)
                preds = yolo_head(main_feats)
                loss_dict = criterion(preds, yolo_targets)
                total_loss_yolo = loss_dict['total_loss']
                total_loss = total_loss_yolo

                if epoch >= CONFIG['FREEZE_EPOCHS']:
                    # Calculate Deep Supervision Loss
                    ds_targets_list = [{'gt_boxes_xyxy': b} for b in data['gt_boxes_xyxy']]
                    
                    # Pass the correct variable and the device to the function
                    ds_target_labels, _ = prepare_deep_supervision_targets(
                        ds_targets_list, 
                        CONFIG['backbone_feature_sizes']["layer4"], 
                        device
                    )

                    loss_ds_ce = F.cross_entropy(deep_supervision_preds['pred_logits'], ds_target_labels)
                    total_loss = total_loss + CONFIG['lambda_deep_supervision'] * loss_ds_ce


                if not math.isfinite(total_loss.item()):
                    print(f"ERROR: Non-finite loss {total_loss.item()} at batch {batch_idx}. Skipping."); continue

                loss = total_loss / ACCUMULATION_STEPS
                scaler.scale(loss).backward()

            if (batch_idx + 1) % ACCUMULATION_STEPS == 0:
                if scaler.is_enabled():
                    scaler.unscale_(optimizer)
                    
                optimizer_step = (batch_idx + 1) // ACCUMULATION_STEPS
                if optimizer_step % 20 == 0: # Print every 50 optimizer steps
                    print("\n--- Before Clipping ---")
                    print_grad_norms(feature_extractor, "FeatureExtractor")
                    print_grad_norms(yolo_head, "YOLOHead")
                torch.nn.utils.clip_grad_norm_(optimizer.param_groups[0]['params'], CONFIG['clip_norm_backbone'])
                torch.nn.utils.clip_grad_norm_(optimizer.param_groups[1]['params'], CONFIG['clip_norm_backbone'])
                torch.nn.utils.clip_grad_norm_(optimizer.param_groups[2]['params'], CONFIG['clip_norm_head'])
                
                scaler.step(optimizer)
                scaler.update()
                scheduler.step() # new scheduler
                # This happens *after* the scheduler step to ensure our value is the final one.
                if  (CONFIG['FREEZE_EPOCHS'] <= epoch < CONFIG['FREEZE_EPOCHS'] + STABILIZATION_EPOCHS) or \
                    (CONFIG['FREEZE_EPOCHS'] + 20 <= epoch < CONFIG['FREEZE_EPOCHS'] + 20 + STABILIZATION_EPOCHS) or \
                    (CONFIG['FREEZE_EPOCHS'] + 50 <= epoch < CONFIG['FREEZE_EPOCHS'] + 50 + STABILIZATION_EPOCHS):
                    
                    # Get the learning rate that the scheduler just set for the backbone
                    scheduled_lr = scheduler.get_last_lr()[0] 
                    
                    # Override it with our much smaller, stabilized value
                    optimizer.param_groups[0]['lr'] = scheduled_lr / STABILIZATION_FACTOR

                optimizer.zero_grad(set_to_none=True)

            run_loss += total_loss.item(); batch_count += 1
            if len(train_loader)>0 and (batch_idx%100==0 or batch_idx==len(train_loader)-1):
                lr_bb, lr_head_fusion, lr_head_yolo = optimizer.param_groups[0]['lr'], optimizer.param_groups[1]['lr'], optimizer.param_groups[2]['lr']
                loss_box_val = loss_dict['loss_box'].item()
                loss_obj_val = loss_dict['loss_obj'].item()
                loss_cls_val = loss_dict['loss_cls'].item()
                log_msg = (f"[Epoch {epoch+1}/{CONFIG['num_epochs']} Batch {batch_idx}/{len(train_loader)}] "
                           f"Total Loss:{total_loss.item():.4f} | "
                           f"Box: {loss_box_val:.4f}, Obj: {loss_obj_val:.4f}, Cls: {loss_cls_val:.4f}")
                
                if epoch >= FREEZE_EPOCHS:
                    log_msg += f", DS: {loss_ds_ce:.4f}" # Add DS loss to the log
                
                log_msg += (f" | LRs -> BB: {lr_bb:.2e}, Fusion: {lr_head_fusion:.2e}, YOLO: {lr_head_yolo:.2e}")
                print(log_msg)

        # Step the scheduler at the end of each epoch
        #scheduler.step()

        avg_loss = run_loss / batch_count if batch_count > 0 else float('inf')
        losses_hist.append(avg_loss)
        print(f"\n--- Epoch {epoch+1} Finished --- Average Training Loss: {avg_loss:.4f} ---")

        is_vis_epoch = (epoch + 1) % 1 == 0 or epoch == CONFIG['num_epochs'] - 1 # Visualize every 5 epochs
        val_loss, val_metrics, vis_val_sample = evaluate(feature_extractor, yolo_head, criterion, val_loader, device, return_sample_for_vis=is_vis_epoch)
        # Print the results
        print(f"--- Validation Loss: {val_loss:.4f} ---")
        print(f"--- Validation mAP@.50:.95: {val_metrics['map']:.4f} | mAP@.50: {val_metrics['map_50']:.4f} | mAP@.75: {val_metrics['map_75']:.4f} ---")
        val_losses_hist.append(val_loss)

        """if 'data' in locals():
            print("Generating training visualization...")
            feature_extractor.eval(); yolo_head.eval()
            with torch.no_grad():
                try:
                    # Run the last training batch through the model
                    main_feats, _, _ = feature_extractor(data['rgb'].to(device), data['ir'].to(device))
                    preds = yolo_head(main_feats)
                    
                    # Decode predictions for visualization
                    yolo_head_module = yolo_head.module if isinstance(yolo_head, nn.DataParallel) else yolo_head
                    B, H, W, A, C = preds.shape
                    preds_reshaped = preds.view(B, H * W * A, C)
                    grid = torch.meshgrid(torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij')
                    grid = torch.stack(grid, dim=-1).view(1, H, W, 1, 2).repeat(B, 1, 1, A, 1)
                    anchor_grid = yolo_head_module.anchor_grid.to(device).view(1, 1, 1, A, 2).repeat(B, H, W, 1, 1)
                    pred_xy = (preds_reshaped[..., 0:2].sigmoid() * 2 - 0.5 + grid.view(B, -1, 2)) * yolo_head_module.stride
                    pred_wh = (preds_reshaped[..., 2:4].sigmoid() * 2)**2 * anchor_grid.view(B, -1, 2) * yolo_head_module.stride
                    pred_cxcywh = torch.cat([pred_xy, pred_wh], dim=-1)
                    pred_cxcywh[..., 0::2] /= CONFIG['img_size']; pred_cxcywh[..., 1::2] /= CONFIG['img_size']
                    pred_for_nms = torch.cat((pred_cxcywh, preds_reshaped[..., 4:5].sigmoid(), preds_reshaped[..., 5:].sigmoid()), dim=-1)
                    final_preds_train = non_max_suppression(pred_for_nms, conf_thres=CONFIG['conf_threshold'], iou_thres=CONFIG['iou_threshold_nms'])

                    # Visualize the first image in the batch
                    visualize_boxes(data['rgb'][0].detach().cpu(), final_preds_train[0].detach().cpu(), data['gt_boxes_xyxy'][0].detach().cpu(), score_thresh=CONFIG['conf_threshold'], title_suffix=f"TRAINING Sample - Epoch {epoch+1}")
                    # Prepare the IR image for visualization (convert 1-channel to 3-channel)
                    ir_img_vis = data['ir'][0].detach().cpu().repeat(3, 1, 1)
                    # Visualize the IR image with the same predictions
                    visualize_boxes(ir_img_vis, final_preds_train[0].detach().cpu(), data['gt_boxes_xyxy'][0].detach().cpu(), score_thresh=CONFIG['conf_threshold'], denormalize=False, title_suffix=f"TRAINING IR Sample - Epoch {epoch+1}")
                except Exception as e:
                    print(f"Training visualization failed: {e}")"""
                    
        if is_vis_epoch and vis_val_sample:
            print("Generating validation visualization...")
            try:
                vis_rgb_val = vis_val_sample['rgb'][0].detach().cpu()
                vis_gt_boxes_val = vis_val_sample['gt_boxes_xyxy'][0].detach().cpu()
                vis_pred_boxes_val = vis_val_sample['final_preds'][0].detach().cpu()
                visualize_boxes(vis_rgb_val, vis_pred_boxes_val, vis_gt_boxes_val, score_thresh=CONFIG['conf_threshold'], title_suffix=f"VALIDATION Sample - Epoch {epoch+1}")
            except Exception as e:
                print(f"Validation visualization failed: {e}")
            print("Generating attention weight visualization...")
            try:
                visualize_attention_weights(
                    rgb_tensor=vis_val_sample['rgb'][0],
                    ir_tensor=vis_val_sample['ir'][0],
                    attention_weights=vis_val_sample['attention_weights'],
                    epoch_num=epoch + 1
                )
            except Exception as e:
                print(f"Attention visualization failed: {e}")

        if val_loss < best_loss:
            print(f"Validation loss improved ({best_loss:.4f} --> {val_loss:.4f}). Saving model...")
            best_loss = val_loss
            no_improve_epochs = 0
            save_dir = "./checkpoints_yolo"
            os.makedirs(save_dir, exist_ok=True)
            torch.save(model_fe_opt.state_dict(), os.path.join(save_dir, 'feature_extractor_best.pth'))
            torch.save(model_yolo_opt.state_dict(), os.path.join(save_dir, 'yolo_head_best.pth'))
        else:
            no_improve_epochs += 1
            print(f"Validation loss did not improve. Epochs w/o improvement: {no_improve_epochs}/{CONFIG['early_stop_patience']}")
            if no_improve_epochs >= CONFIG['early_stop_patience']:
                print("Early stopping triggered."); break
        # --- MODIFICATION: Save model based on best mAP@.50 ---
        current_map50 = val_metrics['map_50']
        if current_map50 > best_map:
            print(f"Validation mAP@.50 improved ({best_map:.4f}).")
            best_map = current_map50
            """no_improve_epochs = 0
            save_dir = "./checkpoints_yolo"
            os.makedirs(save_dir, exist_ok=True)
            torch.save(model_fe_opt.state_dict(), os.path.join(save_dir, 'feature_extractor_best.pth'))
            torch.save(model_yolo_opt.state_dict(), os.path.join(save_dir, 'yolo_head_best.pth'))
        else:
            no_improve_epochs += 1
            print(f"Validation mAP@.50 did not improve. Epochs w/o improvement: {no_improve_epochs}/{CONFIG['early_stop_patience']}")
            if no_improve_epochs >= CONFIG['early_stop_patience']:
                print("Early stopping triggered."); break"""
                
    if losses_hist and val_losses_hist:
        plt.figure(figsize=(12, 6)); plt.plot(losses_hist, 'b-o', label='Training Loss'); plt.plot(val_losses_hist, 'r-o', label='Validation Loss')
        plt.title("Training & Validation Loss Progression"); plt.xlabel("Epoch"); plt.ylabel("Average Loss"); plt.legend(); plt.grid(True)
        plt.savefig("training_validation_loss_curve_yolo.png"); print("Saved training and validation loss curve.")


In [None]:
if __name__ == '__main__':
    train()