# Ocean SST Super-Resolution with CNNs

This notebook demonstrates the complete pipeline for training a CNN-based super-resolution model for Sea Surface Temperature (SST) data.

## Contents
1. Setup and Imports
2. Data Loading and Preprocessing
3. Model Definition
4. Training
5. Evaluation and Visualization

## 1. Setup and Imports

In [None]:
%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from PIL import Image
import matplotlib.pyplot as plt
import glob
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
import torch.nn.functional as F

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Data Loading and Preprocessing

In [None]:
def process_path(file_path):
    """Load and process a single image."""
    img = Image.open(file_path)
    img = img.convert('RGB')
    return np.array(img)

In [None]:
# Load file paths - update these paths to match your data location
print("Loading file paths...")
low_res_paths = glob.glob("../data/patches/low_res/*/*.png")
high_res_paths = glob.glob("../data/patches/high_res/*/*.png")

print(f"Found {len(low_res_paths)} low-res images")
print(f"Found {len(high_res_paths)} high-res images")

In [None]:
# Process images (this may take a while for large datasets)
print("Processing images...")
low_res_imgs = [process_path(p) for p in tqdm(low_res_paths)]
high_res_imgs = [process_path(p) for p in tqdm(high_res_paths)]

print(f'Total low_res_imgs: {len(low_res_imgs)}')
print(f'Total high_res_imgs: {len(high_res_imgs)}')

In [None]:
# Convert to torch tensors and normalize to [0, 1]
print("Converting to torch tensors...")
x_data = torch.tensor([img.transpose((2, 0, 1)) for img in low_res_imgs], dtype=torch.float32) / 255.0
y_data = torch.tensor([img.transpose((2, 0, 1)) for img in high_res_imgs], dtype=torch.float32) / 255.0

print(f"Low-res data shape: {x_data.shape}")
print(f"High-res data shape: {y_data.shape}")

In [None]:
# Split data into training and test sets
print("Splitting data...")
x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.1, random_state=42)

print(f"Training samples: {len(x_train)}")
print(f"Test samples: {len(x_test)}")

# Create data loaders
train_dataset = TensorDataset(x_train, y_train)
test_dataset = TensorDataset(x_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

### Visualize Sample Data

In [None]:
# Visualize a sample pair
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

idx = 0
axes[0].imshow(x_data[idx].permute(1, 2, 0).numpy())
axes[0].set_title(f'Low Resolution ({x_data[idx].shape[1]}×{x_data[idx].shape[2]})')
axes[0].axis('off')

axes[1].imshow(y_data[idx].permute(1, 2, 0).numpy())
axes[1].set_title(f'High Resolution ({y_data[idx].shape[1]}×{y_data[idx].shape[2]})')
axes[1].axis('off')

plt.tight_layout()
plt.show()

## 3. Model Definition

In [None]:
class SuperResolutionModel(nn.Module):
    """CNN for 5× super-resolution of SST images."""
    
    def __init__(self):
        super(SuperResolutionModel, self).__init__()
        self.initial_upsample = nn.Upsample(scale_factor=5, mode='bilinear', align_corners=True)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(64, 3, kernel_size=3, padding=1)

    def forward(self, x):
        x_initial_upsampled = self.initial_upsample(x)
        x1 = F.relu(self.conv1(x_initial_upsampled))
        x2 = F.relu(self.conv2(x1))
        x3 = F.relu(self.conv3(x2) + x1)  # skip connection
        x4 = self.conv4(x3) + x_initial_upsampled  # skip connection
        return x4

In [None]:
# Create and print model
model = SuperResolutionModel()
model = model.to(device)
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {total_params:,}")

## 4. Training

In [None]:
def visualize_results(model, x, y, epoch):
    """Visualize model output compared to input and target."""
    model.eval()
    with torch.no_grad():
        output = model(x.to(device))
        output_img = output.squeeze().cpu().permute(1, 2, 0).numpy()
        input_img = x.squeeze().permute(1, 2, 0).numpy()
        target_img = y.squeeze().permute(1, 2, 0).numpy()
        
        # Clip values
        output_img = np.clip(output_img, 0, 1)
        
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 3, 1)
        plt.title('Input (Low-Res)')
        plt.imshow(input_img)
        plt.axis('off')
        
        plt.subplot(1, 3, 2)
        plt.title('Target (High-Res)')
        plt.imshow(target_img)
        plt.axis('off')
        
        plt.subplot(1, 3, 3)
        plt.title('Model Output')
        plt.imshow(output_img)
        plt.axis('off')
        
        plt.suptitle(f'Epoch {epoch}')
        plt.tight_layout()
        plt.show()

In [None]:
# Setup training
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training history
train_losses = []
test_losses = []

In [None]:
# Train model
print("Training model...")
n_epochs = 10

for epoch in range(n_epochs):
    # Training
    model.train()
    epoch_loss = 0.0
    
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{n_epochs}')
    for x_batch, y_batch in progress_bar:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        
        optimizer.zero_grad()
        outputs = model(x_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        progress_bar.set_postfix({'loss': f'{loss.item():.6f}'})
    
    avg_train_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    
    # Evaluation
    model.eval()
    test_loss = 0.0
    with torch.no_grad():
        for x_batch, y_batch in test_loader:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            outputs = model(x_batch)
            test_loss += criterion(outputs, y_batch).item()
    
    avg_test_loss = test_loss / len(test_loader)
    test_losses.append(avg_test_loss)
    
    print(f'Epoch {epoch+1}: Train Loss = {avg_train_loss:.6f}, Test Loss = {avg_test_loss:.6f}')
    
    # Visualize results
    visualize_results(model, x_test[0].unsqueeze(0), y_test[0].unsqueeze(0), epoch+1)

## 5. Evaluation and Visualization

In [None]:
# Plot training history
plt.figure(figsize=(10, 5))
plt.plot(range(1, len(train_losses)+1), train_losses, 'b-', label='Training Loss')
plt.plot(range(1, len(test_losses)+1), test_losses, 'r-', label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss (MSE)')
plt.title('Training History')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Compare multiple samples
n_samples = 4
fig, axes = plt.subplots(n_samples, 3, figsize=(15, 4*n_samples))

model.eval()
with torch.no_grad():
    for i in range(n_samples):
        x = x_test[i].unsqueeze(0).to(device)
        y = y_test[i]
        output = model(x).squeeze().cpu()
        
        axes[i, 0].imshow(x_test[i].permute(1, 2, 0).numpy())
        axes[i, 0].set_title('Input')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(y.permute(1, 2, 0).numpy())
        axes[i, 1].set_title('Target')
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(np.clip(output.permute(1, 2, 0).numpy(), 0, 1))
        axes[i, 2].set_title('Output')
        axes[i, 2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Save the trained model
torch.save(model.state_dict(), 'sst_superres_model.pth')
print("Model saved to sst_superres_model.pth")

## Summary

In this notebook, we:
1. Loaded and preprocessed SST patch data
2. Defined a CNN architecture for 5× super-resolution
3. Trained the model using MSE loss
4. Visualized the results

The model learns to enhance low-resolution SST images, though some limitations exist (blurriness, edge artifacts) that could be addressed with more advanced architectures.