# üü£ HVAC-Specific SAM Fine-Tuning Pipeline (v2 - Advanced)
## üîß Enhanced with Multi-Prompt Training & Resumability

This notebook has been upgraded to include advanced features for creating a more robust and flexible model, inspired by recent research (e.g., SAM-PAR).

### Phase 1: Initial Setup

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# Install required dependencies
!pip install torch torchvision --quiet
!pip install opencv-python pycocotools matplotlib onnxruntime onnx --quiet
!pip install git+https://github.com/facebookresearch/segment-anything.git --quiet

[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m17.4/17.4 MB[0m [31m135.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m18.1/18.1 MB[0m [31m132.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m46.0/46.0 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m86.8/86.8 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for segment_anything (setup.py) ... [?25l[?25hdone


In [3]:
import os
import zipfile
import json
import numpy as np
import matplotlib.pyplot as plt
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
from tqdm import tqdm
from statistics import mean
from pycocotools.coco import COCO
from pycocotools import mask as mask_utils
import random

# SAM imports
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.transforms import ResizeLongestSide

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

### Phase 2: Configuration and Dataset Preparation

In [4]:
import torch
from pathlib import Path

CONFIG = {
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',

    # --- STARTING MODEL ---
    # Path to the official pre-trained SAM weights. We start fresh for this new expert dataset.
    'model_path': '/content/sam_vit_h_4b8939.pth',

    # --- DATASET ---
    'dataset_zip_path': '/content/drive/MyDrive/hvac_dataset_coco.zip',
    'unzip_path': '/content/drive/MyDrive/hvac_dataset_coco',
    'annotations_file_name': '_annotations.coco.json',

    # --- OUTPUT & CHECKPOINTING ---
    'output_dir': '/content/drive/MyDrive/sam_finetuning_results',
    'best_model_save_path': '/content/drive/MyDrive/sam_finetuning_results/best_model_multiprompt_v1.pth',
    'latest_checkpoint_save_path': '/content/drive/MyDrive/sam_finetuning_results/latest_checkpoint_multiprompt_v1.pth',
    'resume_training': False, # SET TO TRUE TO RESUME FROM 'latest_checkpoint_save_path'

    # --- PROMPT ENGINEERING STRATEGY ---
    # Options: 'perfect_box', 'noisy_box', 'multi_prompt' (inspired by SAM-PAR paper)
    'prompt_strategy': 'multi_prompt',
    'bbox_noise_factor': 0.1, # How much to 'jiggle' the box in noisy_box or multi_prompt mode

    # --- MODEL & TRAINING HYPERPARAMETERS ---
    'model_type': 'vit_h',
    'image_size': 1024,
    'batch_size': 1,
    'num_workers': 0,
    'num_epochs': 1,
    'learning_rate': 1e-4,
    'weight_decay': 0,
    'early_stopping_patience': 10,
    'checkpoint_batch_interval': 300,
    'min_mask_area': 100,
}

Path(CONFIG['output_dir']).mkdir(parents=True, exist_ok=True)
print(f"‚úì Configuration loaded. Using device: {CONFIG['device']}")

‚úì Configuration loaded. Using device: cuda


In [5]:
# Unzip the dataset if it hasn't been already
if not os.path.exists(CONFIG['unzip_path']):
    print(f"üìÅ Unzipping dataset from {CONFIG['dataset_zip_path']}...")
    with zipfile.ZipFile(CONFIG['dataset_zip_path'], 'r') as zip_ref:
        zip_ref.extractall(CONFIG['unzip_path'])
    print("‚úÖ Unzipping complete.")
else:
    print("‚úÖ Dataset already unzipped.")

üìÅ Unzipping dataset from /content/drive/MyDrive/hvac_dataset_coco.zip...
‚úÖ Unzipping complete.


### Phase 3: Dataset Loading and DataLoader Creation

In [6]:
import os
import zipfile
import json
import numpy as np
import matplotlib.pyplot as plt
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
from tqdm import tqdm
from statistics import mean
from pycocotools.coco import COCO
from pycocotools import mask as mask_utils
import random

# SAM imports
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.transforms import ResizeLongestSide

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

def load_coco_split(dataset_root_path: str, split_name: str, annotations_file: str) -> Tuple[COCO, str]:
    split_path = os.path.join(dataset_root_path, split_name)
    annotations_path = os.path.join(split_path, annotations_file)
    if not os.path.exists(annotations_path):
        raise FileNotFoundError(f"Annotations file not found for '{split_name}' at: {annotations_path}")
    print(f"üîÑ Loading '{split_name}' annotations from: {annotations_path}")
    coco = COCO(annotations_path)
    print(f"üìä Found {len(coco.getImgIds())} images in '{split_name}'.")
    return coco, split_path

def get_image_path(split_path: str, img_info: dict) -> str:
    full_path = os.path.join(split_path, img_info['file_name'])
    if not os.path.exists(full_path):
        raise FileNotFoundError(f"Image file not found: {img_info['file_name']} in {split_path}")
    return full_path

class HvacSamDataset(Dataset):
    def __init__(self, coco: COCO, image_ids: List[int], split_path: str, is_training: bool = True):
        self.coco = coco
        self.image_ids = image_ids
        self.split_path = split_path
        self.is_training = is_training
        self.resize_transform = ResizeLongestSide(CONFIG['image_size'])
        self.pixel_mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
        self.pixel_std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
        self.prompt_strategy = CONFIG.get('prompt_strategy', 'perfect_box')

    def __len__(self) -> int:
        return len(self.image_ids)

    def _generate_prompt(self, mask: np.ndarray, bbox: List[float]):
        # For validation/testing, always use the perfect box for consistent evaluation
        if not self.is_training:
            return {'box': np.array(bbox)}

        # Determine the prompt type for this training item
        if self.prompt_strategy == 'multi_prompt':
            prompt_type = random.choice(['box', 'point', 'scribble'])
        elif self.prompt_strategy == 'point':
            prompt_type = 'point'
        elif self.prompt_strategy == 'noisy_box':
            prompt_type = 'box'
        else: # 'perfect_box'
            prompt_type = 'box'

        if prompt_type == 'box':
            noise_factor = CONFIG.get('bbox_noise_factor', 0) if self.prompt_strategy == 'noisy_box' or self.prompt_strategy == 'multi_prompt' else 0
            x, y, w, h = bbox
            x_noise = w * noise_factor * (random.random() - 0.5) * 2
            y_noise = h * noise_factor * (random.random() - 0.5) * 2
            w_noise = w * noise_factor * (random.random() - 0.5) * 2
            h_noise = h * noise_factor * (random.random() - 0.5) * 2
            box = np.array([x + x_noise, y + y_noise, w + w_noise, h + h_noise])
            return {'box': box}

        points = np.argwhere(mask)
        if len(points) == 0: return None

        if prompt_type == 'point':
            point = points[random.randint(0, len(points) - 1)]
            point_coords = np.array([[point[1], point[0]]]) # (x, y)
            point_labels = np.array([1])
            return {'point_coords': point_coords, 'point_labels': point_labels}

        elif prompt_type == 'scribble':
            num_points = min(5, len(points))
            point_indices = np.random.choice(len(points), num_points, replace=False)
            scribble_points = points[point_indices]
            point_coords = scribble_points[:, ::-1] # (row, col) -> (x, y)
            point_labels = np.ones(num_points)
            return {'point_coords': point_coords, 'point_labels': point_labels}

        return None

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        img_id = self.image_ids[idx]
        img_info = self.coco.loadImgs([img_id])[0]
        image_path = get_image_path(self.split_path, img_info)
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        ann_ids = self.coco.getAnnIds(imgIds=[img_id])
        annotations = self.coco.loadAnns(ann_ids)

        masks, prompts = [], []
        for ann in annotations:
            if 'segmentation' not in ann or ann.get('iscrowd', 0) == 1: continue
            mask = self.coco.annToMask(ann)
            if mask.sum() < CONFIG['min_mask_area']: continue

            prompt = self._generate_prompt(mask, ann['bbox'])
            if prompt:
                masks.append(mask.astype(bool))
                prompts.append(prompt)

        original_size = image.shape[:2]
        resized_image = self.resize_transform.apply_image(image)
        input_image_torch = torch.as_tensor(resized_image, dtype=torch.float32).permute(2, 0, 1).contiguous()
        input_image_torch = (input_image_torch - self.pixel_mean) / self.pixel_std
        h, w = input_image_torch.shape[-2:]
        padh, padw = CONFIG['image_size'] - h, CONFIG['image_size'] - w
        input_image_padded = torch.nn.functional.pad(input_image_torch, (0, padw, 0, padh))

        return {
            'image': input_image_padded,
            'masks': masks,
            'prompts': prompts,
            'original_size': original_size,
            'input_size': (h, w)
        }

def custom_collate_fn(batch: List[Dict]) -> Dict[str, Any]:
    return {
        'image': torch.stack([item['image'] for item in batch]),
        'masks': [item['masks'] for item in batch],
        'prompts': [item['prompts'] for item in batch],
        'original_size': [item['original_size'] for item in batch],
        'input_size': [item['input_size'] for item in batch]
    }

# Load datasets
train_coco, train_path = load_coco_split(CONFIG['unzip_path'], 'train', CONFIG['annotations_file_name'])
val_coco, val_path = load_coco_split(CONFIG['unzip_path'], 'valid', CONFIG['annotations_file_name'])
train_ids, val_ids = train_coco.getImgIds(), val_coco.getImgIds()

# Create Datasets and DataLoaders
train_dataset = HvacSamDataset(train_coco, train_ids, train_path, is_training=True)
val_dataset = HvacSamDataset(val_coco, val_ids, val_path, is_training=False)

train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=CONFIG['num_workers'], collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=CONFIG['num_workers'], collate_fn=custom_collate_fn)

print(f"\n‚úÖ Training dataset initialized with {len(train_dataset)} samples.")
print(f"‚úÖ Validation dataset initialized with {len(val_dataset)} samples.")

print(f"\nInspecting dataset directory: {CONFIG['unzip_path']}")
!ls -R {CONFIG['unzip_path']}

üîÑ Loading 'train' annotations from: /content/drive/MyDrive/hvac_dataset_coco/train/_annotations.coco.json
loading annotations into memory...
Done (t=0.33s)
creating index...
index created!
üìä Found 2802 images in 'train'.
üîÑ Loading 'valid' annotations from: /content/drive/MyDrive/hvac_dataset_coco/valid/_annotations.coco.json
loading annotations into memory...
Done (t=0.03s)
creating index...
index created!
üìä Found 341 images in 'valid'.

‚úÖ Training dataset initialized with 2802 samples.
‚úÖ Validation dataset initialized with 341 samples.

Inspecting dataset directory: /content/drive/MyDrive/hvac_dataset_coco
/content/drive/MyDrive/hvac_dataset_coco:
README.roboflow.txt  test  train  valid

/content/drive/MyDrive/hvac_dataset_coco/test:
03-Mechanical_1_page_4_grid_1x3_jpg.rf.9b42123bb3bd2742a707dd110f23bee5.jpg
03-Mechanical_1_page_4_grid_3x2_jpg.rf.ba0723861f9acef4423783880408a711.jpg
03-Mechanical_1_page_5_grid_2x1_jpg.rf.4a0c650e0552d94979ca72ee07532051.jpg
0510162022_

In [7]:
import os
import urllib.request

# Define URL for the official ViT-H model
CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
CHECKPOINT_PATH = "/content/sam_vit_h_4b8939.pth"

print(f"‚¨áÔ∏è Downloading official SAM ViT-H weights to {CHECKPOINT_PATH}...")
if not os.path.exists(CHECKPOINT_PATH):
    urllib.request.urlretrieve(CHECKPOINT_URL, CHECKPOINT_PATH)
    print("‚úÖ Download complete.")
else:
    print("‚úÖ File already exists.")

‚¨áÔ∏è Downloading official SAM ViT-H weights to /content/sam_vit_h_4b8939.pth...
‚úÖ Download complete.


### Phase 4: Model Preparation and Training Setup

In [8]:
# Initialize model and optimizer first to allow for state loading
sam_model = sam_model_registry[CONFIG['model_type']]()
sam_model.to(CONFIG['device'])
optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])

start_epoch = 0

# --- NEW: RESUME TRAINING LOGIC ---
if CONFIG['resume_training']:
    print(f"üîÑ Attempting to resume training from {CONFIG['latest_checkpoint_save_path']}")
    if os.path.exists(CONFIG['latest_checkpoint_save_path']):
        checkpoint = torch.load(CONFIG['latest_checkpoint_save_path'])
        sam_model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"‚úÖ Resumed successfully. Starting from epoch {start_epoch}.")
    else:
        print("‚ö†Ô∏è Resume checkpoint not found. Starting training from scratch with pre-trained SAM.")
        # Load the base pre-trained model if resume fails
        sam_model.load_state_dict(torch.load(CONFIG['model_path']))
else:
    print(f"üîÑ Starting new training session from pre-trained model: {CONFIG['model_path']}")
    sam_model.load_state_dict(torch.load(CONFIG['model_path']))

# Configure model for fine-tuning (freeze encoders)
sam_model.train()
for name, param in sam_model.named_parameters():
    if name.startswith("image_encoder") or name.startswith("prompt_encoder"):
        param.requires_grad = False

trainable_params = sum(p.numel() for p in sam_model.parameters() if p.requires_grad)
print(f"‚úÖ SAM model configured for fine-tuning. Trainable parameters: {trainable_params:,}")

üîÑ Starting new training session from pre-trained model: /content/sam_vit_h_4b8939.pth
‚úÖ SAM model configured for fine-tuning. Trainable parameters: 4,058,340


In [9]:
import torch
import torch.nn as nn

# Setup scheduler and loss function
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

def combined_loss(pred_masks: torch.Tensor, true_masks: torch.Tensor) -> torch.Tensor:
    bce_loss = nn.BCEWithLogitsLoss()(pred_masks, true_masks.float())
    pred_flat = torch.sigmoid(pred_masks).reshape(-1)
    true_flat = true_masks.reshape(-1)
    intersection = (pred_flat * true_flat).sum()
    dice_loss = 1 - (2. * intersection + 1e-8) / (pred_flat.sum() + true_flat.sum() + 1e-8)
    return 0.8 * bce_loss + 0.2 * dice_loss

print("‚úÖ Scheduler and loss function configured.")

‚úÖ Scheduler and loss function configured.


### Phase 5: Training and Validation Loop

In [None]:
### Phase 5: Training and Validation Loop (Optimized for T4 GPU)

def run_epoch(model, dataloader, optimizer, is_training, device, epoch):
    model.train(is_training)
    epoch_losses, iou_scores = [], []

    desc = "Training" if is_training else "Validation"
    for batch_idx, batch in enumerate(tqdm(dataloader, desc=desc)):
        images = batch['image'].to(device, non_blocking=True)
        all_gt_masks_list = batch['masks']
        all_prompts_list = batch['prompts']

        # For training, we must clear gradients at the start of each image (batch)
        if is_training:
            optimizer.zero_grad()

        with torch.set_grad_enabled(is_training):
            with torch.no_grad():
                image_embeddings = model.image_encoder(images)

            total_loss_for_image = 0
            num_annotations_in_image = 0

            # This loop structure is correct for batch_size = 1
            if len(all_gt_masks_list) > 0:
                # Loop through every annotation for the current image
                for j in range(len(all_gt_masks_list[0])):
                    gt_mask_np = all_gt_masks_list[0][j]
                    prompt = all_prompts_list[0][j]
                    num_annotations_in_image += 1

                    gt_mask_torch = torch.from_numpy(gt_mask_np).unsqueeze(0).unsqueeze(0).to(device)
                    transform = ResizeLongestSide(CONFIG['image_size'])

                    box_torch, points_torch, points_label_torch = None, None, None
                    if 'box' in prompt:
                        box_torch = torch.as_tensor(transform.apply_boxes(prompt['box'].reshape(1, 4), batch['original_size'][0]), dtype=torch.float, device=device)
                    elif 'point_coords' in prompt:
                        points_torch = torch.as_tensor(transform.apply_coords(prompt['point_coords'], batch['original_size'][0]), dtype=torch.float, device=device).unsqueeze(0)
                        points_label_torch = torch.as_tensor(prompt['point_labels'], dtype=torch.float, device=device).unsqueeze(0)

                    with torch.no_grad():
                        sparse_embeddings, dense_embeddings = model.prompt_encoder(
                            points=(points_torch, points_label_torch) if points_torch is not None else None,
                            boxes=box_torch,
                            masks=None
                        )

                    low_res_masks, iou_predictions = model.mask_decoder(
                        image_embeddings=image_embeddings,
                        image_pe=model.prompt_encoder.get_dense_pe(),
                        sparse_prompt_embeddings=sparse_embeddings,
                        dense_prompt_embeddings=dense_embeddings,
                        multimask_output=False,
                    )

                    upscaled_masks = model.postprocess_masks(low_res_masks, batch['input_size'][0], batch['original_size'][0])

                    # Calculate loss for this single annotation
                    loss = combined_loss(upscaled_masks, gt_mask_torch)

                    # --- TRUE GRADIENT ACCUMULATION STEP ---
                    if is_training:
                        # Normalize the loss by the number of annotations to keep gradients stable
                        normalized_loss = loss / len(all_gt_masks_list[0])
                        # Calculate and accumulate gradients for this single, tiny loss
                        normalized_loss.backward()

                    total_loss_for_image += loss.item()

                    if not is_training:
                        pred_mask = (torch.sigmoid(upscaled_masks) > 0.5).squeeze().cpu().numpy().astype(np.uint8)
                        iou = mask_utils.iou([mask_utils.encode(np.asfortranarray(pred_mask))], [mask_utils.encode(np.asfortranarray(gt_mask_np.astype(np.uint8)))], [0])[0][0]
                        iou_scores.append(iou)

            # --- OPTIMIZER STEP AFTER ALL ANNOTATIONS ---
            # After processing all annotations for the image, update the model weights ONCE
            if is_training and num_annotations_in_image > 0:
                optimizer.step()
                # The gradients are automatically cleared by the zero_grad() at the start of the next batch

            if num_annotations_in_image > 0:
                epoch_losses.append(total_loss_for_image / num_annotations_in_image)

        # Checkpoint saving logic
        if is_training and (batch_idx + 1) % CONFIG['checkpoint_batch_interval'] == 0:
            chk_path = CONFIG['latest_checkpoint_save_path']
            torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, chk_path)
            tqdm.write(f"\nüíæ Overwrote latest checkpoint (for crash recovery): {os.path.basename(chk_path)}")

    return {'loss': mean(epoch_losses) if epoch_losses else 0, 'iou': mean(iou_scores) if iou_scores else 0}


# --- Main Training Loop ---
best_val_iou = 0
patience_counter = 0
history = {'train_loss': [], 'val_loss': [], 'val_iou': []}

print(f"\nüöÄ Starting training from epoch {start_epoch} for {CONFIG['num_epochs']} total epochs...")
for epoch in range(start_epoch, CONFIG['num_epochs']):
    print(f"\n--- Epoch {epoch+1}/{CONFIG['num_epochs']} ---")

    train_metrics = run_epoch(sam_model, train_loader, optimizer, is_training=True, device=CONFIG['device'], epoch=epoch)
    history['train_loss'].append(train_metrics['loss'])

    val_metrics = run_epoch(sam_model, val_loader, None, is_training=False, device=CONFIG['device'], epoch=epoch)
    history['val_loss'].append(val_metrics['loss'])
    history['val_iou'].append(val_metrics['iou'])

    print(f"Train Loss: {train_metrics['loss']:.4f}")
    print(f"Val Loss: {val_metrics['loss']:.4f} | Val IoU: {val_metrics['iou']:.4f}")

    scheduler.step(val_metrics['loss'])

    if val_metrics['iou'] > best_val_iou:
        best_val_iou = val_metrics['iou']
        patience_counter = 0
        best_model_path = CONFIG['best_model_save_path']
        torch.save(sam_model.state_dict(), best_model_path)
        print(f"üèÜ New best model saved with IoU: {best_val_iou:.4f}")
    else:
        patience_counter += 1

    if patience_counter >= CONFIG['early_stopping_patience']:
        print(f"üõë Early stopping triggered after {patience_counter} epochs with no improvement.")
        break

print("\n‚úÖ Training completed!")


üöÄ Starting training from epoch 0 for 1 total epochs...

--- Epoch 1/1 ---


Training:  11%|‚ñà         | 300/2802 [17:59<4:05:40,  5.89s/it]


üíæ Overwrote latest checkpoint (for crash recovery): latest_checkpoint_multiprompt_v1.pth


Training:  14%|‚ñà‚ñç        | 391/2802 [23:50<1:32:37,  2.31s/it]

### Phase 6: Results and Export

In [None]:
# Plot training metrics
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

ax1.plot(history['train_loss'], label='Training Loss', color='blue')
ax1.plot(history['val_loss'], label='Validation Loss', color='red')
ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True)

ax2.plot(history['val_iou'], label='Validation IoU', color='green')
ax2.set_title('Validation Metrics (IoU)')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Score')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.savefig(os.path.join(CONFIG['output_dir'], 'training_metrics.png'))
plt.show()

print(f"\nüìä FINAL METRICS:")
if history['val_iou']:
    print(f"Best Validation IoU: {max(history['val_iou']):.4f}")
else:
    print("No validation metrics recorded.")

In [None]:
# Save the final trained model (or the best one)
final_model_path = os.path.join(CONFIG['output_dir'], 'final_model.pth')
torch.save(sam_model.state_dict(), final_model_path)
print(f"‚úÖ Final model state saved to: {final_model_path}")
print(f"‚úÖ Best performing model (by IoU) saved to: {CONFIG['best_model_save_path']}")

In [None]:
### Phase 7: Final, Unbiased Evaluation on the Test Set

print("\n--- Final Model Evaluation on Unseen Test Data ---")

# 1. Load the best performing model that was saved during training
best_model_path = CONFIG['best_model_save_path']
if os.path.exists(best_model_path):
    print(f"üîÑ Loading best model from: {best_model_path}")
    # We need to re-initialize the model structure before loading the state dict
    eval_model = sam_model_registry[CONFIG['model_type']]()
    eval_model.load_state_dict(torch.load(best_model_path))
    eval_model.to(CONFIG['device'])
else:
    print("‚ùå Best model file not found. Cannot perform final evaluation.")
    # You might want to handle this case, but for now we'll assume it exists

# 2. Load the test dataset
try:
    test_coco, test_path = load_coco_split(CONFIG['unzip_path'], 'test', CONFIG['annotations_file_name'])
    test_ids = test_coco.getImgIds()

    # Use is_training=False to ensure validation uses perfect boxes
    test_dataset = HvacSamDataset(test_coco, test_ids, test_path, is_training=False)
    test_loader = DataLoader(test_dataset, batch_size=CONFIG['batch_size'], shuffle=False,
                             num_workers=CONFIG['num_workers'], collate_fn=custom_collate_fn)

    print(f"\n‚úÖ Test dataset loaded with {len(test_dataset)} samples.")

except FileNotFoundError as e:
    print(f"\n‚ö†Ô∏è Test split not found: {e}. Skipping final evaluation.")

# 3. Run a single evaluation pass on the test data
if 'test_loader' in locals() and 'eval_model' in locals():
    print("\nüöÄ Running final evaluation on the test set...")
    final_test_metrics = run_epoch(eval_model, test_loader, optimizer=None, is_training=False, device=CONFIG['device'], epoch=0)

    print("\n" + "="*50)
    print("      üéâ FINAL UNBIASED PERFORMANCE METRICS üéâ")
    print("="*50)
    print(f"Final Test IoU:   {final_test_metrics['iou']:.4f}")
    print(f"Final Test Loss:  {final_test_metrics['loss']:.4f}")
    print("="*50)
    print("\nThis is the true expected performance of your model on new data.")