# üü£ HVAC-Specific SAM Fine-Tuning Pipeline (Optimized)
## üîß Complete Production-Ready Implementation

This notebook has been refactored to provide a clean, linear, and robust workflow for fine-tuning the Segment Anything Model (SAM) on your HVAC dataset.

### Phase 1: Initial Setup

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
# Install required dependencies
!pip install torch torchvision --quiet
!pip install opencv-python pycocotools matplotlib onnxruntime onnx --quiet
!pip install git+https://github.com/facebookresearch/segment-anything.git --quiet

[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m17.4/17.4 MB[0m [31m120.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m18.1/18.1 MB[0m [31m119.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m46.0/46.0 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m86.8/86.8 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for segment_anything (setup.py) ... [?25l[?25hdone


In [4]:
import os
import zipfile
import json
import numpy as np
import matplotlib.pyplot as plt
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
from tqdm import tqdm
from statistics import mean
from pycocotools.coco import COCO
from pycocotools import mask as mask_utils

# SAM imports
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.transforms import ResizeLongestSide

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

### Phase 2: Configuration and Dataset Preparation

In [7]:
import torch
from pathlib import Path

CONFIG = {
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'model_path': '/content/drive/MyDrive/sam_finetuning_results/best_model.pth',
    'dataset_zip_path': '/content/drive/MyDrive/hvac_dataset_final.zip',
    'unzip_path': '/content/drive/MyDrive/hvac_dataset_final',
    'annotations_file_name': '_annotations.coco.json',
    'output_dir': '/content/drive/MyDrive/sam_finetuning_results',
    'model_type': 'vit_h',
    'image_size': 1024,
    'batch_size': 1,
    'num_workers': 0,
    'num_epochs': 1,
    'learning_rate': 1e-4,
    'weight_decay': 0,
    'early_stopping_patience': 10,
    'checkpoint_interval': 10, # For end-of-epoch saving
    'checkpoint_batch_interval': 300, # ADDED: Save every 200 batches
    'min_mask_area': 100,
}

Path(CONFIG['output_dir']).mkdir(parents=True, exist_ok=True)
print(f"‚úì Configuration loaded. Using device: {CONFIG['device']}")

‚úì Configuration loaded. Using device: cuda


In [8]:
# Unzip the dataset if it hasn't been already
if not os.path.exists(CONFIG['unzip_path']):
    print(f"üìÅ Unzipping dataset from {CONFIG['dataset_zip_path']}...")
    with zipfile.ZipFile(CONFIG['dataset_zip_path'], 'r') as zip_ref:
        zip_ref.extractall(CONFIG['unzip_path'])
    print("‚úÖ Unzipping complete.")
else:
    print("‚úÖ Dataset already unzipped.")

‚úÖ Dataset already unzipped.


### Phase 3: Dataset Loading and DataLoader Creation

In [9]:
def load_coco_split(dataset_root_path: str, split_name: str, annotations_file: str) -> Tuple[COCO, str]:
    """Loads a specific split of a COCO dataset."""
    split_path = os.path.join(dataset_root_path, split_name)
    annotations_path = os.path.join(split_path, annotations_file)
    if not os.path.exists(annotations_path):
        raise FileNotFoundError(f"Annotations file not found for '{split_name}' at: {annotations_path}")
    print(f"üîÑ Loading '{split_name}' annotations from: {annotations_path}")
    coco = COCO(annotations_path)
    print(f"üìä Found {len(coco.getImgIds())} images in '{split_name}'.")
    return coco, split_path

def get_image_path(split_path: str, img_info: dict) -> str:
    """Constructs the full path to an image file."""
    full_path = os.path.join(split_path, img_info['file_name'])
    if not os.path.exists(full_path):
        raise FileNotFoundError(f"Image file not found: {img_info['file_name']} in {split_path}")
    return full_path

class HvacSamDataset(Dataset):
    def __init__(self, coco: COCO, image_ids: List[int], split_path: str):
        self.coco = coco
        self.image_ids = image_ids
        self.split_path = split_path
        self.resize_transform = ResizeLongestSide(CONFIG['image_size'])
        self.pixel_mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
        self.pixel_std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)

    def __len__(self) -> int:
        return len(self.image_ids)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        img_id = self.image_ids[idx]
        img_info = self.coco.loadImgs([img_id])[0]
        image_path = get_image_path(self.split_path, img_info)
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        ann_ids = self.coco.getAnnIds(imgIds=[img_id])
        annotations = self.coco.loadAnns(ann_ids)
        masks, bboxes = [], []

        for ann in annotations:
            if 'segmentation' not in ann or ann.get('iscrowd', 0) == 1: continue
            mask = self.coco.annToMask(ann)
            if mask.sum() < CONFIG['min_mask_area']: continue
            masks.append(mask.astype(bool))
            bboxes.append(ann['bbox'])

        original_size = image.shape[:2]
        resized_image = self.resize_transform.apply_image(image)
        input_image_torch = torch.as_tensor(resized_image, dtype=torch.float32).permute(2, 0, 1).contiguous()
        input_image_torch = (input_image_torch - self.pixel_mean) / self.pixel_std

        h, w = input_image_torch.shape[-2:]
        padh, padw = CONFIG['image_size'] - h, CONFIG['image_size'] - w
        input_image_padded = torch.nn.functional.pad(input_image_torch, (0, padw, 0, padh))

        return {
            'image': input_image_padded,
            'masks': masks,
            'bboxes': bboxes,
            'original_size': original_size,
            'input_size': (h, w)
        }

def custom_collate_fn(batch: List[Dict]) -> Dict[str, Any]:
    """Handles batches with variable numbers of masks/bboxes per image."""
    return {
        'image': torch.stack([item['image'] for item in batch]),
        'masks': [item['masks'] for item in batch],
        'bboxes': [item['bboxes'] for item in batch],
        'original_size': [item['original_size'] for item in batch],
        'input_size': [item['input_size'] for item in batch]
    }

# Load datasets
train_coco, train_path = load_coco_split(CONFIG['unzip_path'], 'train', CONFIG['annotations_file_name'])
val_coco, val_path = load_coco_split(CONFIG['unzip_path'], 'valid', CONFIG['annotations_file_name'])

train_ids = train_coco.getImgIds()
val_ids = val_coco.getImgIds()

# Create Datasets and DataLoaders
train_dataset = HvacSamDataset(train_coco, train_ids, train_path)
val_dataset = HvacSamDataset(val_coco, val_ids, val_path)

train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True,
                        num_workers=CONFIG['num_workers'], collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False,
                      num_workers=CONFIG['num_workers'], collate_fn=custom_collate_fn)

print(f"\n‚úÖ Training dataset initialized with {len(train_dataset)} samples.")
print(f"‚úÖ Validation dataset initialized with {len(val_dataset)} samples.")

üîÑ Loading 'train' annotations from: /content/drive/MyDrive/hvac_dataset_final/train/_annotations.coco.json
loading annotations into memory...
Done (t=1.41s)
creating index...
index created!
üìä Found 2604 images in 'train'.
üîÑ Loading 'valid' annotations from: /content/drive/MyDrive/hvac_dataset_final/valid/_annotations.coco.json
loading annotations into memory...
Done (t=0.54s)
creating index...
index created!
üìä Found 351 images in 'valid'.

‚úÖ Training dataset initialized with 2604 samples.
‚úÖ Validation dataset initialized with 351 samples.


In [11]:
import os
import urllib.request

# Define URL for the official ViT-H model
CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
CHECKPOINT_PATH = "/content/sam_vit_h_4b8939.pth"

print(f"‚¨áÔ∏è Downloading official SAM ViT-H weights to {CHECKPOINT_PATH}...")
if not os.path.exists(CHECKPOINT_PATH):
    urllib.request.urlretrieve(CHECKPOINT_URL, CHECKPOINT_PATH)
    print("‚úÖ Download complete.")
else:
    print("‚úÖ File already exists.")

‚¨áÔ∏è Downloading official SAM ViT-H weights to /content/sam_vit_h_4b8939.pth...
‚úÖ Download complete.


### Phase 4: Model Preparation and Training Setup

In [12]:
# --- UPDATE CONFIG PATH ---
CONFIG['model_path'] = "/content/sam_vit_h_4b8939.pth"  # Point to the fresh download

def load_sam_model(model_path: str, model_type: str) -> nn.Module:
    """Loads a SAM model from a checkpoint."""
    print(f"üîÑ Loading SAM model from: {model_path}")

    if not os.path.exists(model_path):
        raise FileNotFoundError(f"‚ùå Model checkpoint not found at: {model_path}. Did you run the download cell?")

    try:
        # Load the model
        sam = sam_model_registry[model_type](checkpoint=model_path)
        print(f"‚úÖ Successfully loaded '{model_type}' model.")
        return sam.to(device=CONFIG['device'])
    except Exception as e:
        print(f"‚ùå CRITICAL ERROR: The model file is corrupted: {e}")
        print("üëâ Please re-run the 'Download Official Weights' cell above.")
        raise e

# Load the model
sam_model = load_sam_model(CONFIG['model_path'], CONFIG['model_type'])

# Configure model for fine-tuning (freeze encoders)
sam_model.train()
for name, param in sam_model.named_parameters():
    if name.startswith("image_encoder") or name.startswith("prompt_encoder"):
        param.requires_grad = False

trainable_params = sum(p.numel() for p in sam_model.parameters() if p.requires_grad)
print(f"‚úÖ SAM model configured for fine-tuning. Trainable parameters: {trainable_params:,}")

üîÑ Loading SAM model from: /content/sam_vit_h_4b8939.pth
‚úÖ Successfully loaded 'vit_h' model.
‚úÖ SAM model configured for fine-tuning. Trainable parameters: 4,058,340


In [None]:
# Setup optimizer, scheduler, and loss function
optimizer = torch.optim.Adam(
    sam_model.mask_decoder.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay']
)

scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

def combined_loss(pred_masks: torch.Tensor, true_masks: torch.Tensor) -> torch.Tensor:
    """Combines BCE and Dice loss for better segmentation performance."""
    bce_loss = nn.BCEWithLogitsLoss()(pred_masks, true_masks.float())

    pred_flat = torch.sigmoid(pred_masks).reshape(-1)
    true_flat = true_masks.reshape(-1)
    intersection = (pred_flat * true_flat).sum()
    dice_loss = 1 - (2. * intersection + 1e-8) / (pred_flat.sum() + true_flat.sum() + 1e-8)

    return 0.8 * bce_loss + 0.2 * dice_loss

print("‚úÖ Optimizer and loss function configured.")

### Phase 5: Training and Validation Loop

In [None]:
def run_epoch(model, dataloader, optimizer, is_training, device, epoch):
    model.train(is_training)
    epoch_losses = []
    iou_scores, dice_scores = [], []

    desc = "Training" if is_training else "Validation"
    for batch_idx, batch in enumerate(tqdm(dataloader, desc=desc)):
        images = batch['image'].to(device, non_blocking=True)
        all_gt_masks_list = batch['masks']
        all_bboxes_list = batch['bboxes']

        batch_loss = 0

        with torch.set_grad_enabled(is_training):
            with torch.no_grad():
                image_embeddings = model.image_encoder(images)

            for i in range(len(all_gt_masks_list)):
                if not all_gt_masks_list[i]: continue

                gt_mask_np = all_gt_masks_list[i][0]
                bbox_np = all_bboxes_list[i][0]

                gt_mask_torch = torch.from_numpy(gt_mask_np).unsqueeze(0).unsqueeze(0).to(device)

                transform = ResizeLongestSide(CONFIG['image_size'])
                box_torch = torch.as_tensor(transform.apply_boxes(np.array(bbox_np).reshape(1, 4), batch['original_size'][i]),
                                          dtype=torch.float, device=device)

                with torch.no_grad():
                    sparse_embeddings, dense_embeddings = model.prompt_encoder(points=None, boxes=box_torch, masks=None)

                low_res_masks, iou_predictions = model.mask_decoder(
                    image_embeddings=image_embeddings[i].unsqueeze(0),
                    image_pe=model.prompt_encoder.get_dense_pe(),
                    sparse_prompt_embeddings=sparse_embeddings,
                    dense_prompt_embeddings=dense_embeddings,
                    multimask_output=False,
                )

                upscaled_masks = model.postprocess_masks(low_res_masks, batch['input_size'][i], batch['original_size'][i])

                loss = combined_loss(upscaled_masks, gt_mask_torch)
                batch_loss += loss

                if not is_training:
                    pred_mask = torch.sigmoid(upscaled_masks) > 0.5
                    pred_mask_np = pred_mask.cpu().numpy().squeeze().astype(np.uint8)
                    gt_mask_np_uint8 = gt_mask_np.astype(np.uint8)
                    rle_pred = mask_utils.encode(np.asfortranarray(pred_mask_np))
                    rle_gt = mask_utils.encode(np.asfortranarray(gt_mask_np_uint8))
                    h, w = batch['original_size'][i]
                    empty_rle = {'size': [h, w], 'counts': ''}
                    if rle_pred is None: rle_pred = empty_rle
                    if rle_gt is None: rle_gt = empty_rle
                    iou_scores.append(mask_utils.iou([rle_pred], [rle_gt], [0])[0][0])
                    intersection = torch.logical_and(pred_mask.squeeze(), gt_mask_torch.squeeze()).sum().item()
                    total = pred_mask.sum().item() + gt_mask_torch.sum().item()
                    dice_scores.append((2 * intersection) / total if total > 0 else 0)

        if is_training and batch_loss > 0:
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()
            epoch_losses.append(batch_loss.item() / len(all_gt_masks_list))
        elif not is_training and batch_loss > 0:
            epoch_losses.append(batch_loss.item() / len(all_gt_masks_list))

        # --- ADDED: Intra-epoch checkpoint saving logic ---
        if is_training and (batch_idx + 1) % CONFIG['checkpoint_batch_interval'] == 0:
            chk_path = os.path.join(CONFIG['output_dir'], f'checkpoint_epoch_{epoch+1}_batch_{batch_idx+1}.pth')
            torch.save({
                'epoch': epoch,
                'batch_idx': batch_idx,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, chk_path)
            # Use .write to avoid interfering with tqdm progress bar
            tqdm.write(f"\nüíæ Intra-epoch checkpoint saved: {chk_path}")

    return {
        'loss': mean(epoch_losses) if epoch_losses else 0,
        'iou': mean(iou_scores) if iou_scores else 0,
        'dice': mean(dice_scores) if dice_scores else 0
    }

best_val_iou = 0
patience_counter = 0
history = {'train_loss': [], 'val_loss': [], 'val_iou': [], 'val_dice': []}

print(f"\nüöÄ Starting training for {CONFIG['num_epochs']} epochs...")
for epoch in range(CONFIG['num_epochs']):
    print(f"\n--- Epoch {epoch+1}/{CONFIG['num_epochs']} ---")

    train_metrics = run_epoch(sam_model, train_loader, optimizer, is_training=True, device=CONFIG['device'], epoch=epoch)
    history['train_loss'].append(train_metrics['loss'])

    val_metrics = run_epoch(sam_model, val_loader, None, is_training=False, device=CONFIG['device'], epoch=epoch)
    history['val_loss'].append(val_metrics['loss'])
    history['val_iou'].append(val_metrics['iou'])
    history['val_dice'].append(val_metrics['dice'])

    print(f"Train Loss: {train_metrics['loss']:.4f}")
    print(f"Val Loss: {val_metrics['loss']:.4f} | Val IoU: {val_metrics['iou']:.4f} | Val Dice: {val_metrics['dice']:.4f}")

    scheduler.step(val_metrics['loss'])

    if val_metrics['iou'] > best_val_iou:
        best_val_iou = val_metrics['iou']
        patience_counter = 0
        best_model_path = os.path.join(CONFIG['output_dir'], 'best_model.pth')
        torch.save(sam_model.state_dict(), best_model_path)
        print(f"üèÜ New best model saved with IoU: {best_val_iou:.4f}")
    else:
        patience_counter += 1

    if patience_counter >= CONFIG['early_stopping_patience']:
        print(f"üõë Early stopping triggered after {patience_counter} epochs with no improvement.")
        break

    if (epoch + 1) % CONFIG['checkpoint_interval'] == 0:
        chk_path = os.path.join(CONFIG['output_dir'], f'checkpoint_epoch_{epoch+1}.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': sam_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': val_metrics['loss'],
        }, chk_path)
        print(f"üíæ End-of-epoch checkpoint saved: {chk_path}")

print("\n‚úÖ Training completed!")

### Phase 6: Results and Export

In [None]:
# Plot training metrics
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

ax1.plot(history['train_loss'], label='Training Loss', color='blue')
ax1.plot(history['val_loss'], label='Validation Loss', color='red')
ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True)

ax2.plot(history['val_iou'], label='Validation IoU', color='green')
ax2.plot(history['val_dice'], label='Validation Dice', color='purple')
ax2.set_title('Validation Metrics (IoU & Dice)')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Score')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.savefig(os.path.join(CONFIG['output_dir'], 'training_metrics.png'))
plt.show()

print(f"\nüìä FINAL METRICS:")
print(f"Best Validation IoU: {max(history['val_iou']):.4f}")

In [None]:
# Save the final trained model (or the best one)
# Note: The 'best_model.pth' is already saved during training.
# You can also save the final epoch's model if desired.
final_model_path = os.path.join(CONFIG['output_dir'], 'final_model.pth')
torch.save(sam_model.state_dict(), final_model_path)
print(f"‚úÖ Final model state saved to: {final_model_path}")
print(f"‚úÖ Best performing model (by IoU) saved to: {os.path.join(CONFIG['output_dir'], 'best_model.pth')}")

In [None]:
### Phase 7: Final, Unbiased Evaluation on the Test Set

print("\n--- Final Model Evaluation on Unseen Test Data ---")

# 1. Load the best performing model that was saved during training
best_model_path = os.path.join(CONFIG['output_dir'], 'best_model.pth')
if os.path.exists(best_model_path):
    print(f"üîÑ Loading best model from: {best_model_path}")
    # We need to re-initialize the model structure before loading the state dict
    sam_model = sam_model_registry[CONFIG['model_type']]()
    sam_model.load_state_dict(torch.load(best_model_path))
    sam_model.to(CONFIG['device'])
else:
    print("‚ùå Best model file not found. Cannot perform final evaluation.")
    # You might want to handle this case, but for now we'll assume it exists

# 2. Load the test dataset
try:
    test_coco, test_path = load_coco_split(CONFIG['unzip_path'], 'test', CONFIG['annotations_file_name'])
    test_ids = test_coco.getImgIds()

    test_dataset = HvacSamDataset(test_coco, test_ids, test_path)
    test_loader = DataLoader(test_dataset, batch_size=CONFIG['batch_size'], shuffle=False,
                             num_workers=CONFIG['num_workers'], collate_fn=custom_collate_fn)

    print(f"\n‚úÖ Test dataset loaded with {len(test_dataset)} samples.")

except FileNotFoundError as e:
    print(f"\n‚ö†Ô∏è Test split not found: {e}. Skipping final evaluation.")

# 3. Run a single evaluation pass on the test data
if 'test_loader' in locals():
    print("\nüöÄ Running final evaluation on the test set...")
    final_test_metrics = run_epoch(sam_model, test_loader, optimizer=None, is_training=False, device=CONFIG['device'], epoch=0)

    print("\n" + "="*50)
    print("      üéâ FINAL UNBIASED PERFORMANCE METRICS üéâ")
    print("="*50)
    print(f"Final Test IoU:   {final_test_metrics['iou']:.4f}")
    print(f"Final Test Dice:  {final_test_metrics['dice']:.4f}")
    print(f"Final Test Loss:  {final_test_metrics['loss']:.4f}")
    print("="*50)
    print("\nThis is the true expected performance of your model on new data.")