# Training with Weights & Biases Integration

This notebook demonstrates iSAID instance segmentation training with full W&B tracking using the integrated `Trainer` class.

**Features:**
- Automatic logging of training/validation losses and metrics
- Gradient norm tracking for CBAM and RoI layers  
- Learning rate scheduling (OneCycleLR or ReduceLROnPlateau)
- Validation predictions visualization
- Model checkpointing as W&B artifacts
- mAP, mean IoU, and overfitting gap metrics

## 1. Setup

In [None]:
!git clone https://github.com/michaelo-ponteski/isaid-instance-segmentation.git

In [None]:
%cd /kaggle/working/isaid-instance-segmentation
!git pull
!git switch wandb

In [None]:
import os
import sys
import gc
import numpy as np
import torch
from pathlib import Path

# Add project root to path
sys.path.insert(0, str(Path.cwd().parent))

# Set memory optimization for CUDA
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Available memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
!pip install --upgrade wandb

In [None]:
# Install wandb if not available
try:
    import wandb
    print(f"wandb version: {wandb.__version__}") # Must be newest
except ImportError:
    print("Installing wandb...")
    !pip install --upgrade wandb
    import wandb

### Kaggle wandb API setup

In [None]:
from kaggle_secrets import UserSecretsClient

# 1. Retrieve the key securely
user_secrets = UserSecretsClient()
my_secret = user_secrets.get_secret("wandb_key") 

# 2. Login explicitly (bypasses the interactive freeze)
wandb.login(key=my_secret)

In [None]:
import importlib
import datasets.isaid_dataset
import models.maskrcnn_model
import training.transforms
import training.trainer

importlib.reload(datasets.isaid_dataset)
importlib.reload(models.maskrcnn_model)
importlib.reload(training.transforms)
importlib.reload(training.trainer)

from datasets.isaid_dataset import iSAIDDataset
from training.transforms import get_transforms
from training.trainer import Trainer, create_datasets
from models.maskrcnn_model import CustomMaskRCNN
from training.wandb_logger import ISAID_CLASS_LABELS

print("All modules imported successfully!")
print(f"\niSAID Class Labels:")
for idx, name in ISAID_CLASS_LABELS.items():
    print(f"  {idx}: {name}")

## 2. Configuration

In [None]:
# All hyperparameters in one place - this will be logged to W&B
HYPERPARAMETERS = {
    # Dataset
    "data_root": "/kaggle/input/isaid-patches/iSAID_patches",
    "num_classes": 16,
    "image_size": 800,
    
    # Training
    "batch_size": 2,
    "num_epochs": 20,
    "learning_rate": 0.0001,
    "weight_decay": 0.0005,
    "momentum": 0.9,
    
    # Model Architecture
    "backbone": "efficientnet_b0",
    "pretrained_backbone": True,
    "cbam_reduction_ratio": 16,
    "roi_head_layers": 4,
    
    # RPN Anchors (optimized for iSAID)
    "anchor_sizes": ((8, 16), (16, 32), (32, 64), (64, 128)),
    "aspect_ratios": ((0.5, 1.0, 2.0),) * 4,
    
    # W&B Logging
    "wandb_project": "isaid-custom-segmentation",
    "wandb_entity": "marek-olnk-put-pozna-",
    "wandb_log_freq": 20,  # Log every N batches
    "wandb_num_val_images": 4,  # Number of images for validation visualization
    "wandb_conf_threshold": 0.5,  # Confidence threshold for predictions
}

print("Hyperparameters:")
for k, v in HYPERPARAMETERS.items():
    print(f"  {k}: {v}")

## 3. Load Data

In [None]:
# Create datasets
train_dataset, val_dataset = create_datasets(
    data_root=HYPERPARAMETERS["data_root"],
    image_size=HYPERPARAMETERS["image_size"],
    subset_fraction=1.0,  # Use full dataset
)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")

## 4. Create Model

In [None]:
# Create model with custom anchor configuration
model = CustomMaskRCNN(
    num_classes=HYPERPARAMETERS["num_classes"],
    pretrained_backbone=HYPERPARAMETERS["pretrained_backbone"],
    rpn_anchor_sizes=HYPERPARAMETERS["anchor_sizes"],
    rpn_aspect_ratios=HYPERPARAMETERS["aspect_ratios"],
)

