# Underwater Image Enhancement with Autoencoder
This notebook trains an autoencoder model to enhance underwater GoPro images, matching manually edited quality.

## 🚀 Quick Start
1. **Run in Google Colab** with GPU runtime (Runtime → Change runtime type → GPU)
2. **Mount Google Drive** (dataset should be pre-uploaded)
3. **Train the model** using your pre-processed dataset

## 📁 Expected Dataset Structure
Your dataset should be in Google Drive at:
```
/content/drive/MyDrive/testing-dataset-1000-underwater/
├── input/          # Input TIFF images
├── target/         # Target (manually edited) TIFF images
└── split.txt       # Train/validation split indices
```

## 1. Setup and Installation

In [None]:
# Check GPU availability
import torch
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f'GPU Available: {torch.cuda.get_device_name(0)}')
    print(f'Memory Available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')
else:
    device = torch.device('cpu')
    print('No GPU available, using CPU')

In [None]:
# Install required packages
!pip install -q pillow torch torchvision tqdm matplotlib tensorboard

In [None]:
# Import required libraries
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 2. Mount Google Drive and Setup Paths

## 2a. Mount Google Drive and Load Dataset

In [None]:
# Mount Google Drive
from google.colab import drive
import os
from pathlib import Path

drive.mount('/content/drive')

# Dataset path (pre-uploaded)
DATASET_PATH = Path('/content/drive/MyDrive/testing-dataset-1000-underwater')
INPUT_DIR = DATASET_PATH / 'input'
TARGET_DIR = DATASET_PATH / 'target'
SPLIT_FILE = DATASET_PATH / 'split.txt'

# Output directories
OUTPUT_BASE = Path('/content/drive/MyDrive/underwater_enhancement')
MODELS_PATH = OUTPUT_BASE / 'models'
CHECKPOINT_PATH = OUTPUT_BASE / 'checkpoints'

# Create output directories
OUTPUT_BASE.mkdir(exist_ok=True)
MODELS_PATH.mkdir(exist_ok=True)
CHECKPOINT_PATH.mkdir(exist_ok=True)

# Verify dataset exists
if not DATASET_PATH.exists():
    print(f"❌ Dataset not found at {DATASET_PATH}")
    print("Please upload your dataset to Google Drive first.")
else:
    print(f"✓ Dataset found at: {DATASET_PATH}")
    print(f"✓ Input images: {INPUT_DIR}")
    print(f"✓ Target images: {TARGET_DIR}")
    print(f"✓ Models will be saved at: {MODELS_PATH}")

In [None]:
# Load image filenames and split information
import json

# Get list of all images - check for both .tiff and .tif extensions
input_files = sorted([f for f in os.listdir(INPUT_DIR) if f.endswith(('.tiff', '.tif', '.TIFF', '.TIF'))])
target_files = sorted([f for f in os.listdir(TARGET_DIR) if f.endswith(('.tiff', '.tif', '.TIFF', '.TIF'))])

print(f"Found {len(input_files)} input images")
print(f"Found {len(target_files)} target images")

# If no target files found with TIFF extensions, check for other image formats
if len(target_files) == 0:
    target_files = sorted([f for f in os.listdir(TARGET_DIR) if f.endswith(('.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG'))])
    print(f"Found {len(target_files)} target images with alternative extensions")

# Load train/validation split
if SPLIT_FILE.exists():
    try:
        # Try to load as JSON first
        with open(SPLIT_FILE, 'r') as f:
            split_data = json.load(f)
        train_indices = split_data['train']
        val_indices = split_data['validation']
    except json.JSONDecodeError:
        # If not JSON, parse as plain text with comma-separated indices
        with open(SPLIT_FILE, 'r') as f:
            lines = f.readlines()
        
        train_indices = []
        val_indices = []
        
        for line in lines:
            line = line.strip()
            if 'Training' in line or 'training' in line:
                continue
            elif 'Validation' in line or 'validation' in line:
                continue
            elif line and not line.startswith('#'):
                # Parse comma-separated indices
                indices = [int(x.strip()) for x in line.split(',') if x.strip().isdigit()]
                
                # Determine if these are train or validation indices based on position
                if not train_indices:  # First set of indices is training
                    train_indices.extend(indices)
                else:  # Second set is validation
                    val_indices.extend(indices)
        
    print(f"✓ Loaded split: {len(train_indices)} train, {len(val_indices)} validation")
