# üü£ HVAC-Specific SAM Fine-Tuning Pipeline (DGX Master - v4.0)
## ‚ö° Full Multi-GPU Architecture with HVAC-Specialized Logic

**Pipeline Overview:**
1.  **Environment:** Installs headless libraries to prevent DGX display driver crashes.
2.  **Data Prep:** Checks, downloads, and intelligently extracts your HVAC dataset structure.
3.  **Script Generation:** Writes the **complete** training logic (Dataset, Adaptive Prompting, AMP, DDP) to `train_dgx.py`.
4.  **Execution:** Launches the job across all available GPUs using `torchrun`.
5.  **Visualization:** Loads the results and plots metrics.

In [None]:
# Cell 1: Environment & Dependency Check
import sys, subprocess, os

def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "--no-cache-dir", package])

print("‚öôÔ∏è Checking DGX Environment...")
try:
    import segment_anything
    import cv2
    import pycocotools
except ImportError:
    print("   Installing missing libraries...")
    install("git+https://github.com/facebookresearch/segment-anything.git")
    install("opencv-python-headless") # Critical for server environments
    install("pycocotools")
    install("matplotlib")
    install("onnxruntime")
    install("onnx")

print("‚úÖ Environment Ready.")

In [None]:
# Cell 2: Configuration, Asset Management & Data Extraction
import os, json, torch, zipfile, urllib.request, shutil
from pathlib import Path

# --- 1. PATH CONFIGURATION (DGX Standard) ---
WORKSPACE_ROOT = Path(os.getcwd()) 
DATA_DIR = WORKSPACE_ROOT / 'data'
OUTPUT_DIR = WORKSPACE_ROOT / 'results'

DATA_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# --- 2. HYPERPARAMETERS (Complete) ---
CONFIG = {
    # Architecture
    'model_type': 'vit_h',
    'image_size': 1024,
    
    # DGX Scaling
    'batch_size': 4,          # Per-GPU batch size. (e.g., 4 GPUs = Effective Batch 16)
    'num_workers': 8,         # High I/O for engineering drawings
    'num_epochs': 50,
    'learning_rate': 1e-4,
    'weight_decay': 1e-4,
    'use_amp': True,          # Automatic Mixed Precision (A100 Tensor Core Optimization)
    
    # HVAC Specialization (Adaptive Prompting)
    'prompt_strategy': 'multi_prompt', # Uses Box, Point, and Scribble
    'bbox_noise_max': 0.1,    # High noise for large objects (Ducts)
    'bbox_noise_min': 0.02,   # Low noise for small objects (Valves)
    'min_mask_area': 25,      # Sensitivity threshold for small instruments
    
    # Checkpointing
    'resume_training': False,
    'checkpoint_interval': 300, # Batches between checkpoints
    
    # Asset Paths
    'model_url': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
    'model_path': str(DATA_DIR / 'sam_vit_h_4b8939.pth'),
    'dataset_zip_path': str(DATA_DIR / 'hvac_dataset_coco.zip'),
    'unzip_path': str(DATA_DIR / 'hvac_dataset_coco'),
    'annotations_file_name': '_annotations.coco.json',
    'best_model_save_path': str(OUTPUT_DIR / 'best_model_dgx.pth'),
    'latest_checkpoint_save_path': str(OUTPUT_DIR / 'latest_checkpoint_dgx.pth'),
    'metrics_save_path': str(OUTPUT_DIR / 'training_log.json')
}

# Save config for the script to load later
CONFIG_PATH = WORKSPACE_ROOT / 'config.json'
with open(CONFIG_PATH, 'w') as f:
    json.dump(CONFIG, f, indent=4)

# --- 3. ASSET FETCHING ---
if not os.path.exists(CONFIG['model_path']):
    print(f"‚¨áÔ∏è Downloading Base SAM Weights...")
    urllib.request.urlretrieve(CONFIG['model_url'], CONFIG['model_path'])
else:
    print("‚úÖ Base SAM Weights Present.")