# Print model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 5. Create Trainer with W&B Integration

In [None]:
# Create trainer with W&B integration
# The trainer handles all logging automatically!
trainer = Trainer(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    model=model,
    batch_size=HYPERPARAMETERS["batch_size"],
    lr=HYPERPARAMETERS["learning_rate"],
    device=device,
    use_amp=True,
    num_workers=4,
    # W&B configuration
    wandb_project=HYPERPARAMETERS["wandb_project"],
    wandb_entity=HYPERPARAMETERS["entity"],
    wandb_tags=["maskrcnn", "efficientnet", "cbam", "trainer-integrated"],
    wandb_notes="Training with integrated Trainer class - EfficientNet backbone + CBAM + FPN",
    wandb_log_freq=HYPERPARAMETERS["wandb_log_freq"],
    wandb_num_val_images=HYPERPARAMETERS["wandb_num_val_images"],
    wandb_conf_threshold=HYPERPARAMETERS["wandb_conf_threshold"],
    hyperparameters=HYPERPARAMETERS,
)

print(f"\nW&B Run: {trainer.wandb_logger.run.name}")
print(f"URL: {trainer.wandb_logger.run.url}")

## 6. Training

The `Trainer.fit()` method handles everything:
- Training loop with gradient clipping and AMP
- Validation loss computation
- mAP and mean IoU metrics
- W&B logging (losses, gradients, predictions, checkpoints)
- Learning rate scheduling
- Best model saving

In [None]:
# Run training!
# All W&B logging happens automatically inside trainer.fit()
history = trainer.fit(
    epochs=HYPERPARAMETERS["num_epochs"],
    save_dir="checkpoints",
    compute_metrics_every=1,  # Compute mAP every epoch
    max_map_samples=200,  # Limit samples for faster mAP computation
)

print("\nTraining complete!")

## 7. Visualize Results

In [None]:
# Plot training history
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss curves
ax = axes[0, 0]
ax.plot(history["train/loss"], label="Train Loss")
ax.plot(history["val/loss"], label="Val Loss")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_title("Training & Validation Loss")
ax.legend()
ax.grid(True, alpha=0.3)

# mAP curves
ax = axes[0, 1]
ax.plot(history["train/mAP@0.5"], label="Train mAP@0.5")
ax.plot(history["val/mAP@0.5"], label="Val mAP@0.5")
ax.set_xlabel("Epoch")
ax.set_ylabel("mAP@0.5")
ax.set_title("mAP Performance")
ax.legend()
ax.grid(True, alpha=0.3)

# Learning rate
ax = axes[1, 0]
ax.plot(history["train/lr"])
ax.set_xlabel("Epoch")
ax.set_ylabel("Learning Rate")
ax.set_title("Learning Rate Schedule")
ax.set_yscale("log")
ax.grid(True, alpha=0.3)

# Gradient norm
ax = axes[1, 1]
ax.plot(history["train/grad_norm"])
ax.set_xlabel("Epoch")
ax.set_ylabel("Gradient Norm")
ax.set_title("Training Gradient Norm")
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 8. Visualize Predictions

In [None]:
# Visualize predictions on validation set
trainer.visualize_predictions(
    num_samples=5,
    score_threshold=0.5,
    mask_alpha=0.4,
)

## 9. Finish W&B Run

In [None]:
# Finish the W&B run
trainer.finish()

print(f"\nW&B run completed!")
print(f"View results at: {trainer.wandb_logger.run.url}")

## 10. Load Model from W&B Artifact (Optional)

In [None]:
# Example: Load best model from W&B artifacts
# Uncomment to use

# import wandb
# api = wandb.Api()
# artifact = api.artifact('YOUR_ENTITY/isaid-custom-segmentation/isaid-model:best')
# artifact_dir = artifact.download()
# 
# model = CustomMaskRCNN(num_classes=16)
# model.load_state_dict(torch.load(f"{artifact_dir}/best_model.pth"))
# model.eval()
# print("Model loaded from W&B artifact!")