# Prithvi 2.0 to U-Net Feature Distillation

This notebook demonstrates how to distill knowledge from the Prithvi 2.0 foundation model to a U-Net student model using feature distillation.

## Overview
- **Teacher Model**: Prithvi 2.0 (100M parameter foundation model)
- **Student Model**: U-Net (lightweight architecture)
- **Distillation Method**: Feature Distillation
- **Dataset**: SEN12MS (Sentinel-1/2 multispectral data)
- **Task**: Land cover classification

## 1. Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import yaml

# Import GeoKD components
from src.data.sen12ms_dataset import SEN12MSDataset
from teachers.prithvi_loader import PrithviTeacher
from students.unet_student import UNetStudent
from src.distillation.losses import FeatureDistillation
from src.distillation.distiller import GeospatialDistiller
from src.evaluation.metrics import GeospatialMetrics

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 2. Configuration and Hyperparameters

In [None]:
# Training configuration
config = {
    'batch_size': 16,
    'learning_rate': 1e-4,
    'num_epochs': 50,
    'temperature': 4.0,
    'alpha': 0.7,  # Weight for distillation loss
    'beta': 0.3,   # Weight for task loss
    'feature_weight': 1.0,  # Feature distillation weight
    'num_classes': 17,  # SEN12MS land cover classes
    'input_channels': 13,  # Sentinel-2 bands
    'image_size': 256,
    'data_path': '/path/to/sen12ms/dataset',
    'checkpoint_dir': './checkpoints',
    'log_interval': 10
}

# Create checkpoint directory
os.makedirs(config['checkpoint_dir'], exist_ok=True)

print('Configuration:')
for key, value in config.items():
    print(f'  {key}: {value}')

## 3. Data Loading and Preprocessing

In [None]:
# Load SEN12MS dataset
print('Loading SEN12MS dataset...')

train_dataset = SEN12MSDataset(
    root_dir=config['data_path'],
    split='train',
    image_size=config['image_size'],
    normalize=True
)

val_dataset = SEN12MSDataset(
    root_dir=config['data_path'],
    split='val',
    image_size=config['image_size'],
    normalize=True
)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

print(f'Training samples: {len(train_dataset)}')
print(f'Validation samples: {len(val_dataset)}')
print(f'Training batches: {len(train_loader)}')
print(f'Validation batches: {len(val_loader)}')

## 4. Model Initialization

In [None]:
# Initialize teacher model (Prithvi 2.0)
print('Loading Prithvi 2.0 teacher model...')
teacher_model = PrithviTeacher(
    model_name='prithvi_100M',
    num_classes=config['num_classes'],
    pretrained=True
).to(device)

# Freeze teacher model
teacher_model.eval()
for param in teacher_model.parameters():
    param.requires_grad = False

print(f'Teacher model parameters: {sum(p.numel() for p in teacher_model.parameters()):,}')

# Initialize student model (U-Net)
print('Initializing U-Net student model...')
student_model = UNetStudent(
    in_channels=config['input_channels'],
    num_classes=config['num_classes'],
    base_channels=64
).to(device)

print(f'Student model parameters: {sum(p.numel() for p in student_model.parameters()):,}')

# Calculate compression ratio
teacher_params = sum(p.numel() for p in teacher_model.parameters())
student_params = sum(p.numel() for p in student_model.parameters())
compression_ratio = teacher_params / student_params
print(f'Compression ratio: {compression_ratio:.2f}x')

## 5. Feature Distillation Setup

In [None]:
# Initialize feature distillation loss
feature_distillation = FeatureDistillation(
    teacher_channels=[768, 768, 768, 768],  # Prithvi feature dimensions
    student_channels=[64, 128, 256, 512],   # U-Net feature dimensions
    temperature=config['temperature'],
    feature_weight=config['feature_weight']
).to(device)

# Initialize distiller
distiller = GeospatialDistiller(
    teacher_model=teacher_model,
    student_model=student_model,
    distillation_loss=feature_distillation,
    temperature=config['temperature'],
    alpha=config['alpha'],
    beta=config['beta']
)

# Initialize optimizer
optimizer = optim.AdamW(
    student_model.parameters(),
    lr=config['learning_rate'],
    weight_decay=1e-4
)

# Learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=config['num_epochs'],
    eta_min=1e-6
)

# Initialize metrics
metrics = GeospatialMetrics(num_classes=config['num_classes'])

print('Feature distillation setup complete!')

## 6. Training Loop

In [None]:
# Training history
train_losses = []
val_losses = []
val_accuracies = []
val_ious = []

best_val_iou = 0.0

print('Starting training...')