# --- 4. SMART DATASET EXTRACTION ---
# Ensures hvac_dataset_coco/train/_annotations.coco.json structure exists
if not os.path.exists(CONFIG['unzip_path']):
    if os.path.exists(CONFIG['dataset_zip_path']):
        print(f"üìÇ Unzipping HVAC Dataset...")
        with zipfile.ZipFile(CONFIG['dataset_zip_path'], 'r') as zf:
            # Detect if root folder exists in zip
            first_file = zf.namelist()[0]
            if first_file.startswith('hvac_dataset_coco/'):
                zf.extractall(DATA_DIR)
            else:
                zf.extractall(CONFIG['unzip_path'])
        print(f"‚úÖ Dataset Extracted.")
    else:
        print(f"‚ö†Ô∏è WARNING: Dataset Zip not found at {CONFIG['dataset_zip_path']}. Please upload it.")
else:
    print(f"‚úÖ Dataset directory ready at: {CONFIG['unzip_path']}")

print(f"\nüöÄ Ready to launch on {torch.cuda.device_count()} GPUs.")

### Phase 3: The Complete Training Script (`train_dgx.py`)
This cell generates the actual python script. It contains **all** the logic from the original notebook, including the Dataset class, the specialized prompt generation, the combined loss function, and the training loop, wrapped in DDP (Distributed Data Parallel) code.

In [None]:
%%writefile train_dgx.py
import os
import json
import random
import argparse
import time
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import autocast, GradScaler
from typing import Dict, List, Any, Tuple
from statistics import mean
from pycocotools.coco import COCO
from pycocotools import mask as mask_utils

# Distributed Training Imports
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

# SAM Imports
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide

# --- DDP INFRASTRUCTURE ---
def setup_ddp():
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    return local_rank

def cleanup_ddp():
    dist.destroy_process_group()

def is_main_process():
    return dist.get_rank() == 0

# --- HVAC SPECIALIZED DATASET CLASS ---
class HvacSamDataset(Dataset):
    def __init__(self, coco: COCO, image_ids: List[int], split_path: str, config: dict, is_training: bool = True):
        self.coco = coco
        self.image_ids = image_ids
        self.split_path = split_path
        self.is_training = is_training
        self.config = config
        self.resize_transform = ResizeLongestSide(config['image_size'])
        self.pixel_mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
        self.pixel_std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
        self.prompt_strategy = config.get('prompt_strategy', 'multi_prompt')

    def __len__(self) -> int:
        return len(self.image_ids)

    def _generate_prompt(self, mask: np.ndarray, bbox: List[float], img_area: float):
        # Validation always uses perfect box for metrics
        if not self.is_training:
            return {'box': np.array(bbox)}

        # Determine prompt type
        if self.prompt_strategy == 'multi_prompt':
            prompt_type = random.choice(['box', 'point', 'scribble'])
        elif self.prompt_strategy == 'point':
            prompt_type = 'point'
        else:
            prompt_type = 'box'

        # --- HVAC ADAPTIVE LOGIC ---
        if prompt_type == 'box':
            x, y, w, h = bbox
            box_area = w * h
            rel_size = box_area / (img_area + 1e-6)
            
            # Adaptive Noise: Small objects (Valves) get less noise to prevent box drift off-target.
            # Large objects (Ducts) get more noise to force model generalization.
            if rel_size < 0.05: 
                noise_factor = self.config.get('bbox_noise_min', 0.02)
            else:
                noise_factor = self.config.get('bbox_noise_max', 0.1)
                
            x_noise = w * noise_factor * (random.random() - 0.5) * 2
            y_noise = h * noise_factor * (random.random() - 0.5) * 2
            w_noise = w * noise_factor * (random.random() - 0.5) * 2
            h_noise = h * noise_factor * (random.random() - 0.5) * 2
            
            box = np.array([x + x_noise, y + y_noise, w + w_noise, h + h_noise])
            return {'box': box}

        points = np.argwhere(mask)
        if len(points) == 0: return None

        if prompt_type == 'point':
            point = points[random.randint(0, len(points) - 1)]
            point_coords = np.array([[point[1], point[0]]]) # (x, y)
            point_labels = np.array([1])
            return {'point_coords': point_coords, 'point_labels': point_labels}

        elif prompt_type == 'scribble':
            num_points = min(5, len(points))
            point_indices = np.random.choice(len(points), num_points, replace=False)
            scribble_points = points[point_indices]
            point_coords = scribble_points[:, ::-1]
            point_labels = np.ones(num_points)
            return {'point_coords': point_coords, 'point_labels': point_labels}

        return None

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        img_id = self.image_ids[idx]
        img_info = self.coco.loadImgs([img_id])[0]
        
        # Handle nested directory structures
        image_path = os.path.join(self.split_path, img_info['file_name'])
        if not os.path.exists(image_path):
             image_path = os.path.join(os.path.dirname(self.split_path), img_info['file_name'])
             if not os.path.exists(image_path):
                raise FileNotFoundError(f"Image not found: {image_path}")

        image = cv2.imread(image_path)
        if image is None:
            raise ValueError(f"Failed to load image: {image_path}")
            
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        orig_h, orig_w = image.shape[:2]
        img_area = orig_h * orig_w

        ann_ids = self.coco.getAnnIds(imgIds=[img_id])
        annotations = self.coco.loadAnns(ann_ids)

        masks, prompts = [], []
        for ann in annotations:
            if 'segmentation' not in ann or ann.get('iscrowd', 0) == 1: continue
            mask = self.coco.annToMask(ann)
            # HVAC Filter: Skip tiny noise artifacts, keep small valves
            if mask.sum() < self.config['min_mask_area']: continue

            prompt = self._generate_prompt(mask, ann['bbox'], img_area)
            if prompt:
                masks.append(mask.astype(bool))
                prompts.append(prompt)

        original_size = image.shape[:2]
        resized_image = self.resize_transform.apply_image(image)
        input_image_torch = torch.as_tensor(resized_image, dtype=torch.float32).permute(2, 0, 1).contiguous()
        input_image_torch = (input_image_torch - self.pixel_mean) / self.pixel_std
        h, w = input_image_torch.shape[-2:]
        padh, padw = self.config['image_size'] - h, self.config['image_size'] - w
        input_image_padded = torch.nn.functional.pad(input_image_torch, (0, padw, 0, padh))

        return {
            'image': input_image_padded,
            'masks': masks,
            'prompts': prompts,
            'original_size': original_size,
            'input_size': (h, w)
        }

