Training Implementation

In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import logging
from datetime import datetime
import json
import numpy as np

# Import model modules
from config.config import TrainingConfig
from models.generator import Generator
from models.discriminator import Discriminator
from utils.data_loader import BigEarthNetSRDataset, create_data_splits, create_dataloaders
from utils.training import SRGANTrainer
from utils.metrics import EvaluationMetrics

Import model architecture and metrics from SRGAN_model 

In [2]:
# Set up configuration
config = TrainingConfig()

# Modify default parameters here
#config.batch_size = 8
#config.num_epochs = 100
#config.lr_generator = 1e-4
#config.lr_discriminator = 4e-4

Set Up Dataset Path 

In [3]:
# Set paths
data_dir = "C:\\Users\\kimki\\Downloads\\SRGAN_Satellite"  # Modify this path
output_dir = Path("outputs")
output_dir.mkdir(parents=True, exist_ok=True)

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Dataset Spliting: 80% for training, 10% for validation and 10% for testing

In [None]:
# Create dataset
dataset = BigEarthNetSRDataset(
    root_dir=data_dir,
    subset_size=None,  # Set a number here to use subset of data
    augment=True
)
# Create splits
train_set, val_set, test_set = create_data_splits(dataset)

# Create dataloaders
train_loader, val_loader, test_loader = create_dataloaders(
    train_set=train_set,
    val_set=val_set,
    test_set=test_set,
    batch_size=config.batch_size,
    num_workers=config.num_workers
)


print(f"Training samples: {len(train_set)}")
print(f"Validation samples: {len(val_set)}")
print(f"Test samples: {len(test_set)}")

In [6]:
def normalize_for_display(tensor):
    """Normalize tensor to 0-1 range for display"""
    min_val = tensor.min()
    max_val = tensor.max()
    normalized = (tensor - min_val) / (max_val - min_val + 1e-8)
    return normalized

def create_rgb_composite(tensor):
    """Convert single/double channel image to RGB for display"""
    if tensor.shape[0] == 1:  # Single channel
        return torch.cat([tensor] * 3, dim=0)
    elif tensor.shape[0] == 2:  # Two channels
        # Use first channel for R, second for G, and their mean for B
        r = tensor[0:1]
        g = tensor[1:2]
        b = torch.mean(tensor, dim=0, keepdim=True)
        return torch.cat([r, g, b], dim=0)
    else:
        return tensor[:3]  # Take first 3 channels

def plot_sample(batch):
    if batch is None:
        print("Received empty batch")
        return
        
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Plot 20m bands (6 channels)
    img_20m = batch['bands_20m'][0]  # Shape: [6, 60, 60]
    img_20m = create_rgb_composite(img_20m)  # Convert to RGB
    img_20m = normalize_for_display(img_20m)
    axes[0].imshow(img_20m.permute(1, 2, 0).cpu().numpy())
    axes[0].set_title('20m Bands (First 3 channels)')
    
    # Plot 60m bands (2 channels)
    img_60m = batch['bands_60m'][0]  # Shape: [2, 20, 20]
    img_60m = create_rgb_composite(img_60m)  # Convert to RGB
    img_60m = normalize_for_display(img_60m)
    axes[1].imshow(img_60m.permute(1, 2, 0).cpu().numpy())
    axes[1].set_title('60m Bands (RGB composite)')
    
    # Plot 10m bands (4 channels)
    img_10m = batch['bands_10m'][0]  # Shape: [4, 120, 120]
    img_10m = create_rgb_composite(img_10m)  # Take first 3 channels
    img_10m = normalize_for_display(img_10m)
    axes[2].imshow(img_10m.permute(1, 2, 0).cpu().numpy())
    axes[2].set_title('10m Bands (First 3 channels)')
    
    plt.tight_layout()
    plt.show()

    # Print shape information
    print("\nBatch shapes:")
    for key, value in batch.items():
        print(f"{key}: {value.shape}")
        print(f"Value range: [{value.min():.2f}, {value.max():.2f}]")

In [None]:
try:
    sample_batch = next(iter(train_loader))
    if sample_batch is not None:
        plot_sample(sample_batch)
except Exception as e:
    print(f"Error displaying sample: {str(e)}")

