# üü£ HVAC-Specific SAM Fine-Tuning Pipeline (v2 - Advanced)
## üîß Enhanced with Multi-Prompt Training, Resumability, and T4 Optimizations

This notebook has been upgraded to include advanced features for creating a more robust and flexible model, inspired by recent research (e.g., SAM-PAR) and optimized for T4 GPU runtime.

### Phase 1: Initial Setup and T4 Optimizations

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Install required dependencies
!pip install torch torchvision --quiet
!pip install opencv-python pycocotools matplotlib onnxruntime onnx --quiet
!pip install git+https://github.com/facebookresearch/segment-anything.git --quiet

In [None]:
import os
import zipfile
import json
import numpy as np
import matplotlib.pyplot as plt
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
from tqdm import tqdm
from statistics import mean
from pycocotools.coco import COCO
from pycocotools import mask as mask_utils
import random
import gc
import contextlib
from torch.cuda.amp import autocast, GradScaler

# SAM imports
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
from segment_anything.utils.transforms import ResizeLongestSide

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Clear memory
torch.cuda.empty_cache()
gc.collect()

In [None]:
# LayerNorm2d from SAM repository
class LayerNorm2d(nn.Module):
    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        x = self.weight[:, None, None] * x + self.bias[:, None, None]
        return x

In [None]:
# Custom HVAC-Optimized Decoder (Blog #2 Inspired)
class HvacOptimizedDecoder(nn.Module):
    """Memory-efficient decoder inspired by Blog #2, optimized for HVAC components"""
    def __init__(self, sam_encoder):
        super().__init__()
        self.sam_encoder = sam_encoder
        self.dropout = nn.Dropout(p=0.1)
        
        # Transposed convolutions for upsampling (more efficient than SAM's decoder)
        self.conv1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.norm1 = LayerNorm2d(128)
        self.conv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.norm2 = LayerNorm2d(64)
        self.conv3 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.norm3 = LayerNorm2d(32)
        self.conv4 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
        self.norm4 = LayerNorm2d(16)
        self.conv5 = nn.Conv2d(16, 1, kernel_size=1)  # Final 1x1 conv for binary mask
    
    def forward(self, x):
        x = self.sam_encoder(x)
        x = torch.nn.functional.relu(self.norm1(self.conv1(x)))
        x = self.dropout(x)
        x = torch.nn.functional.relu(self.norm2(self.conv2(x)))
        x = torch.nn.functional.relu(self.norm3(self.conv3(x)))
        x = self.dropout(x)
        x = torch.nn.functional.relu(self.norm4(self.conv4(x)))
        x = torch.sigmoid(self.conv5(x))
        return x

### Phase 2: Configuration and Dataset Preparation

In [None]:
import torch
from pathlib import Path

CONFIG = {
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',

    # --- STARTING MODEL ---
    'model_path': '/content/sam_vit_h_4b8939.pth',

    # --- DATASET ---
    'dataset_zip_path': '/content/drive/MyDrive/hvac_dataset_coco.zip',
    'unzip_path': '/content/drive/MyDrive/hvac_dataset_coco',
    'annotations_file_name': '_annotations.coco.json',

    # --- OUTPUT & CHECKPOINTING ---
    'output_dir': '/content/drive/MyDrive/sam_finetuning_results',
    'best_model_save_path': '/content/drive/MyDrive/sam_finetuning_results/best_model_multiprompt_v1.pth',
    'latest_checkpoint_save_path': '/content/drive/MyDrive/sam_finetuning_results/latest_checkpoint_multiprompt_v1.pth',
    'resume_training': False, # SET TO TRUE TO RESUME FROM 'latest_checkpoint_save_path'

    # --- PROMPT ENGINEERING STRATEGY ---
    'prompt_strategy': 'multi_prompt',
    'bbox_noise_factor': 0.1,

    # --- MODEL & TRAINING HYPERPARAMETERS (T4 Optimized) ---
    'model_type': 'vit_h',
    'image_size': 1024,
    'batch_size': 1,  # Critical for T4
    'num_workers': 2, # Reduced for Colab RAM
    'num_epochs': 1,
    'learning_rate': 1e-4,
    'weight_decay': 0,
    'early_stopping_patience': 10,
    'checkpoint_batch_interval': 300,
    'min_mask_area': 100,
    
    # --- T4 SPECIFIC OPTIMIZATIONS ---
    'mixed_precision': True,
    'gradient_checkpointing': True,  # Implemented manually later
    'use_custom_decoder': True,      # Enable custom HVAC decoder
    'encoder_frozen': True,          # Freeze 99% of SAM parameters
    'neck_unfrozen': True,           # Only train Conv2D neck + decoder
    'clear_cache_every': 5,          # Clear CUDA cache every N batches
    
}

