# Lab 6: Baseline Model Training (2 Hours)

## ‚è±Ô∏è Time Allocation
- **Part 1 (40 min):** Model architecture design
- **Part 2 (55 min):** Data loading and training implementation
- **Part 3 (25 min):** GPU training with SLURM

## üéØ Learning Objectives

### Core (Essential)
- ‚úÖ Build CNN classifier with PyTorch
- ‚úÖ Implement DataLoaders for patch datasets
- ‚úÖ Create training loop with optimizer
- ‚úÖ Submit GPU job via SLURM
- ‚úÖ Monitor training and save checkpoints

### Optional (For Early Finishers)
- üîµ Experiment with different architectures (ResNet, EfficientNet)
- üîµ Implement learning rate scheduling
- üîµ Add early stopping
- üîµ Try mixed precision training
- üîµ Set up TensorBoard logging

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from pathlib import Path
import json
import os
from tqdm import tqdm

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")

### Load Dataset

In [None]:
# Define paths
project_dir = Path(os.getenv('PROJECT_training2600')) / 'my_workspace'
data_dir = project_dir / 'data' / 'preprocessed'
model_dir = project_dir / 'models'
model_dir.mkdir(parents=True, exist_ok=True)

print(f"üìÇ Data directory: {data_dir}")
print(f"üìÇ Model directory: {model_dir}")

# Load data
print("\nüì• Loading preprocessed data...")
X_train = np.load(data_dir / 'X_train.npy')
y_train = np.load(data_dir / 'y_train.npy')
X_val = np.load(data_dir / 'X_val.npy')
y_val = np.load(data_dir / 'y_val.npy')

# Load metadata
with open(data_dir / 'dataset_metadata.json', 'r') as f:
    metadata = json.load(f)

print(f"\n‚úÖ Dataset loaded:")
print(f"   Train: {X_train.shape}")
print(f"   Val: {X_val.shape}")
print(f"   Classes: {metadata['num_classes']} ({', '.join(metadata['class_names'])})")
print(f"   Bands: {metadata['num_bands']}")

## Section 2: Create PyTorch Dataset and DataLoader (5 min)

### Custom Dataset Class

