## 1. Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 2. Navigate to Project Directory

**Update this path to where you uploaded the project in your Google Drive**

In [None]:
import os

# Update this path!
PROJECT_PATH = '/content/drive/MyDrive/spondylolisthesis-maht-net'

os.chdir(PROJECT_PATH)
print(f"Current directory: {os.getcwd()}")
!ls -la

## 3. Install Dependencies

In [None]:
!pip install -q -r requirements.txt
print("✓ Dependencies installed")

## 4. Verify GPU Availability

In [None]:
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️ WARNING: GPU not available. Training will be slow on CPU.")

## 5. Verify Data

In [None]:
from pathlib import Path

# Check data paths
data_root = Path('data/Train/Keypointrcnn_data')
train_images = data_root / 'images' / 'train'
train_labels = data_root / 'labels' / 'train'
val_images = data_root / 'images' / 'val'
val_labels = data_root / 'labels' / 'val'

print("Data directories:")
print(f"  Train images: {train_images.exists()} - {len(list(train_images.glob('*')))} files")
print(f"  Train labels: {train_labels.exists()} - {len(list(train_labels.glob('*.json')))} files")
print(f"  Val images:   {val_images.exists()} - {len(list(val_images.glob('*')))} files")
print(f"  Val labels:   {val_labels.exists()} - {len(list(val_labels.glob('*.json')))} files")

if all([train_images.exists(), train_labels.exists(), val_images.exists(), val_labels.exists()]):
    print("\n✓ All data directories found!")
else:
    print("\n✗ ERROR: Some data directories are missing!")

## 6. Run Quick Tests (Optional but Recommended)

In [None]:
# Run comprehensive tests
!python scripts/test_unet.py

## 7. Test Data Loading

In [None]:
import sys
sys.path.insert(0, os.getcwd())

from src.data.unet_dataset import create_unet_dataloaders
from src.data.preprocessing import ImagePreprocessor
from src.data.augmentation import SpondylolisthesisAugmentation

# Create preprocessor
preprocessor = ImagePreprocessor(
    target_size=(512, 512),
    normalize=True,
    apply_clahe=True
)

# Create augmentation
augmentation = SpondylolisthesisAugmentation(mode='train')

# Create dataloaders (small batch for testing)
print("Creating dataloaders...")
train_loader, val_loader = create_unet_dataloaders(
    train_image_dir=train_images,
    train_label_dir=train_labels,
    val_image_dir=val_images,
    val_label_dir=val_labels,
    batch_size=4,
    num_workers=2,
    heatmap_sigma=3.0,
    output_stride=1,
    preprocessor=preprocessor,
    augmentation=augmentation
)

print(f"✓ Train samples: {len(train_loader.dataset)}")
print(f"✓ Val samples:   {len(val_loader.dataset)}")
print(f"✓ Train batches: {len(train_loader)}")
print(f"✓ Val batches:   {len(val_loader)}")

# Test loading a batch
batch = next(iter(train_loader))
print(f"\n✓ Batch loaded successfully:")
print(f"  Images shape:   {batch['images'].shape}")
print(f"  Heatmaps shape: {batch['heatmaps'].shape}")

## 8. Visualize Sample Data

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Get a sample
sample = train_loader.dataset[0]

# Denormalize image for visualization
image = sample['image'].permute(1, 2, 0).numpy()
image = (image - image.min()) / (image.max() - image.min())

# Get heatmaps
heatmaps = sample['heatmaps'].numpy()

# Plot
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Original image
axes[0, 0].imshow(image)
axes[0, 0].set_title('Original Image')
axes[0, 0].axis('off')

# Heatmaps
heatmap_titles = ['Bottom-Left Corner', 'Bottom-Right Corner', 
                  'Top-Left Corner', 'Top-Right Corner']
for i in range(4):
    row = (i + 1) // 3
    col = (i + 1) % 3
    axes[row, col].imshow(heatmaps[i], cmap='hot')
    axes[row, col].set_title(heatmap_titles[i])
    axes[row, col].axis('off')

# Combined heatmap
combined = np.max(heatmaps, axis=0)
axes[1, 2].imshow(image)
axes[1, 2].imshow(combined, cmap='hot', alpha=0.5)
axes[1, 2].set_title('All Keypoints Overlay')
axes[1, 2].axis('off')

plt.tight_layout()
plt.show()

print(f"Sample filename: {sample['filename']}")
print(f"Number of vertebrae: {len(sample['keypoints'])}")

## 9. Initialize Model

In [None]:
from models.unet import create_unet

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create model
model = create_unet(
    in_channels=3,
    num_keypoints=4,
    bilinear=False,
    base_channels=64
)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n✓ Model created successfully")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

