# Training with Weights & Biases Integration

This notebook demonstrates full W&B tracking for iSAID instance segmentation training.

## 1. Setup

In [None]:
!git clone https://github.com/michaelo-ponteski/isaid-instance-segmentation.git

In [None]:
%cd /kaggle/working/isaid-instance-segmentation
!git pull
!git switch wandb

In [None]:
import os
import sys
import gc
import numpy as np
import torch
from pathlib import Path

# Add project root to path
sys.path.insert(0, str(Path.cwd().parent))

# Set memory optimization for CUDA
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Available memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install wandb if not available
try:
    import wandb
    print(f"wandb version: {wandb.__version__}")
except ImportError:
    print("Installing wandb...")
    !pip install wandb
    import wandb

In [None]:
# Login to W&B (run once)
# wandb.login()

In [None]:
import importlib
import datasets.isaid_dataset
import models.maskrcnn_model
import training.transforms

importlib.reload(datasets.isaid_dataset)
importlib.reload(models.maskrcnn_model)
importlib.reload(training.transforms)

from datasets.isaid_dataset import iSAIDDataset
from training.transforms import get_transforms
from training.wandb_logger import (
    WandbLogger,
    WandbConfig,
    create_wandb_logger,
    compute_gradient_norms,
    get_fixed_val_batch,
    ISAID_CLASS_LABELS,
)
from models.maskrcnn_model import CustomMaskRCNN, get_custom_maskrcnn

print("All modules imported successfully!")
print(f"\niSAID Class Labels:")
for idx, name in ISAID_CLASS_LABELS.items():
    print(f"  {idx}: {name}")

## 2. Configuration

In [None]:
# All hyperparameters in one place - this will be logged to W&B
HYPERPARAMETERS = {
    # Dataset
    "data_root": "/kaggle/input/isaid-patches/iSAID_patches",
    "num_classes": 16,
    "image_size": 800,
    
    # Training
    "batch_size": 2,
    "num_epochs": 50,
    "learning_rate": 0.005,
    "weight_decay": 0.0005,
    "momentum": 0.9,
    
    # Model Architecture
    "backbone": "efficientnet_b0",
    "pretrained_backbone": True,
    "cbam_reduction_ratio": 16,
    "roi_head_layers": 4,
    
    # RPN Anchors
    "anchor_sizes": ((8, 16), (16, 32), (32, 64), (64, 128)),
    "aspect_ratios": ((0.5, 1.0, 2.0),) * 4,
    
    # Scheduler
    "scheduler_type": "onecycle",
    "max_lr": 0.01,
    
    # W&B Logging
    "log_freq": 20,  # Log every N batches
    "num_val_images": 4,  # Number of images for validation visualization
    "conf_threshold": 0.5,  # Confidence threshold for predictions
}

print("Hyperparameters:")
for k, v in HYPERPARAMETERS.items():
    print(f"  {k}: {v}")

## 3. Initialize W&B

In [None]:
# Create W&B logger
wandb_config = WandbConfig(
    project="isaid-custom-segmentation",
    run_name=None,  # Auto-generated, or set a custom name
    tags=["maskrcnn", "efficientnet", "cbam"],
    notes="Training with custom EfficientNet backbone + CBAM + FPN",
    log_freq=HYPERPARAMETERS["log_freq"],
    num_val_images=HYPERPARAMETERS["num_val_images"],
    conf_threshold=HYPERPARAMETERS["conf_threshold"],
)

logger = WandbLogger(wandb_config, HYPERPARAMETERS)
print(f"\nW&B Run: {logger.run.name}")
print(f"URL: {logger.run.url}")

## 4. Load Data

In [None]:
# Load datasets
print("Loading datasets...")

train_dataset = iSAIDDataset(
    HYPERPARAMETERS["data_root"],
    split="train",
    transforms=get_transforms(train=True),
    image_size=HYPERPARAMETERS["image_size"],
)

val_dataset = iSAIDDataset(
    HYPERPARAMETERS["data_root"],
    split="val",
    transforms=get_transforms(train=False),
    image_size=HYPERPARAMETERS["image_size"],
)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")

# Create data loaders
def collate_fn(batch):
    return tuple(zip(*batch))

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=HYPERPARAMETERS["batch_size"],
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
)

In [None]:
# Select fixed validation images for consistent visualization
val_indices = [0, 10, 25, 50]  # Fixed indices for tracking progress
logger.set_validation_images(val_indices)

# Pre-load these images
fixed_val_images, fixed_val_targets = get_fixed_val_batch(val_dataset, val_indices, device)
print(f"Selected {len(val_indices)} validation images for visualization")

## 5. Create Model

In [None]:
# Create model
model = CustomMaskRCNN(
    num_classes=HYPERPARAMETERS["num_classes"],
    pretrained_backbone=HYPERPARAMETERS["pretrained_backbone"],
    rpn_anchor_sizes=HYPERPARAMETERS["anchor_sizes"],
    rpn_aspect_ratios=HYPERPARAMETERS["aspect_ratios"],
)
model.to(device)