Path(CONFIG['output_dir']).mkdir(parents=True, exist_ok=True)
print(f"‚úì Configuration loaded. Using device: {CONFIG['device']}")

In [None]:
# Unzip the dataset if it hasn't been already
if not os.path.exists(CONFIG['unzip_path']):
    print(f"üìÅ Unzipping dataset from {CONFIG['dataset_zip_path']}...")
    with zipfile.ZipFile(CONFIG['dataset_zip_path'], 'r') as zip_ref:
        zip_ref.extractall(CONFIG['unzip_path'])
    print("‚úÖ Unzipping complete.")
else:
    print("‚úÖ Dataset already unzipped.")

### Phase 3: Dataset Loading and DataLoader Creation

In [None]:
def load_coco_split(dataset_root_path: str, split_name: str, annotations_file: str) -> Tuple[COCO, str]:
    split_path = os.path.join(dataset_root_path, split_name)
    annotations_path = os.path.join(split_path, annotations_file)
    if not os.path.exists(annotations_path):
        raise FileNotFoundError(f"Annotations file not found for '{split_name}' at: {annotations_path}")
    print(f"üîÑ Loading '{split_name}' annotations from: {annotations_path}")
    coco = COCO(annotations_path)
    print(f"üìä Found {len(coco.getImgIds())} images in '{split_name}'.")
    return coco, split_path

def get_image_path(split_path: str, img_info: dict) -> str:
    full_path = os.path.join(split_path, img_info['file_name'])
    if not os.path.exists(full_path):
        raise FileNotFoundError(f"Image file not found: {img_info['file_name']} in {split_path}")
    return full_path

class HvacSamDataset(Dataset):
    def __init__(self, coco: COCO, image_ids: List[int], split_path: str, is_training: bool = True):
        self.coco = coco
        self.image_ids = image_ids
        self.split_path = split_path
        self.is_training = is_training
        self.resize_transform = ResizeLongestSide(CONFIG['image_size'])
        self.pixel_mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
        self.pixel_std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
        self.prompt_strategy = CONFIG.get('prompt_strategy', 'perfect_box')

    def __len__(self) -> int:
        return len(self.image_ids)

    def _generate_prompt(self, mask: np.ndarray, bbox: List[float]):
        # For validation/testing, always use the perfect box for consistent evaluation
        if not self.is_training:
            return {'box': np.array(bbox)}

        # Determine the prompt type for this training item
        if self.prompt_strategy == 'multi_prompt':
            prompt_type = random.choice(['box', 'point', 'scribble'])
        elif self.prompt_strategy == 'point':
            prompt_type = 'point'
        elif self.prompt_strategy == 'noisy_box':
            prompt_type = 'box'
        else: # 'perfect_box'
            prompt_type = 'box'

        if prompt_type == 'box':
            noise_factor = CONFIG.get('bbox_noise_factor', 0) if self.prompt_strategy == 'noisy_box' or self.prompt_strategy == 'multi_prompt' else 0
            x, y, w, h = bbox
            x_noise = w * noise_factor * (random.random() - 0.5) * 2
            y_noise = h * noise_factor * (random.random() - 0.5) * 2
            w_noise = w * noise_factor * (random.random() - 0.5) * 2
            h_noise = h * noise_factor * (random.random() - 0.5) * 2
            box = np.array([x + x_noise, y + y_noise, w + w_noise, h + h_noise])
            return {'box': box}

        points = np.argwhere(mask)
        if len(points) == 0: return None

        if prompt_type == 'point':
            point = points[random.randint(0, len(points) - 1)]
            point_coords = np.array([[point[1], point[0]]]) # (x, y)
            point_labels = np.array([1])
            return {'point_coords': point_coords, 'point_labels': point_labels}

        elif prompt_type == 'scribble':
            num_points = min(5, len(points))
            point_indices = np.random.choice(len(points), num_points, replace=False)
            scribble_points = points[point_indices]
            point_coords = scribble_points[:, ::-1] # (row, col) -> (x, y)
            point_labels = np.ones(num_points)
            return {'point_coords': point_coords, 'point_labels': point_labels}

        return None

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        img_id = self.image_ids[idx]
        img_info = self.coco.loadImgs([img_id])[0]
        image_path = get_image_path(self.split_path, img_info)
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        ann_ids = self.coco.getAnnIds(imgIds=[img_id])
        annotations = self.coco.loadAnns(ann_ids)

        masks, prompts = [], []
        for ann in annotations:
            if 'segmentation' not in ann or ann.get('iscrowd', 0) == 1: continue
            mask = self.coco.annToMask(ann)
            if mask.sum() < CONFIG['min_mask_area']: continue

            prompt = self._generate_prompt(mask, ann['bbox'])
            if prompt:
                masks.append(mask.astype(bool))
                prompts.append(prompt)

        original_size = image.shape[:2]
        resized_image = self.resize_transform.apply_image(image)
        input_image_torch = torch.as_tensor(resized_image, dtype=torch.float32).permute(2, 0, 1).contiguous()
        input_image_torch = (input_image_torch - self.pixel_mean) / self.pixel_std
        h, w = input_image_torch.shape[-2:]
        padh, padw = CONFIG['image_size'] - h, CONFIG['image_size'] - w
        input_image_padded = torch.nn.functional.pad(input_image_torch, (0, padw, 0, padh))

        return {
            'image': input_image_padded,
            'masks': masks,
            'prompts': prompts,
            'original_size': original_size,
            'input_size': (h, w)
        }

