In [None]:
import sys
from pathlib import Path

# Setup paths
repo_root = Path("D:/Divas/Projects/MSCS/7643/Project/TORTOISE").resolve()
sys.path.insert(0, str(repo_root / "src"))

import torch
import torch.optim as optim
from tortoise.train import get_device, print_device_info, train_one_epoch, evaluate
from tortoise.dataloader import build_dataloaders
from tortoise.model import U_Net

print_device_info()

## Step 1: Setup device and model

In [None]:
# Get device (automatically selects GPU if available)
device = get_device()
print(f"Training on: {device}")

# Initialize model (13 input channels for MS tiles, 1 output for binary segmentation)
model = U_Net(img_ch=13, output_ch=1)
model = model.to(device)  # Move model to GPU

print(f"Model moved to {device}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## Step 2: Build dataloaders (with pin_memory for GPU)

In [None]:
# Build dataloaders
# Note: pin_memory is automatically enabled when CUDA is available
train_loader, val_loader, test_loader = build_dataloaders(
    tiles_dir=repo_root / "data" / "tiles",
    batch_size=8,
    seed=42,
    train_ratio=0.8,
    val_ratio=0.1,
    use_ms=True,          # Use multispectral (13 bands)
    use_rgb=False,
    num_workers=0,         # Increase on multi-core systems (e.g., 4)
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches:   {len(val_loader)}")
print(f"Test batches:  {len(test_loader)}")

## Step 3: Verify data is on GPU

In [None]:
# Grab one batch to verify GPU transfer
batch = next(iter(train_loader))

# Data is still on CPU at this point (DataLoader default)
print(f"Batch MS shape: {batch['ms'].shape}, device: {batch['ms'].device}")

# Move to GPU (this happens in train_one_epoch, but shown here for clarity)
ms_gpu = batch["ms"].to(device)
label_gpu = batch["label"].to(device)
mask_gpu = batch["mask"].to(device)

print(f"After .to(device):")
print(f"  MS device:    {ms_gpu.device}")
print(f"  Label device: {label_gpu.device}")
print(f"  Mask device:  {mask_gpu.device}")

## Step 4: Training loop with GPU and mixed precision

In [None]:
# Optimizer and learning rate scheduler
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

# Optional: mixed precision training (faster, uses less GPU memory)
scaler = torch.cuda.amp.GradScaler() if device.type == "cuda" else None

num_epochs = 3

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # Training
    train_loss = train_one_epoch(
        model=model,
        loader=train_loader,
        optimizer=optimizer,
        device=device,
        scaler=scaler,  # Enable mixed precision if CUDA available
    )
    
    # Validation
    val_loss = evaluate(
        model=model,
        loader=val_loader,
        desc="Val",
        device=device,
        use_amp=(scaler is not None),
    )
    
    scheduler.step()
    
    print(f"  Train Loss: {train_loss:.6f}")
    print(f"  Val Loss:   {val_loss:.6f}")

print("\nTraining complete!")

## Step 5: Test set evaluation

In [None]:
test_loss = evaluate(
    model=model,
    loader=test_loader,
    desc="Test",
    device=device,
    use_amp=(scaler is not None),
)

print(f"\nTest Loss: {test_loss:.6f}")

## GPU Memory Management Tips

1. **Monitor GPU memory:**
   ```python
   import torch
   print(f"GPU allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
   print(f"GPU reserved:  {torch.cuda.memory_reserved() / 1e9:.2f} GB")
   ```

2. **Clear GPU cache** (use sparingly):
   ```python
   torch.cuda.empty_cache()
   ```

3. **Enable mixed precision** (AMP) for 2-3x speedup on modern GPUs:
   - Already enabled in `train_one_epoch` when `scaler` is passed

4. **Increase batch size** if GPU memory allows (speeds up training)

5. **Increase num_workers** for faster data loading (set to CPU core count, e.g., 4-8)

In [None]:
# Optional: monitor GPU memory
if device.type == "cuda":
    print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
    print(f"GPU memory reserved:  {torch.cuda.memory_reserved() / 1e9:.2f} GB")