# Print model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Log model architecture to W&B
wandb.watch(model, log="all", log_freq=100)

## 6. Setup Optimizer & Scheduler

In [None]:
# Optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params,
    lr=HYPERPARAMETERS["learning_rate"],
    momentum=HYPERPARAMETERS["momentum"],
    weight_decay=HYPERPARAMETERS["weight_decay"],
)

# Scheduler
total_steps = len(train_loader) * HYPERPARAMETERS["num_epochs"]
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=HYPERPARAMETERS["max_lr"],
    total_steps=total_steps,
    pct_start=0.1,
)

print(f"Total training steps: {total_steps}")

## 7. Training Loop with W&B Logging

In [None]:
from tqdm.auto import tqdm

best_val_loss = float('inf')
global_step = 0

for epoch in range(HYPERPARAMETERS["num_epochs"]):
    logger.epoch = epoch
    
    # =========================================================================
    # TRAINING
    # =========================================================================
    model.train()
    epoch_loss = 0.0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{HYPERPARAMETERS['num_epochs']}")
    
    for batch_idx, (images, targets) in enumerate(pbar):
        # Move to device
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v 
                   for k, v in t.items()} for t in targets]
        
        # Skip empty batches
        if all(len(t['boxes']) == 0 for t in targets):
            continue
        
        # Forward pass
        loss_dict = model(images, targets)
        total_loss = sum(loss_dict.values())
        
        # Backward pass
        optimizer.zero_grad()
        total_loss.backward()
        
        # =====================================================================
        # W&B LOGGING: Training metrics & Gradients (before optimizer step)
        # =====================================================================
        current_lr = scheduler.get_last_lr()[0]
        
        # Log training metrics (every N steps)
        logger.log_training_step(
            loss_dict=loss_dict,
            learning_rate=current_lr,
            step=global_step,
            epoch=epoch,
        )
        
        # Log gradient norms for CBAM and RoI layers
        logger.log_gradient_norms(model, step=global_step)
        
        # Optimizer step
        torch.nn.utils.clip_grad_norm_(params, max_norm=1.0)
        optimizer.step()
        scheduler.step()
        
        epoch_loss += total_loss.item()
        global_step += 1
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{total_loss.item():.4f}',
            'lr': f'{current_lr:.6f}',
        })
    
    avg_train_loss = epoch_loss / len(train_loader)
    
    # =========================================================================
    # VALIDATION
    # =========================================================================
    model.eval()
    val_loss = 0.0
    
    with torch.no_grad():
        for images, targets in tqdm(val_loader, desc="Validating"):
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v 
                       for k, v in t.items()} for t in targets]
            
            loss_dict = model(images, targets)
            val_loss += sum(loss_dict.values()).item()
    
    avg_val_loss = val_loss / len(val_loader)
    
    # =========================================================================
    # W&B LOGGING: Validation metrics
    # =========================================================================
    logger.log_validation_metrics(
        val_loss=avg_val_loss,
        epoch=epoch,
    )
    
    # =========================================================================
    # W&B LOGGING: Validation predictions visualization
    # =========================================================================
    model.eval()
    with torch.no_grad():
        # Get predictions on fixed validation images
        val_images_device = [img.to(device) for img in fixed_val_images]
        predictions = model(val_images_device)
        
        # Move predictions back to CPU for logging
        predictions_cpu = [
            {k: v.cpu() for k, v in pred.items()}
            for pred in predictions
        ]
        
        # Log visualization
        logger.log_validation_predictions(
            images=fixed_val_images,
            targets=fixed_val_targets,
            predictions=predictions_cpu,
            epoch=epoch,
        )
    
    # =========================================================================
    # MODEL CHECKPOINTING
    # =========================================================================
    # Save checkpoint
    checkpoint_path = f"checkpoints/epoch_{epoch+1}.pth"
    os.makedirs("checkpoints", exist_ok=True)
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'val_loss': avg_val_loss,
    }, checkpoint_path)
    
    # Save best model and log as W&B artifact
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_model_path = "checkpoints/best_model.pth"
        torch.save(model.state_dict(), best_model_path)
        
        # Log best model as W&B artifact
        logger.log_best_model(best_model_path, avg_val_loss)
        print(f"  -> New best model saved! Val Loss: {avg_val_loss:.4f}")
    
    print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f}, Val Loss={avg_val_loss:.4f}")
    
    # Memory cleanup
    gc.collect()
    torch.cuda.empty_cache()

print("\nTraining complete!")

## 8. Finish W&B Run

In [None]:
# Finish the W&B run
logger.finish()

print(f"\nW&B run completed!")
print(f"View results at: {logger.run.url}")

## 9. Load Model from W&B Artifact

In [None]:
# Example: Load best model from W&B artifacts
# Uncomment to use

# api = wandb.Api()
# artifact = api.artifact('YOUR_ENTITY/isaid-custom-segmentation/isaid-model:best')
# artifact_dir = artifact.download()
# 
# model = CustomMaskRCNN(num_classes=16)
# model.load_state_dict(torch.load(f"{artifact_dir}/best_model.pth"))
# model.eval()