# Test forward pass
dummy_input = torch.randn(1, 3, 512, 512).to(device)
with torch.no_grad():
    output = model(dummy_input)
print(f"  Input shape:  {dummy_input.shape}")
print(f"  Output shape: {output.shape}")
print(f"\n✓ Model ready for training!")

## 10. Start Training

### Option A: Use the training script (Recommended)

In [None]:
# Train using the script
!python scripts/train_unet.py

### Option B: Custom training loop in notebook

In [None]:
import yaml
import torch.optim as optim
from scripts.train_unet import UNetKeypointLoss, UNetTrainer

# Load config
with open('experiments/configs/unet_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Update batch size for Colab (adjust based on GPU memory)
config['training']['batch_size'] = 8  # Reduce if OOM
config['training']['num_epochs'] = 50

print(f"Training configuration:")
print(f"  Batch size: {config['training']['batch_size']}")
print(f"  Epochs: {config['training']['num_epochs']}")
print(f"  Learning rate: {config['training']['learning_rate']}")

# Recreate dataloaders with training batch size
train_loader, val_loader = create_unet_dataloaders(
    train_image_dir=train_images,
    train_label_dir=train_labels,
    val_image_dir=val_images,
    val_label_dir=val_labels,
    batch_size=config['training']['batch_size'],
    num_workers=2,
    heatmap_sigma=3.0,
    output_stride=1,
    preprocessor=preprocessor,
    augmentation=augmentation
)

# Loss function
criterion = UNetKeypointLoss(use_focal=True, focal_alpha=2.0, focal_beta=4.0)

# Optimizer
optimizer = optim.Adam(
    model.parameters(),
    lr=config['training']['learning_rate'],
    weight_decay=config['training']['weight_decay']
)

# Scheduler
scheduler = optim.lr_scheduler.StepLR(
    optimizer,
    step_size=config['training']['step_size'],
    gamma=config['training']['gamma']
)

# Create trainer
save_dir = Path('experiments/results/unet_colab')
trainer = UNetTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    config=config,
    save_dir=save_dir
)

print("\n✓ Trainer initialized. Starting training...\n")

# Train
trainer.train(config['training']['num_epochs'])

## 11. Visualize Training Results

In [None]:
# Plot training curves
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(trainer.train_losses, label='Train Loss')
plt.plot(trainer.val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Progress')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(trainer.val_losses, label='Val Loss', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Validation Loss')
plt.title('Validation Loss')
plt.axhline(y=trainer.best_val_loss, color='r', linestyle='--', label=f'Best: {trainer.best_val_loss:.4f}')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

print(f"\nBest validation loss: {trainer.best_val_loss:.4f}")
print(f"Final train loss: {trainer.train_losses[-1]:.4f}")
print(f"Final val loss: {trainer.val_losses[-1]:.4f}")

## 12. Test Inference

In [None]:
# Load best model
checkpoint = torch.load(save_dir / 'best_unet_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"✓ Loaded best model from epoch {checkpoint['epoch']+1}")
print(f"  Validation loss: {checkpoint['val_loss']:.4f}")

# Get a validation sample
val_sample = val_loader.dataset[0]
image = val_sample['image'].unsqueeze(0).to(device)
target_heatmaps = val_sample['heatmaps'].numpy()

# Predict
with torch.no_grad():
    pred_heatmaps = torch.sigmoid(model(image))
pred_heatmaps = pred_heatmaps.cpu().squeeze(0).numpy()

# Visualize
fig, axes = plt.subplots(3, 4, figsize=(16, 12))

# Original image
img_np = val_sample['image'].permute(1, 2, 0).numpy()
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())

for i in range(4):
    # Original image
    axes[0, i].imshow(img_np)
    axes[0, i].set_title(f'Corner {i+1}')
    axes[0, i].axis('off')
    
    # Target heatmap
    axes[1, i].imshow(target_heatmaps[i], cmap='hot')
    axes[1, i].set_title('Target')
    axes[1, i].axis('off')
    
    # Predicted heatmap
    axes[2, i].imshow(pred_heatmaps[i], cmap='hot')
    axes[2, i].set_title('Predicted')
    axes[2, i].axis('off')

plt.tight_layout()
plt.show()

print("✓ Inference successful!")

## 13. Download Results

Download the trained model and results back to your local machine or keep in Google Drive.

In [None]:
# Results are already saved in Google Drive at:
print(f"Results saved in: {save_dir}")
print(f"\nFiles:")
for f in save_dir.glob('*'):
    print(f"  - {f.name}")