In [None]:
class SatelliteDataset(Dataset):
    """PyTorch Dataset for satellite imagery patches."""
    
    def __init__(self, X, y, transform=None):
        """
        Args:
            X: numpy array of shape (N, C, H, W)
            y: numpy array of shape (N,)
            transform: optional transforms to apply
        """
        self.X = torch.FloatTensor(X)
        self.y = torch.LongTensor(y)
        self.transform = transform
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        image = self.X[idx]
        label = self.y[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Create datasets
train_dataset = SatelliteDataset(X_train, y_train)
val_dataset = SatelliteDataset(X_val, y_val)

print(f"‚úÖ Created datasets:")
print(f"   Train: {len(train_dataset)} samples")
print(f"   Val: {len(val_dataset)} samples")

### Create DataLoaders

In [None]:
# Hyperparameters
batch_size = 32
num_workers = 4  # For parallel data loading

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True  # Faster GPU transfer
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

print(f"‚úÖ Created DataLoaders:")
print(f"   Train batches: {len(train_loader)}")
print(f"   Val batches: {len(val_loader)}")
print(f"   Batch size: {batch_size}")

# Test a batch
images, labels = next(iter(train_loader))
print(f"\n   Sample batch shape: {images.shape}")
print(f"   Sample labels shape: {labels.shape}")

## Section 3: Build CNN Classifier (8 min)

### Simple CNN Architecture
We'll start with a baseline CNN, then optionally use pre-trained models later.

In [None]:
class SatelliteCNN(nn.Module):
    """Baseline CNN for satellite image classification."""
    
    def __init__(self, in_channels=6, num_classes=7):
        super(SatelliteCNN, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2)  # 224 -> 112
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)  # 112 -> 56
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2)  # 56 -> 28
        )
        
        self.conv4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2)  # 28 -> 14
        )
        
        # Global Average Pooling
        self.gap = nn.AdaptiveAvgPool2d(1)
        
        # Fully connected layers
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.gap(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc(x)
        return x

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SatelliteCNN(
    in_channels=metadata['num_bands'],
    num_classes=metadata['num_classes']
).to(device)

print(f"‚úÖ Model created and moved to {device}")
print(f"\nüìä Model Summary:")
print(model)

In [None]:
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nüî¢ Model Parameters:")
print(f"   Total: {total_params:,}")
print(f"   Trainable: {trainable_params:,}")

## Section 4: Training Setup (7 min)

### Loss Function and Optimizer

In [None]:
# Loss function (CrossEntropyLoss for multi-class classification)
criterion = nn.CrossEntropyLoss()

# Optimizer (Adam with weight decay)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

# Learning rate scheduler (reduce on plateau)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

print("‚úÖ Training components initialized:")
print(f"   Loss: CrossEntropyLoss")
print(f"   Optimizer: Adam (lr=0.001)")
print(f"   Scheduler: ReduceLROnPlateau")

### Training and Validation Functions

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in tqdm(loader, desc='Training', leave=False):
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc


def validate_epoch(model, loader, criterion, device):
    """Validate for one epoch."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Validation', leave=False):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

print("‚úÖ Training functions defined")

## Section 5: Train the Model (10 min)

### Training Loop

In [None]:
# Training configuration
num_epochs = 20
best_val_acc = 0.0
patience_counter = 0
early_stop_patience = 5

# History tracking
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}

print(f"üöÄ Starting training for {num_epochs} epochs...\n")
print(f"{'Epoch':<8} {'Train Loss':<12} {'Train Acc':<12} {'Val Loss':<12} {'Val Acc':<12}")
print("-" * 60)

for epoch in range(num_epochs):
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)
    
    # Update scheduler
    scheduler.step(val_loss)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Print progress
    print(f"{epoch+1:<8} {train_loss:<12.4f} {train_acc:<12.2f} {val_loss:<12.4f} {val_acc:<12.2f}")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss,
            'metadata': metadata
        }
        torch.save(checkpoint, model_dir / 'best_model.pth')
        print(f"   ‚úì Saved best model (val_acc: {val_acc:.2f}%)")
    else:
        patience_counter += 1
    
    # Early stopping
    if patience_counter >= early_stop_patience:
        print(f"\n‚ö†Ô∏è Early stopping triggered at epoch {epoch+1}")
        break

print(f"\n‚úÖ Training complete!")
print(f"   Best validation accuracy: {best_val_acc:.2f}%")

## Section 6: Visualize Training Progress (5 min)

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss
axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history['train_acc'], label='Train Acc', marker='o')
axes[1].plot(history['val_acc'], label='Val Acc', marker='s')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(model_dir / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nüíæ Saved training curves to: {model_dir / 'training_curves.png'}")

### Save Training History

In [None]:
# Save history as JSON
with open(model_dir / 'training_history.json', 'w') as f:
    json.dump(history, f, indent=2)

print(f"üíæ Saved training history to: {model_dir / 'training_history.json'}")

## Section 7: SLURM Job Submission (5 min)

### Create Training Script
For longer training runs, submit as a batch job:

In [None]:
# Save this notebook's training code as a standalone script
training_script = """
#!/usr/bin/env python3
# Standalone training script for SLURM submission

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import json
import os
from tqdm import tqdm

# [Include SatelliteDataset and SatelliteCNN classes here]
# [Include train_epoch and validate_epoch functions here]

if __name__ == '__main__':
    # Load data
    project_dir = Path(os.getenv('PROJECT_training2600')) / 'my_workspace'
    data_dir = project_dir / 'data' / 'preprocessed'
    model_dir = project_dir / 'models'
    
    # [Include full training loop here]
    
    print("Training complete!")
"""

script_path = project_dir / 'scripts' / 'train_baseline.py'
with open(script_path, 'w') as f:
    f.write(training_script)

print(f"üíæ Saved training script to: {script_path}")

### Create SLURM Submission Script

In [None]:
slurm_script = f"""
#!/bin/bash
#SBATCH --job-name=baseline_train
#SBATCH --account=training2600
#SBATCH --partition=dc-gpu
#SBATCH --nodes=1
#SBATCH --gres=gpu:1
#SBATCH --time=02:00:00
#SBATCH --output=logs/train_%j.out
#SBATCH --error=logs/train_%j.err

# Load modules
module load Python/3.11.3
module load PyTorch/2.0.1

# Activate virtual environment
source ~/envs/ml_eo_course/bin/activate

# Print GPU info
nvidia-smi

# Run training
python {script_path}

echo "Job finished at $(date)"
"""

sbatch_path = project_dir / 'scripts' / 'submit_training.sbatch'
with open(sbatch_path, 'w') as f:
    f.write(slurm_script)

print(f"üíæ Saved SLURM script to: {sbatch_path}")
print(f"\nüìã To submit job, run:")
print(f"   cd {project_dir / 'scripts'}")
print(f"   sbatch submit_training.sbatch")

## Summary & Next Steps

### What We Covered
‚úÖ Built a CNN classifier for satellite imagery  
‚úÖ Created PyTorch DataLoaders  
‚úÖ Trained model with GPU acceleration  
‚úÖ Monitored training metrics  
‚úÖ Saved model checkpoints  
‚úÖ Created SLURM submission scripts  

### Model Performance
- **Best Validation Accuracy:** Variable (depends on data)
- **Training Time:** ~10-20 min on GPU
- **Model Size:** ~5 MB
- **Parameters:** ~1-2 million

### Key Training Concepts
- **Batch Training:** Process data in mini-batches for efficiency
- **Learning Rate Scheduling:** Adapt learning rate during training
- **Early Stopping:** Prevent overfitting
- **Checkpointing:** Save best model based on validation

### Prepare for Lab 5.2
Next lab: **Model Evaluation Metrics**
- Calculate precision, recall, F1-score
- Generate confusion matrix
- Visualize predictions on test set
- Compare against baseline

### Best Practices
1. **Always monitor both train and val metrics** (detect overfitting)
2. **Use GPU for faster training** (10-50x speedup)
3. **Save checkpoints regularly** (protect against crashes)
4. **Track experiments** (log hyperparameters and results)

### Monitoring SLURM Jobs
```bash
# Check job status
squeue -u $USER

# View output
tail -f logs/train_<job_id>.out

# Cancel job
scancel <job_id>
```

### Additional Resources
- **PyTorch Tutorials:** https://pytorch.org/tutorials/
- **SLURM Docs:** https://slurm.schedmd.com/
- **Model Training Best Practices:** https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html

### Homework (Optional)
1. Experiment with different architectures (ResNet, EfficientNet)
2. Try different optimizers (SGD, AdamW)
3. Implement data augmentation
4. Track experiments with Weights & Biases or TensorBoard

---

**Excellent progress!** Your baseline model is trained! Next, we'll evaluate its performance in detail. üéØ

---

## ‚úÖ Lab 6 Completion Checklist

### Core Tasks
- [ ] CNN model defined and tested
- [ ] DataLoader implemented
- [ ] Training loop working
- [ ] SLURM script created
- [ ] GPU job submitted
- [ ] Checkpoint saved

### Optional Tasks
- [ ] Tried different architectures
- [ ] Implemented LR scheduling
- [ ] Added early stopping
- [ ] Set up TensorBoard

## üöÄ Next Lab
**Lab 7: Model Evaluation** - Load trained model, compute metrics, analyze errors