def custom_collate_fn(batch: List[Dict]) -> Dict[str, Any]:
    return {
        'image': torch.stack([item['image'] for item in batch]),
        'masks': [item['masks'] for item in batch],
        'prompts': [item['prompts'] for item in batch],
        'original_size': [item['original_size'] for item in batch],
        'input_size': [item['input_size'] for item in batch]
    }

# Load datasets
train_coco, train_path = load_coco_split(CONFIG['unzip_path'], 'train', CONFIG['annotations_file_name'])
val_coco, val_path = load_coco_split(CONFIG['unzip_path'], 'valid', CONFIG['annotations_file_name'])
train_ids, val_ids = train_coco.getImgIds(), val_coco.getImgIds()

# Create Datasets and DataLoaders
train_dataset = HvacSamDataset(train_coco, train_ids, train_path, is_training=True)
val_dataset = HvacSamDataset(val_coco, val_ids, val_path, is_training=False)

train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=CONFIG['num_workers'], collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=CONFIG['num_workers'], collate_fn=custom_collate_fn)

print(f"\n‚úÖ Training dataset initialized with {len(train_dataset)} samples.")
print(f"‚úÖ Validation dataset initialized with {len(val_dataset)} samples.")

print(f"\nInspecting dataset directory: {CONFIG['unzip_path']}")
!ls -R {CONFIG['unzip_path']}

In [None]:
import os
import urllib.request

# Define URL for the official ViT-H model
CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
CHECKPOINT_PATH = "/content/sam_vit_h_4b8939.pth"

print(f"‚¨áÔ∏è Downloading official SAM ViT-H weights to {CHECKPOINT_PATH}...")
if not os.path.exists(CHECKPOINT_PATH):
    urllib.request.urlretrieve(CHECKPOINT_URL, CHECKPOINT_PATH)
    print("‚úÖ Download complete.")
else:
    print("‚úÖ File already exists.")

### Phase 4: Model Preparation and Training Setup (T4 Optimized)

In [None]:
# Initialize base SAM model
sam_model = sam_model_registry[CONFIG['model_type']]()
sam_model.to(CONFIG['device'])

# Load pre-trained weights
print(f"üîÑ Loading pre-trained model: {CONFIG['model_path']}")
sam_model.load_state_dict(torch.load(CONFIG['model_path'], map_location=CONFIG['device']))