def custom_collate_fn(batch: List[Dict]) -> Dict[str, Any]:
    return {
        'image': torch.stack([item['image'] for item in batch]),
        'masks': [item['masks'] for item in batch],
        'prompts': [item['prompts'] for item in batch],
        'original_size': [item['original_size'] for item in batch],
        'input_size': [item['input_size'] for item in batch]
    }

# --- OPTIMIZED LOSS FUNCTION ---
def combined_loss(pred_masks: torch.Tensor, true_masks: torch.Tensor) -> torch.Tensor:
    # 80% BCE (Pixel accuracy) + 20% Dice (Shape overlap) - Optimal for Engineering Drawings
    bce_loss = nn.BCEWithLogitsLoss()(pred_masks, true_masks.float())
    pred_flat = torch.sigmoid(pred_masks).reshape(-1)
    true_flat = true_masks.reshape(-1)
    intersection = (pred_flat * true_flat).sum()
    dice_loss = 1 - (2. * intersection + 1e-8) / (pred_flat.sum() + true_flat.sum() + 1e-8)
    return 0.8 * bce_loss + 0.2 * dice_loss

# --- CORE TRAINING LOOP ---
def run_epoch(model, dataloader, optimizer, is_training, device, config, scaler):
    model.train(is_training)
    epoch_losses, epoch_ious = [], []
    
    if is_main_process():
        print(f"{'Training' if is_training else 'Validating'}...")
    
    for batch_idx, batch in enumerate(dataloader):
        images = batch['image'].to(device, non_blocking=True)
        all_gt_masks = batch['masks']
        all_prompts = batch['prompts']

        if is_training:
            optimizer.zero_grad()

        with torch.set_grad_enabled(is_training):
            # AMP Context: Runs Image Encoder in FP16 for speed
            with autocast(enabled=config['use_amp']):
                image_embeddings = model.module.image_encoder(images)

            batch_loss = 0
            num_anns = 0

            # Loop through batch images
            for i in range(len(all_gt_masks)):
                gt_masks = all_gt_masks[i]
                prompts = all_prompts[i]
                if len(gt_masks) == 0: continue

                curr_embedding = image_embeddings[i].unsqueeze(0)

                # Loop through HVAC components in current image
                for j in range(len(gt_masks)):
                    gt_mask_np = gt_masks[j]
                    prompt = prompts[j]
                    num_anns += 1

                    gt_mask_torch = torch.from_numpy(gt_mask_np).unsqueeze(0).unsqueeze(0).to(device)
                    transform = ResizeLongestSide(config['image_size'])
                    box_torch, points_torch, labels_torch = None, None, None

                    if 'box' in prompt:
                        box_torch = torch.as_tensor(transform.apply_boxes(prompt['box'].reshape(1, 4), batch['original_size'][i]), dtype=torch.float, device=device)
                    elif 'point_coords' in prompt:
                        points_torch = torch.as_tensor(transform.apply_coords(prompt['point_coords'], batch['original_size'][i]), dtype=torch.float, device=device).unsqueeze(0)
                        labels_torch = torch.as_tensor(prompt['point_labels'], dtype=torch.float, device=device).unsqueeze(0)

                    # Decoder Forward Pass (Mixed Precision)
                    with autocast(enabled=config['use_amp']):
                        with torch.no_grad():
                            sparse, dense = model.module.prompt_encoder(
                                points=(points_torch, labels_torch) if points_torch is not None else None,
                                boxes=box_torch,
                                masks=None
                            )
                        
                        low_res_masks, _ = model.module.mask_decoder(
                            image_embeddings=curr_embedding,
                            image_pe=model.module.prompt_encoder.get_dense_pe(),
                            sparse_prompt_embeddings=sparse,
                            dense_prompt_embeddings=dense,
                            multimask_output=False
                        )

                        upscaled_masks = model.module.postprocess_masks(low_res_masks, batch['input_size'][i], batch['original_size'][i])
                        loss = combined_loss(upscaled_masks, gt_mask_torch)
                        batch_loss += loss

                    if not is_training:
                        pred = (torch.sigmoid(upscaled_masks) > 0.5).squeeze().cpu().numpy().astype(np.uint8)
                        iou = mask_utils.iou([mask_utils.encode(np.asfortranarray(pred))], [mask_utils.encode(np.asfortranarray(gt_mask_np.astype(np.uint8)))], [0])[0][0]
                        epoch_ious.append(iou)

            if num_anns > 0:
                norm_loss = batch_loss / num_anns
                if is_training:
                    # AMP Scaled Backward Pass
                    scaler.scale(norm_loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                epoch_losses.append(norm_loss.item())

    # Sync Metrics across GPUs
    mean_loss = torch.tensor(mean(epoch_losses) if epoch_losses else 0.0).to(device)
    mean_iou = torch.tensor(mean(epoch_ious) if epoch_ious else 0.0).to(device)
    dist.all_reduce(mean_loss, op=dist.ReduceOp.SUM)
    dist.all_reduce(mean_iou, op=dist.ReduceOp.SUM)
    world_size = dist.get_world_size()
    return mean_loss.item() / world_size, mean_iou.item() / world_size

# --- MAIN ORCHESTRATOR ---
def main(config):
    local_rank = setup_ddp()
    device = torch.device("cuda", local_rank)

    # Load Datasets
    train_dir = os.path.join(config['unzip_path'], 'train')
    valid_dir = os.path.join(config['unzip_path'], 'valid')
    
    train_coco = COCO(os.path.join(train_dir, config['annotations_file_name']))
    val_coco = COCO(os.path.join(valid_dir, config['annotations_file_name']))
    
    train_ds = HvacSamDataset(train_coco, train_coco.getImgIds(), train_dir, config, is_training=True)
    val_ds = HvacSamDataset(val_coco, val_coco.getImgIds(), valid_dir, config, is_training=False)

    train_sampler = DistributedSampler(train_ds)
    val_sampler = DistributedSampler(val_ds, shuffle=False)

    train_loader = DataLoader(train_ds, batch_size=config['batch_size'], sampler=train_sampler, 
                              num_workers=config['num_workers'], collate_fn=custom_collate_fn, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=config['batch_size'], sampler=val_sampler, 
                            num_workers=config['num_workers'], collate_fn=custom_collate_fn, pin_memory=True)

    # Model Setup
    sam_model = sam_model_registry[config['model_type']]()
    
    # Load Weights
    if config['resume_training'] and os.path.exists(config['latest_checkpoint_save_path']):
        checkpoint = torch.load(config['latest_checkpoint_save_path'], map_location='cpu')
        sam_model.load_state_dict(checkpoint['model_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        if is_main_process(): print(f"üîÑ Resuming from epoch {start_epoch}")
    else:
        sam_model.load_state_dict(torch.load(config['model_path'], map_location='cpu'))
        start_epoch = 0
        if is_main_process(): print("üîÑ Loaded base SAM weights.")

    sam_model.to(device)
    
    # Freeze Encoders
    for name, param in sam_model.named_parameters():
        if name.startswith("image_encoder") or name.startswith("prompt_encoder"):
            param.requires_grad = False

    model = DDP(sam_model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
    scaler = GradScaler(enabled=config['use_amp'])
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    best_iou = 0.0
    
    metrics_history = {'train_loss': [], 'val_loss': [], 'val_iou': []}

    # Epoch Loop
    for epoch in range(start_epoch, config['num_epochs']):
        train_sampler.set_epoch(epoch)
        train_loss, _ = run_epoch(model, train_loader, optimizer, True, device, config, scaler)
        val_loss, val_iou = run_epoch(model, val_loader, None, False, device, config, scaler)
        
        scheduler.step(val_loss)

        if is_main_process():
            print(f"Epoch {epoch+1}/{config['num_epochs']} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val IoU: {val_iou:.4f}")
            
            metrics_history['train_loss'].append(train_loss)
            metrics_history['val_loss'].append(val_loss)
            metrics_history['val_iou'].append(val_iou)
            
            # Save Metrics
            with open(config['metrics_save_path'], 'w') as f:
                json.dump(metrics_history, f)

            # Save Checkpoint
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.module.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, config['latest_checkpoint_save_path'])

            # Save Best Model
            if val_iou > best_iou:
                best_iou = val_iou
                torch.save(model.module.state_dict(), config['best_model_save_path'])
                print(f"üèÜ New Best Model Saved (IoU: {best_iou:.4f})")

    cleanup_ddp()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, required=True)
    args = parser.parse_args()
    with open(args.config, 'r') as f:
        main(json.load(f))


### Phase 4: Launch Multi-GPU Training
This uses `torchrun` to spawn one process per GPU on the node.

In [None]:
# Cell 4: Execute Distributed Training
num_gpus = torch.cuda.device_count()
print(f"\nüöÄ Launching training on {num_gpus} GPUs...")

!torchrun --nproc_per_node={num_gpus} train_dgx.py --config config.json

print("\n‚úÖ Training pipeline finished. Check 'results/' directory.")

### Phase 5: Visualization & Final Reporting
Since training ran in a subprocess, we load the saved JSON metrics to visualize performance.

In [None]:
# Cell 5: Visualize Results
import json
import matplotlib.pyplot as plt
import os

metrics_path = os.path.join(os.getcwd(), 'results', 'training_log.json')

if os.path.exists(metrics_path):
    with open(metrics_path, 'r') as f:
        history = json.load(f)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss Plot
    ax1.plot(history['train_loss'], label='Train Loss', color='blue')
    ax1.plot(history['val_loss'], label='Val Loss', color='red')
    ax1.set_title('Training vs Validation Loss')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)

    # IoU Plot
    ax2.plot(history['val_iou'], label='Validation IoU', color='green')
    ax2.set_title('Mean IoU Performance')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('IoU Score')
    ax2.legend()
    ax2.grid(True)

    plt.show()
    print(f"üèÜ Peak Validation IoU: {max(history['val_iou']):.4f}")
else:
    print("‚ö†Ô∏è Metrics file not found. Did training complete successfully?")