else:
    # If no split file, create 80/20 split
    print("No split.txt found, creating 80/20 train/validation split")
    n_images = min(len(input_files), len(target_files))
    indices = np.random.permutation(n_images)
    split_point = int(0.8 * n_images)
    train_indices = indices[:split_point].tolist()
    val_indices = indices[split_point:].tolist()
    print(f"Created split: {len(train_indices)} train, {len(val_indices)} validation")

# Verify we have matching input and target files
if len(input_files) != len(target_files):
    print(f"⚠️ Warning: Number of input files ({len(input_files)}) doesn't match target files ({len(target_files)})")
    print("Using the minimum number of files available")
    min_files = min(len(input_files), len(target_files))
    input_files = input_files[:min_files]
    target_files = target_files[:min_files]

In [None]:
# Display a sample image pair
sample_idx = 0
input_path = INPUT_DIR / input_files[sample_idx]
target_path = TARGET_DIR / target_files[sample_idx]

fig, axes = plt.subplots(1, 2, figsize=(12, 6))

input_img = Image.open(input_path)
target_img = Image.open(target_path)

axes[0].imshow(input_img)
axes[0].set_title(f'Input Image\n{input_files[sample_idx]}')
axes[0].axis('off')

axes[1].imshow(target_img)
axes[1].set_title(f'Target Image\n{target_files[sample_idx]}')
axes[1].axis('off')

plt.tight_layout()
plt.show()

print(f"Input image size: {input_img.size}")
print(f"Target image size: {target_img.size}")

## 3. Create PyTorch Dataset

In [None]:
class UnderwaterDataset(Dataset):
    def __init__(self, input_dir, target_dir, file_indices, image_size=512, augment=False):
        self.input_dir = Path(input_dir)
        self.target_dir = Path(target_dir)
        self.augment = augment
        
        # Get all files with various image extensions
        input_extensions = ('.tiff', '.tif', '.TIFF', '.TIF')
        target_extensions = ('.tiff', '.tif', '.TIFF', '.TIF', '.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG')
        
        all_input_files = sorted([f for f in os.listdir(input_dir) if f.endswith(input_extensions)])
        all_target_files = sorted([f for f in os.listdir(target_dir) if f.endswith(target_extensions)])
        
        # Filter indices to valid range
        valid_indices = [i for i in file_indices if i < min(len(all_input_files), len(all_target_files))]
        
        # Select files based on indices
        self.input_files = [all_input_files[i] for i in valid_indices]
        self.target_files = [all_target_files[i] for i in valid_indices]
        
        # Define transforms
        transform_list = [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
        ]
        
        if augment:
            # Add augmentation for training
            self.transform = transforms.Compose([
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.5),
                transforms.RandomRotation(degrees=10),
                *transform_list
            ])
        else:
            self.transform = transforms.Compose(transform_list)
    
    def __len__(self):
        return len(self.input_files)
    
    def __getitem__(self, idx):
        # Load images
        input_path = self.input_dir / self.input_files[idx]
        target_path = self.target_dir / self.target_files[idx]
        
        input_img = Image.open(input_path).convert('RGB')
        target_img = Image.open(target_path).convert('RGB')
        
        # Apply transforms
        if self.augment:
            # Apply same random transform to both images
            seed = torch.randint(0, 2**32, (1,)).item()
            torch.manual_seed(seed)
            input_img = self.transform(input_img)
            torch.manual_seed(seed)
            target_img = self.transform(target_img)
        else:
            input_img = self.transform(input_img)
            target_img = self.transform(target_img)
        
        return input_img, target_img

In [None]:
# Create train and validation datasets using the loaded indices
train_dataset = UnderwaterDataset(INPUT_DIR, TARGET_DIR, train_indices, image_size=256, augment=True)
val_dataset = UnderwaterDataset(INPUT_DIR, TARGET_DIR, val_indices, image_size=256, augment=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

## 3b. Create DataLoaders

In [None]:
# Create optimized data loaders
batch_size = 16  # Adjust based on GPU memory

# Optimize DataLoader settings for Colab
train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=2,  # Colab typically has 2 CPU cores
    pin_memory=True,  # Faster GPU transfer
    prefetch_factor=2,  # Prefetch batches
    persistent_workers=True  # Keep workers alive between epochs
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=2,
    pin_memory=True,
    prefetch_factor=2,
    persistent_workers=True
)

print(f"Batches per epoch: {len(train_loader)}")
print(f"DataLoader optimized with prefetching and persistent workers")

## 4. Define U-Net Based Autoencoder Architecture