# --- T4 OPTIMIZATION: FREEZE & CUSTOM DECODER SETUP ---
if CONFIG['use_custom_decoder']:
    print("üîß Switching to custom HVAC-optimized decoder...")
    # Freeze the main ViT encoder
    for param in sam_model.image_encoder.parameters():
        param.requires_grad = False
    
    # Only keep the Conv2D neck unfrozen if specified
    if CONFIG['neck_unfrozen']:
        for name, param in sam_model.image_encoder.named_parameters():
            if 'neck' in name:
                param.requires_grad = True
    
    # Replace the mask decoder with the custom one
    hvac_decoder = HvacOptimizedDecoder(sam_model.image_encoder)
    hvac_decoder.to(CONFIG['device'])
    model_to_train = hvac_decoder
    
    # Optimizer only for the new decoder and neck
    trainable_params = [p for p in hvac_decoder.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(trainable_params, lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])
    
    print(f"‚úÖ Custom decoder initialized. Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
else:
    print("üîß Using standard SAM mask decoder...")
    # Standard fine-tuning setup
    for name, param in sam_model.named_parameters():
        if name.startswith("image_encoder") or name.startswith("prompt_encoder"):
            param.requires_grad = False
    optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])
    model_to_train = sam_model
    trainable_params = sum(p.numel() for p in model_to_train.parameters() if p.requires_grad)
    print(f"‚úÖ Standard SAM configured for fine-tuning. Trainable parameters: {trainable_params:,}")

start_epoch = 0
# --- NEW: RESUME TRAINING LOGIC ---
if CONFIG['resume_training']:
    print(f"üîÑ Attempting to resume training from {CONFIG['latest_checkpoint_save_path']}")
    if os.path.exists(CONFIG['latest_checkpoint_save_path']):
        checkpoint = torch.load(CONFIG['latest_checkpoint_save_path'], map_location=CONFIG['device'])
        model_to_train.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"‚úÖ Resumed successfully. Starting from epoch {start_epoch}.")
    else:
        print("‚ö†Ô∏è Resume checkpoint not found. Starting training from scratch with pre-trained SAM.")
        # Already loaded the base model above

model_to_train.train()
scaler = GradScaler() if CONFIG['mixed_precision'] and CONFIG['device'] == 'cuda' else None
print(f"‚úÖ Model and optimizer configured for T4 runtime.")

In [None]:
import torch
import torch.nn as nn

# Setup scheduler and loss function
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

def combined_loss(pred_masks: torch.Tensor, true_masks: torch.Tensor) -> torch.Tensor:
    bce_loss = nn.BCEWithLogitsLoss()(pred_masks, true_masks.float())
    pred_flat = torch.sigmoid(pred_masks).reshape(-1)
    true_flat = true_masks.reshape(-1)
    intersection = (pred_flat * true_flat).sum()
    dice_loss = 1 - (2. * intersection + 1e-8) / (pred_flat.sum() + true_flat.sum() + 1e-8)
    return 0.8 * bce_loss + 0.2 * dice_loss

print("‚úÖ Scheduler and loss function configured.")

### Phase 5: Memory-Efficient Training and Validation Loop (T4 Optimized)

In [None]:
### Phase 5: Memory-Efficient Training and Validation Loop (T4 Optimized)

def process_single_annotation(model, image_embeddings, gt_mask, prompt, orig_size, input_size, device):
    """Process a single mask-prompt pair."""
    gt_mask_torch = torch.from_numpy(gt_mask).unsqueeze(0).unsqueeze(0).to(device)
    transform = ResizeLongestSide(CONFIG['image_size'])

    box_torch, points_torch, points_label_torch = None, None, None
    if 'box' in prompt:
        box_torch = torch.as_tensor(transform.apply_boxes(prompt['box'].reshape(1, 4), orig_size), dtype=torch.float, device=device)
    elif 'point_coords' in prompt:
        points_torch = torch.as_tensor(transform.apply_coords(prompt['point_coords'], orig_size), dtype=torch.float, device=device).unsqueeze(0)
        points_label_torch = torch.as_tensor(prompt['point_labels'], dtype=torch.float, device=device).unsqueeze(0)

    if isinstance(model, HvacOptimizedDecoder):
        # Custom decoder path
        upscaled_masks = model(image_embeddings)
    else:
        # Standard SAM path
        with torch.no_grad():
            sparse_embeddings, dense_embeddings = model.prompt_encoder(
                points=(points_torch, points_label_torch) if points_torch is not None else None,
                boxes=box_torch,
                masks=None
            )

        low_res_masks, iou_predictions = model.mask_decoder(
            image_embeddings=image_embeddings,
            image_pe=model.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=False,
        )

        upscaled_masks = model.postprocess_masks(low_res_masks, input_size, orig_size)

    return combined_loss(upscaled_masks, gt_mask_torch)