# Display individual bands
def plot_individual_bands(batch):
    """Plot each band separately"""
    if batch is None:
        return
        
    # Plot 20m bands
    n_bands_20m = batch['bands_20m'].shape[1]
    fig, axes = plt.subplots(1, n_bands_20m, figsize=(20, 4))
    for i in range(n_bands_20m):
        img = normalize_for_display(batch['bands_20m'][0, i])
        axes[i].imshow(img.cpu().numpy(), cmap='gray')
        axes[i].set_title(f'20m Band {i+1}')
    plt.tight_layout()
    plt.show()
    
    # Plot 60m bands
    n_bands_60m = batch['bands_60m'].shape[1]
    fig, axes = plt.subplots(1, n_bands_60m, figsize=(8, 4))
    for i in range(n_bands_60m):
        img = normalize_for_display(batch['bands_60m'][0, i])
        axes[i].imshow(img.cpu().numpy(), cmap='gray')
        axes[i].set_title(f'60m Band {i+1}')
    plt.tight_layout()
    plt.show()
    
    # Plot 10m bands
    n_bands_10m = batch['bands_10m'].shape[1]
    fig, axes = plt.subplots(1, n_bands_10m, figsize=(16, 4))
    for i in range(n_bands_10m):
        img = normalize_for_display(batch['bands_10m'][0, i])
        axes[i].imshow(img.cpu().numpy(), cmap='gray')
        axes[i].set_title(f'10m Band {i+1}')
    plt.tight_layout()
    plt.show()

# display the individual bands
if sample_batch is not None:
    plot_individual_bands(sample_batch)

Model Initialization

In [None]:
# Initialize models
generator = Generator(n_res_blocks=config.n_res_blocks).to(device)
discriminator = Discriminator().to(device)

# Initialize plot data
plot_data = {
    'g_losses': [],
    'd_losses': [],
    'psnr': [],
    'ssim': [],
    'mse': []
}
def training_callback(epoch, train_metrics, val_metrics):
    """Callback function for updating plots during training"""
    # Update plot data
    plot_data['g_losses'].append(train_metrics['generator_losses'])
    plot_data['d_losses'].append(train_metrics['discriminator_losses'])
    plot_data['psnr'].append(val_metrics['PSNR'])
    plot_data['ssim'].append(val_metrics['SSIM'])
    plot_data['mse'].append(val_metrics['MSE'])
    # Create figure with 2x2 subplots
    plt.figure(figsize=(20, 10))
    
    # Subplot for Losses
    plt.subplot(221)
    plt.plot(plot_data['g_losses'], label='Generator')
    plt.plot(plot_data['d_losses'], label='Discriminator')
    plt.title('Losses')
    plt.xlabel('Epoch')
    plt.ylabel('Loss Value')
    plt.legend()
    
    # Subplot for PSNR
    plt.subplot(222)
    plt.plot(plot_data['psnr'])
    plt.title('PSNR')
    plt.xlabel('Epoch')
    plt.ylabel('PSNR (dB)')
    
    # Subplot for SSIM
    plt.subplot(223)
    plt.plot(plot_data['ssim'])
    plt.title('SSIM')
    plt.xlabel('Epoch')
    plt.ylabel('SSIM Value')
    
    # Subplot for MSE
    plt.subplot(224)
    plt.plot(plot_data['mse'])
    plt.title('MSE')
    plt.xlabel('Epoch')
    plt.ylabel('MSE Value')
    
    plt.tight_layout()
    plt.show()
    plt.close()


# Initialize trainer

trainer = SRGANTrainer(
    generator=generator,
    discriminator=discriminator,
    config=config
)

# Load checkpoint
checkpoint_path = None  # Set path to checkpoint if resuming training
if checkpoint_path:
    start_epoch, metrics = trainer.load_checkpoint(checkpoint_path)
    print(f"Resumed from epoch {start_epoch}")
    print(f"Previous metrics: {metrics}")


Training Loop

In [None]:
# Pretrain generator
trainer.pretrain_generator(train_loader, num_epochs=5)

# Main training with plot updates
trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=config.num_epochs,
    callback=training_callback
)

plt.ioff()

Evaluation 

In [None]:
# Test final model
print("Evaluating final model on test set...")
test_metrics = trainer._validate(test_loader)
print("\nTest Results:")
for metric_name, value in test_metrics.items():
    print(f"{metric_name}: {value:.4f}")

# Generate final samples
trainer._generate_samples(test_loader, "final_results")

Save Model

In [None]:
# Save final model
torch.save({
    'generator_state_dict': generator.state_dict(),
    'discriminator_state_dict': discriminator.state_dict(),
    'test_metrics': test_metrics
}, output_dir / "final_model.pth")

# Save training history
with open(output_dir / "training_history.json", 'w') as f:
    json.dump(plot_data, f, indent=4)

print("Training completed and results saved!")