for epoch in range(config['num_epochs']):
    # Training phase
    student_model.train()
    teacher_model.eval()
    
    train_loss = 0.0
    train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{config["num_epochs"]} [Train]')
    
    for batch_idx, (images, targets) in enumerate(train_pbar):
        images = images.to(device)
        targets = targets.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass through distiller
        loss_dict = distiller(images, targets)
        total_loss = loss_dict['total_loss']
        
        # Backward pass
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)
        optimizer.step()
        
        train_loss += total_loss.item()
        
        # Update progress bar
        if batch_idx % config['log_interval'] == 0:
            train_pbar.set_postfix({
                'Loss': f'{total_loss.item():.4f}',
                'KD': f'{loss_dict["kd_loss"].item():.4f}',
                'Feature': f'{loss_dict["feature_loss"].item():.4f}',
                'Task': f'{loss_dict["task_loss"].item():.4f}'
            })
    
    avg_train_loss = train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    
    # Validation phase
    student_model.eval()
    val_loss = 0.0
    metrics.reset()
    
    val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{config["num_epochs"]} [Val]')
    
    with torch.no_grad():
        for images, targets in val_pbar:
            images = images.to(device)
            targets = targets.to(device)
            
            # Forward pass
            loss_dict = distiller(images, targets)
            val_loss += loss_dict['total_loss'].item()
            
            # Get student predictions
            student_outputs = student_model(images)
            predictions = torch.argmax(student_outputs, dim=1)
            
            # Update metrics
            metrics.update(predictions, targets)
    
    avg_val_loss = val_loss / len(val_loader)
    val_losses.append(avg_val_loss)
    
    # Calculate metrics
    val_metrics = metrics.compute()
    val_accuracy = val_metrics['accuracy']
    val_iou = val_metrics['mean_iou']
    
    val_accuracies.append(val_accuracy)
    val_ious.append(val_iou)
    
    # Update learning rate
    scheduler.step()
    
    # Print epoch results
    print(f'Epoch {epoch+1}/{config["num_epochs"]}:')
    print(f'  Train Loss: {avg_train_loss:.4f}')
    print(f'  Val Loss: {avg_val_loss:.4f}')
    print(f'  Val Accuracy: {val_accuracy:.4f}')
    print(f'  Val mIoU: {val_iou:.4f}')
    print(f'  Learning Rate: {scheduler.get_last_lr()[0]:.6f}')
    
    # Save best model
    if val_iou > best_val_iou:
        best_val_iou = val_iou
        torch.save({
            'epoch': epoch + 1,
            'student_state_dict': student_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_iou': best_val_iou,
            'config': config
        }, os.path.join(config['checkpoint_dir'], 'best_student_model.pth'))
        print(f'  New best model saved! (mIoU: {best_val_iou:.4f})')
    
    print('-' * 60)

print('Training completed!')

## 7. Training Visualization

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss curves
axes[0, 0].plot(train_losses, label='Train Loss', color='blue')
axes[0, 0].plot(val_losses, label='Val Loss', color='red')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Accuracy curve
axes[0, 1].plot(val_accuracies, label='Val Accuracy', color='green')
axes[0, 1].set_title('Validation Accuracy')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True)

# IoU curve
axes[1, 0].plot(val_ious, label='Val mIoU', color='purple')
axes[1, 0].set_title('Validation Mean IoU')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('mIoU')
axes[1, 0].legend()
axes[1, 0].grid(True)

# Learning rate curve
lr_history = [config['learning_rate'] * (0.5 ** (epoch / (config['num_epochs'] / 4))) for epoch in range(len(train_losses))]
axes[1, 1].plot(lr_history, label='Learning Rate', color='orange')
axes[1, 1].set_title('Learning Rate Schedule')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Learning Rate')
axes[1, 1].set_yscale('log')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.tight_layout()
plt.savefig(os.path.join(config['checkpoint_dir'], 'training_curves.png'), dpi=300, bbox_inches='tight')
plt.show()

print(f'Best validation mIoU: {best_val_iou:.4f}')

## 8. Model Evaluation and Comparison

In [None]:
# Load best model
checkpoint = torch.load(os.path.join(config['checkpoint_dir'], 'best_student_model.pth'))
student_model.load_state_dict(checkpoint['student_state_dict'])
student_model.eval()

# Evaluate student model
print('Evaluating distilled student model...')
student_metrics = GeospatialMetrics(num_classes=config['num_classes'])
student_metrics.reset()

# Evaluate teacher model for comparison
print('Evaluating teacher model...')
teacher_metrics = GeospatialMetrics(num_classes=config['num_classes'])
teacher_metrics.reset()

with torch.no_grad():
    for images, targets in tqdm(val_loader, desc='Evaluation'):
        images = images.to(device)
        targets = targets.to(device)
        
        # Student predictions
        student_outputs = student_model(images)
        student_preds = torch.argmax(student_outputs, dim=1)
        student_metrics.update(student_preds, targets)
        
        # Teacher predictions
        teacher_outputs = teacher_model(images)
        teacher_preds = torch.argmax(teacher_outputs, dim=1)
        teacher_metrics.update(teacher_preds, targets)

# Compute final metrics
student_results = student_metrics.compute()
teacher_results = teacher_metrics.compute()

# Print comparison
print('\n' + '='*60)
print('FINAL EVALUATION RESULTS')
print('='*60)