def run_epoch_optimized(model, dataloader, optimizer, is_training, device, epoch, scaler):
    """Memory-optimized training loop for T4 GPU"""
    model.train(is_training)
    epoch_losses, iou_scores = [], []

    desc = "Training" if is_training else "Validation"
    for batch_idx, batch in enumerate(tqdm(dataloader, desc=desc)):
        try:
            images = batch['image'].to(device, non_blocking=True)
            all_gt_masks_list = batch['masks']
            all_prompts_list = batch['prompts']

            if is_training:
                optimizer.zero_grad()

            with torch.set_grad_enabled(is_training):
                # Mixed precision training context
                with autocast() if scaler else contextlib.nullcontext():
                    if isinstance(model, HvacOptimizedDecoder):
                        # Custom decoder: process the whole image at once
image_embeddings = images # The custom decoder incorporates the encoder
upscaled_masks = model(images)

# Handle multiple masks in the image
total_loss = 0
num_annotations_in_image = 0
for j in range(len(all_gt_masks_list[0])):
    gt_mask_np = all_gt_masks_list[0][j]
    gt_mask_torch = torch.from_numpy(gt_mask_np).unsqueeze(0).unsqueeze(0).to(device)
    loss = combined_loss(upscaled_masks, gt_mask_torch)
    total_loss += loss
    num_annotations_in_image += 1

if is_training:
    if scaler:
        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()
    else:
        total_loss.backward()
        optimizer.step()

if num_annotations_in_image > 0:
    epoch_losses.append(total_loss.item() / num_annotations_in_image)
                    else:
                        # Standard SAM: process one annotation at a time
                        image_embeddings = model.image_encoder(images)
                        total_loss = 0
                        num_annotations_in_image = 0

                        for j in range(len(all_gt_masks_list[0])):
                            gt_mask_np = all_gt_masks_list[0][j]
                            prompt = all_prompts_list[0][j]
                            num_annotations_in_image += 1

                            loss = process_single_annotation(
                                model, image_embeddings, gt_mask_np, prompt,
                                batch['original_size'][0], batch['input_size'][0], device
                            )

                        total_loss += loss

                        # --- GRADIENT ACCUMULATION STEP ---
                        if is_training and num_annotations_in_image > 0:
                            avg_loss = total_loss / num_annotations_in_image
                            if scaler:
                                scaler.scale(avg_loss).backward()
                                scaler.step(optimizer)
                                scaler.update()
                            else:
                                avg_loss.backward()
                                optimizer.step()

                        if num_annotations_in_image > 0:
                            epoch_losses.append(total_loss.item() / num_annotations_in_image)

# Memory cleanup
del images
if not isinstance(model, HvacOptimizedDecoder):
    del image_embeddings
torch.cuda.empty_cache()
gc.collect()

# Checkpoint saving logic
if is_training and (batch_idx + 1) % CONFIG['checkpoint_batch_interval'] == 0:
    chk_path = CONFIG['latest_checkpoint_save_path']
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scaler_state_dict': scaler.state_dict() if scaler else None
    }, chk_path)
    tqdm.write(f"\nüíæ Overwrote latest checkpoint (for crash recovery): {os.path.basename(chk_path)}")

        except RuntimeError as e:
            if "out of memory" in str(e):
                print(f"‚ö†Ô∏è OOM at batch {batch_idx}, cleaning up...")
                optimizer.zero_grad(set_to_none=True)
                torch.cuda.empty_cache()
                gc.collect()
                continue  # Skip this batch
            else:
                raise e

    return {'loss': mean(epoch_losses) if epoch_losses else 0, 'iou': mean(iou_scores) if iou_scores else 0}


# --- Main Training Loop ---
best_val_iou = 0
patience_counter = 0
history = {'train_loss': [], 'val_loss': [], 'val_iou': []}