In [None]:
class DoubleConv(nn.Module):
    """Double Convolution Block"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class UNetAutoencoder(nn.Module):
    """U-Net based autoencoder for image enhancement"""
    def __init__(self, n_channels=3, n_classes=3):
        super(UNetAutoencoder, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        # Encoder
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        
        # Decoder
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        # Decoder
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        
        return torch.sigmoid(logits)

In [None]:
# Initialize model
model = UNetAutoencoder(n_channels=3, n_classes=3).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 5. Define Loss Functions and Metrics

In [None]:
class CombinedLoss(nn.Module):
    """Combined loss function for image enhancement"""
    def __init__(self, alpha=0.8, beta=0.2):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.l1_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()
        
    def forward(self, pred, target):
        l1 = self.l1_loss(pred, target)
        mse = self.mse_loss(pred, target)
        return self.alpha * l1 + self.beta * mse


def calculate_psnr(img1, img2):
    """Calculate PSNR between two images"""
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

## 6. Training Setup

In [None]:
# Define loss function and optimizer
criterion = CombinedLoss(alpha=0.8, beta=0.2)
optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

# Training configuration
num_epochs = 50
best_val_loss = float('inf')
patience = 10
patience_counter = 0

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0.0
    total_psnr = 0.0
    num_batches = len(dataloader)
    
    with tqdm(dataloader, desc="Training") as pbar:
        for batch_idx, (inputs, targets) in enumerate(pbar):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Calculate metrics
            batch_loss = loss.item()
            batch_psnr = calculate_psnr(outputs, targets).item()
            
            total_loss += batch_loss
            total_psnr += batch_psnr
            
            # Update progress bar
            pbar.set_postfix({
                'Loss': f'{batch_loss:.4f}',
                'PSNR': f'{batch_psnr:.2f} dB'
            })
    
    avg_loss = total_loss / num_batches
    avg_psnr = total_psnr / num_batches
    
    return avg_loss, avg_psnr


def validate_epoch(model, dataloader, criterion, device):
    """Validate for one epoch"""
    model.eval()
    total_loss = 0.0
    total_psnr = 0.0
    num_batches = len(dataloader)
    
    with torch.no_grad():
        with tqdm(dataloader, desc="Validation") as pbar:
            for batch_idx, (inputs, targets) in enumerate(pbar):
                inputs, targets = inputs.to(device), targets.to(device)
                
                # Forward pass
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                # Calculate metrics
                batch_loss = loss.item()
                batch_psnr = calculate_psnr(outputs, targets).item()
                
                total_loss += batch_loss
                total_psnr += batch_psnr
                
                # Update progress bar
                pbar.set_postfix({
                    'Loss': f'{batch_loss:.4f}',
                    'PSNR': f'{batch_psnr:.2f} dB'
                })
    
    avg_loss = total_loss / num_batches
    avg_psnr = total_psnr / num_batches
    
    return avg_loss, avg_psnr

## 7. Training Loop

In [None]:
# Training history
train_losses = []
val_losses = []
train_psnrs = []
val_psnrs = []

# Resume from checkpoint if exists
CHECKPOINT_FILE = CHECKPOINT_PATH / 'latest_checkpoint.pth'
start_epoch = 0

if CHECKPOINT_FILE.exists():
    print("Found checkpoint, resuming training...")
    checkpoint = torch.load(CHECKPOINT_FILE)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    train_losses = checkpoint.get('train_losses', [])
    val_losses = checkpoint.get('val_losses', [])
    train_psnrs = checkpoint.get('train_psnrs', [])
    val_psnrs = checkpoint.get('val_psnrs', [])
    best_val_loss = checkpoint.get('best_val_loss', float('inf'))
    print(f"Resuming from epoch {start_epoch}")

# Training loop
for epoch in range(start_epoch, num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 50)
    
    # Train
    train_loss, train_psnr = train_epoch(model, train_loader, criterion, optimizer, device)
    train_losses.append(train_loss)
    train_psnrs.append(train_psnr)
    
    # Validate
    val_loss, val_psnr = validate_epoch(model, val_loader, criterion, device)
    val_losses.append(val_loss)
    val_psnrs.append(val_psnr)
    
    # Update learning rate
    scheduler.step(val_loss)
    
    print(f"Train Loss: {train_loss:.4f}, Train PSNR: {train_psnr:.2f} dB")
    print(f"Val Loss: {val_loss:.4f}, Val PSNR: {val_psnr:.2f} dB")
    
    # Save checkpoint to Google Drive
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
        'val_psnr': val_psnr,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_psnrs': train_psnrs,
        'val_psnrs': val_psnrs,
        'best_val_loss': best_val_loss
    }
    
    # Save latest checkpoint
    torch.save(checkpoint, CHECKPOINT_FILE)
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_path = MODELS_PATH / 'best_underwater_enhancer.pth'
        torch.save(checkpoint, best_model_path)
        print(f"✓ Saved best model to Google Drive: {best_model_path}")
        patience_counter = 0
    else:
        patience_counter += 1
    
    # Save periodic checkpoint (every 5 epochs)
    if (epoch + 1) % 5 == 0:
        periodic_checkpoint = CHECKPOINT_PATH / f'checkpoint_epoch_{epoch+1}.pth'
        torch.save(checkpoint, periodic_checkpoint)
        print(f"✓ Saved periodic checkpoint to: {periodic_checkpoint}")
    
    # Early stopping
    if patience_counter >= patience:
        print(f"Early stopping triggered after {epoch+1} epochs")
        break

print("\n✓ Training complete!")
print(f"✓ All models and checkpoints saved to Google Drive: {OUTPUT_BASE}")

In [None]:
# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss plot
ax1.plot(train_losses, label='Train Loss', color='blue')
ax1.plot(val_losses, label='Val Loss', color='red')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# PSNR plot
ax2.plot(train_psnrs, label='Train PSNR', color='blue')
ax2.plot(val_psnrs, label='Val PSNR', color='red')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('PSNR (dB)')
ax2.set_title('Training and Validation PSNR')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Load best model from Google Drive
best_model_path = MODELS_PATH / 'best_underwater_enhancer.pth'

if best_model_path.exists():
    checkpoint = torch.load(best_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"✓ Loaded best model from Google Drive")
    print(f"  Epoch: {checkpoint['epoch']+1}")
    print(f"  Best Val Loss: {checkpoint['val_loss']:.4f}")
    print(f"  Best Val PSNR: {checkpoint['val_psnr']:.2f} dB")
else:
    print("No saved model found in Google Drive")
    print("Please train the model first")

## 8. Visualize Training History

In [None]:
# Load best model from Google Drive
best_model_path = MODELS_PATH / 'best_underwater_enhancer.pth'

if best_model_path.exists():
    checkpoint = torch.load(best_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"✓ Loaded best model from Google Drive")
    print(f"  Epoch: {checkpoint['epoch']+1}")
    print(f"  Best Val Loss: {checkpoint['val_loss']:.4f}")
    print(f"  Best Val PSNR: {checkpoint['val_psnr']:.2f} dB")
else:
    print("No saved model found in Google Drive")
    print("Please train the model first")

In [None]:
# Export to ONNX for deployment (optional)
dummy_input = torch.randn(1, 3, 256, 256).to(device)
torch.onnx.export(
    model,
    dummy_input,
    "underwater_enhancer.onnx",
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'},
                  'output': {0: 'batch_size'}}
)
print("Model exported to ONNX format as 'underwater_enhancer.onnx'")

## 9. Load Best Model and Evaluate

In [None]:
def visualize_results(model, dataloader, device, num_samples=5):
    model.eval()
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
    
    with torch.no_grad():
        for idx, (inputs, targets) in enumerate(dataloader):
            if idx >= num_samples:
                break
            
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            
            # Take first image from batch
            input_img = inputs[0].cpu().permute(1, 2, 0).numpy()
            target_img = targets[0].cpu().permute(1, 2, 0).numpy()
            output_img = outputs[0].cpu().permute(1, 2, 0).numpy()
            
            # Display
            axes[idx, 0].imshow(input_img)
            axes[idx, 0].set_title('Input (Raw)')
            axes[idx, 0].axis('off')
            
            axes[idx, 1].imshow(output_img)
            axes[idx, 1].set_title('Model Output')
            axes[idx, 1].axis('off')
            
            axes[idx, 2].imshow(target_img)
            axes[idx, 2].set_title('Target (Manual Edit)')
            axes[idx, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize results
visualize_results(model, val_loader, device, num_samples=5)

In [ ]:
# Save final model to Google Drive
final_model_path = MODELS_PATH / 'underwater_enhancer_final.pth'

torch.save({
    'model_state_dict': model.state_dict(),
    'model_config': {
        'n_channels': 3,
        'n_classes': 3,
        'image_size': 256
    },
    'training_history': {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_psnrs': train_psnrs,
        'val_psnrs': val_psnrs
    }
}, final_model_path)

print(f"✓ Model saved to Google Drive: {final_model_path}")

## 10. Save Final Models

In [None]:
# Download model files from Colab
from google.colab import files

print("Downloading model files...")
files.download('best_underwater_enhancer.pth')
files.download('underwater_enhancer_final.pth')
files.download('underwater_enhancer.onnx')
print("Download complete!")