print('Teacher Model (Prithvi 2.0):')
print(f'  Parameters: {teacher_params:,}')
print(f'  Accuracy: {teacher_results["accuracy"]:.4f}')
print(f'  Mean IoU: {teacher_results["mean_iou"]:.4f}')
print(f'  F1 Score: {teacher_results["f1_score"]:.4f}')

print('Student Model (U-Net):')
print(f'  Parameters: {student_params:,}')
print(f'  Accuracy: {student_results["accuracy"]:.4f}')
print(f'  Mean IoU: {student_results["mean_iou"]:.4f}')
print(f'  F1 Score: {student_results["f1_score"]:.4f}')

print('Distillation Results:')
print(f'  Compression Ratio: {compression_ratio:.2f}x')
print(f'  Accuracy Retention: {(student_results["accuracy"] / teacher_results["accuracy"]) * 100:.1f}%')
print(f'  IoU Retention: {(student_results["mean_iou"] / teacher_results["mean_iou"]) * 100:.1f}%')
print(f'  F1 Retention: {(student_results["f1_score"] / teacher_results["f1_score"]) * 100:.1f}%')

## 9. Inference Example

In [None]:
# Get a sample for inference demonstration
sample_images, sample_targets = next(iter(val_loader))
sample_images = sample_images[:4].to(device)  # Take first 4 samples
sample_targets = sample_targets[:4].to(device)

# Perform inference
with torch.no_grad():
    teacher_outputs = teacher_model(sample_images)
    student_outputs = student_model(sample_images)
    
    teacher_preds = torch.argmax(teacher_outputs, dim=1)
    student_preds = torch.argmax(student_outputs, dim=1)

# Visualize results
fig, axes = plt.subplots(4, 4, figsize=(16, 16))

for i in range(4):
    # Original image (RGB bands)
    rgb_image = sample_images[i][[3, 2, 1]].cpu().numpy()  # RGB bands
    rgb_image = np.transpose(rgb_image, (1, 2, 0))
    rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min())
    axes[i, 0].imshow(rgb_image)
    axes[i, 0].set_title(f'Sample {i+1}: RGB Image')
    axes[i, 0].axis('off')
    
    # Ground truth
    axes[i, 1].imshow(sample_targets[i].cpu().numpy(), cmap='tab20')
    axes[i, 1].set_title('Ground Truth')
    axes[i, 1].axis('off')
    
    # Teacher prediction
    axes[i, 2].imshow(teacher_preds[i].cpu().numpy(), cmap='tab20')
    axes[i, 2].set_title('Teacher Prediction')
    axes[i, 2].axis('off')
    
    # Student prediction
    axes[i, 3].imshow(student_preds[i].cpu().numpy(), cmap='tab20')
    axes[i, 3].set_title('Student Prediction')
    axes[i, 3].axis('off')

plt.tight_layout()
plt.savefig(os.path.join(config['checkpoint_dir'], 'inference_examples.png'), dpi=300, bbox_inches='tight')
plt.show()

## 10. Model Export and Deployment

In [None]:
# Export student model for deployment
print('Exporting student model...')

# Save model in different formats
export_dir = os.path.join(config['checkpoint_dir'], 'exported_models')
os.makedirs(export_dir, exist_ok=True)

# 1. PyTorch format
torch.save(student_model.state_dict(), os.path.join(export_dir, 'unet_student.pth'))

# 2. TorchScript format
student_model.eval()
example_input = torch.randn(1, config['input_channels'], config['image_size'], config['image_size']).to(device)
traced_model = torch.jit.trace(student_model, example_input)
traced_model.save(os.path.join(export_dir, 'unet_student_traced.pt'))

# 3. ONNX format (optional)
try:
    torch.onnx.export(
        student_model,
        example_input,
        os.path.join(export_dir, 'unet_student.onnx'),
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
    )
    print('ONNX export successful!')
except Exception as e:
    print(f'ONNX export failed: {e}')

# Save model metadata
metadata = {
    'model_type': 'UNet Student',
    'teacher_model': 'Prithvi 2.0',
    'distillation_method': 'Feature Distillation',
    'dataset': 'SEN12MS',
    'num_classes': config['num_classes'],
    'input_channels': config['input_channels'],
    'image_size': config['image_size'],
    'parameters': student_params,
    'compression_ratio': compression_ratio,
    'final_accuracy': student_results['accuracy'],
    'final_miou': student_results['mean_iou'],
    'final_f1': student_results['f1_score'],
    'training_config': config
}

with open(os.path.join(export_dir, 'model_metadata.yaml'), 'w') as f:
    yaml.dump(metadata, f, default_flow_style=False)

print(f'Models exported to: {export_dir}')
print('Available formats:')
print('  - unet_student.pth (PyTorch state dict)')
print('  - unet_student_traced.pt (TorchScript)')
print('  - unet_student.onnx (ONNX)')
print('  - model_metadata.yaml (Model information)')