print(f"\nüöÄ Starting training from epoch {start_epoch} for {CONFIG['num_epochs']} total epochs...")
for epoch in range(start_epoch, CONFIG['num_epochs']):
    print(f"\n--- Epoch {epoch+1}/{CONFIG['num_epochs']} ---")

    train_metrics = run_epoch_optimized(model_to_train, train_loader, optimizer, is_training=True, device=CONFIG['device'], epoch=epoch, scaler=scaler)
    history['train_loss'].append(train_metrics['loss'])

    val_metrics = run_epoch_optimized(model_to_train, val_loader, None, is_training=False, device=CONFIG['device'], epoch=epoch, scaler=None)
    history['val_loss'].append(val_metrics['loss'])
    history['val_iou'].append(val_metrics['iou'])

    print(f"Train Loss: {train_metrics['loss']:.4f}")
    print(f"Val Loss: {val_metrics['loss']:.4f} | Val IoU: {val_metrics['iou']:.4f}")

    scheduler.step(val_metrics['loss'])

    if val_metrics['iou'] > best_val_iou:
        best_val_iou = val_metrics['iou']
        patience_counter = 0
        best_model_path = CONFIG['best_model_save_path']
        torch.save(model_to_train.state_dict(), best_model_path)
        print(f"üèÜ New best model saved with IoU: {best_val_iou:.4f}")
    else:
        patience_counter += 1

    if patience_counter >= CONFIG['early_stopping_patience']:
        print(f"üõë Early stopping triggered after {patience_counter} epochs with no improvement.")
        break

print("\n‚úÖ Training completed!")

### Phase 6: Results and Export

In [None]:
# Plot training metrics
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

ax1.plot(history['train_loss'], label='Training Loss', color='blue')
ax1.plot(history['val_loss'], label='Validation Loss', color='red')
ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True)

ax2.plot(history['val_iou'], label='Validation IoU', color='green')
ax2.set_title('Validation Metrics (IoU)')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Score')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.savefig(os.path.join(CONFIG['output_dir'], 'training_metrics.png'))
plt.show()

print(f"\nüìä FINAL METRICS:")
if history['val_iou']:
    print(f"Best Validation IoU: {max(history['val_iou']):.4f}")
else:
    print("No validation metrics recorded.")

In [None]:
# Save the final trained model (or the best one)
final_model_path = os.path.join(CONFIG['output_dir'], 'final_model.pth')
torch.save(model_to_train.state_dict(), final_model_path)
print(f"‚úÖ Final model state saved to: {final_model_path}")
print(f"‚úÖ Best performing model (by IoU) saved to: {CONFIG['best_model_save_path']}")

In [None]:
### Phase 7: Final, Unbiased Evaluation on the Test Set

print("\n--- Final Model Evaluation on Unseen Test Data ---")

# 1. Load the best performing model that was saved during training
best_model_path = CONFIG['best_model_save_path']
if os.path.exists(best_model_path):
    print(f"üîÑ Loading best model from: {best_model_path}")
    # We need to re-initialize the model structure before loading the state dict
    if CONFIG['use_custom_decoder']:
        eval_model = HvacOptimizedDecoder(sam_model_registry[CONFIG['model_type']]().image_encoder)
    else:
        eval_model = sam_model_registry[CONFIG['model_type']]()
        
    eval_model.load_state_dict(torch.load(best_model_path, map_location=CONFIG['device']))
    eval_model.to(CONFIG['device'])
    eval_model.eval()
else:
    print("‚ùå Best model file not found. Cannot perform final evaluation.")
    eval_model = None

# 2. Load the test dataset
try:
    test_coco, test_path = load_coco_split(CONFIG['unzip_path'], 'test', CONFIG['annotations_file_name'])
    test_ids = test_coco.getImgIds()

    # Use is_training=False to ensure validation uses perfect boxes
    test_dataset = HvacSamDataset(test_coco, test_ids, test_path, is_training=False)
    test_loader = DataLoader(test_dataset, batch_size=CONFIG['batch_size'], shuffle=False,
                             num_workers=CONFIG['num_workers'], collate_fn=custom_collate_fn)

    print(f"\n‚úÖ Test dataset loaded with {len(test_dataset)} samples.")

except FileNotFoundError as e:
    print(f"\n‚ö†Ô∏è Test split not found: {e}. Skipping final evaluation.")

# 3. Run a single evaluation pass on the test data
if 'test_loader' in locals() and eval_model is not None:
    print("\nüöÄ Running final evaluation on the test set...")
    final_test_metrics = run_epoch_optimized(eval_model, test_loader, optimizer=None, is_training=False, device=CONFIG['device'], epoch=0, scaler=None)

    print("\n" + "="*50)
    print("      üéâ FINAL UNBIASED PERFORMANCE METRICS üéâ")
    print("="*50)
    print(f"Final Test IoU:   {final_test_metrics['iou']:.4f}")
    print(f"Final Test Loss:  {final_test_metrics['loss']:.4f}")
    print("="*50)
    print("\nThis is the true expected performance of your model on new data.")
elif eval_model is None:
    print("\n‚ùå Skipping final evaluation due to missing model.")