# Underwater Image Enhancement with Autoencoder - Google Drive Edition
This notebook trains an autoencoder model to enhance underwater GoPro images, matching manually edited quality.

## 🚀 Quick Start
1. **Run in Google Colab** with GPU runtime (Runtime → Change runtime type → GPU)
2. **Mount Google Drive** (you'll be prompted once)
3. **First run**: Dataset downloads automatically to Drive (~15 min, one-time only)
4. **Subsequent runs**: Loads from Drive in seconds!

## 📁 Google Drive Structure
After first run, your Drive will contain:
```
MyDrive/
└── underwater_enhancement/
    ├── dataset/           # Cached dataset (loads in seconds)
    ├── models/            # Trained models
    └── checkpoints/       # Training checkpoints (resume capability)
```

## ✨ Features
- **Persistent storage**: Everything saves to Google Drive
- **Resume training**: Automatically continues from last checkpoint
- **Fast loading**: Dataset cached after first download
- **No re-downloads**: Access your data and models anytime

## 1. Setup and Installation

In [None]:
# Check GPU availability
import torch
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f'GPU Available: {torch.cuda.get_device_name(0)}')
    print(f'Memory Available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')
else:
    device = torch.device('cpu')
    print('No GPU available, using CPU')

In [None]:
# Install required packages
!pip install -q datasets pillow torch torchvision tqdm matplotlib tensorboard

In [None]:
# Import required libraries
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from datasets import load_dataset
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 2. Mount Google Drive and Setup Paths

### 2a. Load Dataset from Google Drive (Fast after first run)

In [None]:
# Mount Google Drive
from google.colab import drive
import os
import pickle
from pathlib import Path

drive.mount('/content/drive')

# Create directory structure in Google Drive
DRIVE_BASE = Path('/content/drive/MyDrive/underwater_enhancement')
DATASET_PATH = DRIVE_BASE / 'dataset'
MODELS_PATH = DRIVE_BASE / 'models'
CHECKPOINT_PATH = DRIVE_BASE / 'checkpoints'

# Create directories if they don't exist
DRIVE_BASE.mkdir(exist_ok=True)
DATASET_PATH.mkdir(exist_ok=True)
MODELS_PATH.mkdir(exist_ok=True)
CHECKPOINT_PATH.mkdir(exist_ok=True)

print(f"✓ Google Drive mounted")
print(f"✓ Project directory: {DRIVE_BASE}")
print(f"✓ Dataset will be stored at: {DATASET_PATH}")
print(f"✓ Models will be saved at: {MODELS_PATH}")

In [None]:
# Load or download dataset
DATASET_FILE = DATASET_PATH / 'underwater_dataset.pkl'
DATASET_PROCESSED = DATASET_PATH / 'processed'

if DATASET_FILE.exists():
    # Load from Google Drive (fast)
    print("Loading dataset from Google Drive...")
    with open(DATASET_FILE, 'rb') as f:
        dataset = pickle.load(f)
    print(f"✓ Loaded {len(dataset)} samples from Google Drive (fast load)")
    
elif DATASET_PROCESSED.exists():
    # Load processed dataset from disk format
    print("Loading processed dataset from Google Drive...")
    from datasets import load_from_disk
    dataset = load_from_disk(str(DATASET_PROCESSED))
    print(f"✓ Loaded {len(dataset)} samples from processed format")
    
else:
    # First time - download from Hugging Face and save to Drive
    print("First run detected - downloading dataset from Hugging Face...")
    print("This will take a few minutes but only needs to be done once.")
    
    # Option 1: Load full dataset (recommended for production)
    dataset = load_dataset("keenanj/testing-underwater", split="train")
    
    # Option 2: Load smaller subset for testing (uncomment if needed)
    # dataset = load_dataset("keenanj/testing-underwater", split="train[:30%]")  # 30% for testing
    
    print(f"✓ Downloaded {len(dataset)} samples")
    
    # Save to Google Drive for next time
    print("Saving dataset to Google Drive for faster future loads...")
    
    # Save as pickle for fastest loading
    with open(DATASET_FILE, 'wb') as f:
        pickle.dump(dataset, f)
    
    # Also save in HuggingFace format
    dataset.save_to_disk(str(DATASET_PROCESSED))
    
    print("✓ Dataset saved to Google Drive!")
    print("✓ Next time, loading will be much faster (10-30 seconds)")

print(f"\\nDataset ready with {len(dataset)} samples")
print(f"Dataset features: {dataset.features if hasattr(dataset, 'features') else 'N/A'}")

In [None]:
# Explore the dataset structure
sample = dataset[0]
print("Sample keys:", sample.keys() if hasattr(sample, 'keys') else type(sample))

# Display a sample pair
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# Handle different dataset formats
if isinstance(sample, dict):
    input_img = sample.get('input', sample.get('image', None))
    output_img = sample.get('output', sample.get('target', None))
else:
    print("Dataset format not recognized, skipping visualization")
    input_img = None
    output_img = None

if input_img is not None:
    axes[0].imshow(input_img)
    axes[0].set_title('Input (GPR/RAW)')
    axes[0].axis('off')

if output_img is not None:
    axes[1].imshow(output_img)
    axes[1].set_title('Target (Manually Edited JPEG)')
    axes[1].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Explore the dataset structure
sample = dataset[0]
print("Sample keys:", sample.keys())

# Display a sample pair
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
if 'input' in sample:
    axes[0].imshow(sample['input'])
    axes[0].set_title('Input (GPR/RAW)')
    axes[0].axis('off')
if 'output' in sample:
    axes[1].imshow(sample['output'])
    axes[1].set_title('Target (Manually Edited JPEG)')
    axes[1].axis('off')
plt.tight_layout()
plt.show()

## 3. Create PyTorch Dataset

In [None]:
class UnderwaterDataset(Dataset):
    def __init__(self, hf_dataset, image_size=512, augment=False):
        self.dataset = hf_dataset
        self.augment = augment
        
        # Define transforms
        transform_list = [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
        ]
        
        if augment:
            # Add augmentation for training
            self.transform = transforms.Compose([
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.5),
                transforms.RandomRotation(degrees=10),
                *transform_list
            ])
        else:
            self.transform = transforms.Compose(transform_list)
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        sample = self.dataset[idx]
        
        # Load images
        input_img = sample['input'].convert('RGB') if 'input' in sample else sample['image'].convert('RGB')
        target_img = sample['output'].convert('RGB') if 'output' in sample else sample['target'].convert('RGB')
        
        # Apply transforms
        if self.augment:
            # Apply same random transform to both images
            seed = torch.randint(0, 2**32, (1,)).item()
            torch.manual_seed(seed)
            input_img = self.transform(input_img)
            torch.manual_seed(seed)
            target_img = self.transform(target_img)
        else:
            input_img = self.transform(input_img)
            target_img = self.transform(target_img)
        
        return input_img, target_img

In [None]:
# Split dataset into train and validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

# Create indices for splitting
indices = torch.randperm(len(dataset)).tolist()
train_indices = indices[:train_size]
val_indices = indices[train_size:]

# Create train and validation datasets
train_dataset = UnderwaterDataset(dataset.select(train_indices), image_size=256, augment=True)
val_dataset = UnderwaterDataset(dataset.select(val_indices), image_size=256, augment=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

In [None]:
# Create optimized data loaders
batch_size = 16  # Adjust based on GPU memory

# Optimize DataLoader settings for Colab
train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=2,  # Colab typically has 2 CPU cores
    pin_memory=True,  # Faster GPU transfer
    prefetch_factor=2,  # Prefetch batches
    persistent_workers=True  # Keep workers alive between epochs
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=2,
    pin_memory=True,
    prefetch_factor=2,
    persistent_workers=True
)

print(f"Batches per epoch: {len(train_loader)}")
print(f"DataLoader optimized with prefetching and persistent workers")

## 4. Define U-Net Based Autoencoder Architecture

In [None]:
class DoubleConv(nn.Module):
    """Double Convolution Block"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class UNetAutoencoder(nn.Module):
    """U-Net based autoencoder for image enhancement"""
    def __init__(self, n_channels=3, n_classes=3):
        super(UNetAutoencoder, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        # Encoder
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        
        # Decoder
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        # Decoder
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        
        return torch.sigmoid(logits)

In [None]:
# Initialize model
model = UNetAutoencoder(n_channels=3, n_classes=3).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 5. Define Loss Functions and Metrics

In [None]:
class CombinedLoss(nn.Module):
    """Combined loss function for image enhancement"""
    def __init__(self, alpha=0.8, beta=0.2):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.l1_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()
        
    def forward(self, pred, target):
        l1 = self.l1_loss(pred, target)
        mse = self.mse_loss(pred, target)
        return self.alpha * l1 + self.beta * mse


def calculate_psnr(img1, img2):
    """Calculate PSNR between two images"""
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

## 6. Training Setup

In [None]:
# Define loss function and optimizer
criterion = CombinedLoss(alpha=0.8, beta=0.2)
optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

# Training configuration
num_epochs = 50
best_val_loss = float('inf')
patience = 10
patience_counter = 0

## 7. Training Loop

In [None]:
# Training history
train_losses = []
val_losses = []
train_psnrs = []
val_psnrs = []

# Resume from checkpoint if exists
CHECKPOINT_FILE = CHECKPOINT_PATH / 'latest_checkpoint.pth'
start_epoch = 0

if CHECKPOINT_FILE.exists():
    print("Found checkpoint, resuming training...")
    checkpoint = torch.load(CHECKPOINT_FILE)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    train_losses = checkpoint.get('train_losses', [])
    val_losses = checkpoint.get('val_losses', [])
    train_psnrs = checkpoint.get('train_psnrs', [])
    val_psnrs = checkpoint.get('val_psnrs', [])
    best_val_loss = checkpoint.get('best_val_loss', float('inf'))
    print(f"Resuming from epoch {start_epoch}")

# Training loop
for epoch in range(start_epoch, num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 50)
    
    # Train
    train_loss, train_psnr = train_epoch(model, train_loader, criterion, optimizer, device)
    train_losses.append(train_loss)
    train_psnrs.append(train_psnr)
    
    # Validate
    val_loss, val_psnr = validate_epoch(model, val_loader, criterion, device)
    val_losses.append(val_loss)
    val_psnrs.append(val_psnr)
    
    # Update learning rate
    scheduler.step(val_loss)
    
    print(f"Train Loss: {train_loss:.4f}, Train PSNR: {train_psnr:.2f} dB")
    print(f"Val Loss: {val_loss:.4f}, Val PSNR: {val_psnr:.2f} dB")
    
    # Save checkpoint to Google Drive
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
        'val_psnr': val_psnr,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_psnrs': train_psnrs,
        'val_psnrs': val_psnrs,
        'best_val_loss': best_val_loss
    }
    
    # Save latest checkpoint
    torch.save(checkpoint, CHECKPOINT_FILE)
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_path = MODELS_PATH / 'best_underwater_enhancer.pth'
        torch.save(checkpoint, best_model_path)
        print(f"✓ Saved best model to Google Drive: {best_model_path}")
        patience_counter = 0
    else:
        patience_counter += 1
    
    # Save periodic checkpoint (every 5 epochs)
    if (epoch + 1) % 5 == 0:
        periodic_checkpoint = CHECKPOINT_PATH / f'checkpoint_epoch_{epoch+1}.pth'
        torch.save(checkpoint, periodic_checkpoint)
        print(f"✓ Saved periodic checkpoint to: {periodic_checkpoint}")
    
    # Early stopping
    if patience_counter >= patience:
        print(f"Early stopping triggered after {epoch+1} epochs")
        break

print("\n✓ Training complete!")
print(f"✓ All models and checkpoints saved to Google Drive: {DRIVE_BASE}")

In [None]:
# Training history
train_losses = []
val_losses = []
train_psnrs = []
val_psnrs = []

# Training loop
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 50)
    
    # Train
    train_loss, train_psnr = train_epoch(model, train_loader, criterion, optimizer, device)
    train_losses.append(train_loss)
    train_psnrs.append(train_psnr)
    
    # Validate
    val_loss, val_psnr = validate_epoch(model, val_loader, criterion, device)
    val_losses.append(val_loss)
    val_psnrs.append(val_psnr)
    
    # Update learning rate
    scheduler.step(val_loss)
    
    print(f"Train Loss: {train_loss:.4f}, Train PSNR: {train_psnr:.2f} dB")
    print(f"Val Loss: {val_loss:.4f}, Val PSNR: {val_psnr:.2f} dB")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'val_psnr': val_psnr,
        }, 'best_underwater_enhancer.pth')
        print("✓ Saved best model")
        patience_counter = 0
    else:
        patience_counter += 1
    
    # Early stopping
    if patience_counter >= patience:
        print(f"Early stopping triggered after {epoch+1} epochs")
        break

## 8. Visualize Training History

In [None]:
# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss plot
ax1.plot(train_losses, label='Train Loss', color='blue')
ax1.plot(val_losses, label='Val Loss', color='red')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# PSNR plot
ax2.plot(train_psnrs, label='Train PSNR', color='blue')
ax2.plot(val_psnrs, label='Val PSNR', color='red')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('PSNR (dB)')
ax2.set_title('Training and Validation PSNR')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Load best model from Google Drive
best_model_path = MODELS_PATH / 'best_underwater_enhancer.pth'

if best_model_path.exists():
    checkpoint = torch.load(best_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"✓ Loaded best model from Google Drive")
    print(f"  Epoch: {checkpoint['epoch']+1}")
    print(f"  Best Val Loss: {checkpoint['val_loss']:.4f}")
    print(f"  Best Val PSNR: {checkpoint['val_psnr']:.2f} dB")
else:
    print("No saved model found in Google Drive")
    print("Please train the model first")

In [None]:
# Load best model
checkpoint = torch.load('best_underwater_enhancer.pth')
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch']+1}")
print(f"Best Val Loss: {checkpoint['val_loss']:.4f}")
print(f"Best Val PSNR: {checkpoint['val_psnr']:.2f} dB")

In [None]:
def visualize_results(model, dataloader, device, num_samples=5):
    model.eval()
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
    
    with torch.no_grad():
        for idx, (inputs, targets) in enumerate(dataloader):
            if idx >= num_samples:
                break
            
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            
            # Take first image from batch
            input_img = inputs[0].cpu().permute(1, 2, 0).numpy()
            target_img = targets[0].cpu().permute(1, 2, 0).numpy()
            output_img = outputs[0].cpu().permute(1, 2, 0).numpy()
            
            # Display
            axes[idx, 0].imshow(input_img)
            axes[idx, 0].set_title('Input (Raw)')
            axes[idx, 0].axis('off')
            
            axes[idx, 1].imshow(output_img)
            axes[idx, 1].set_title('Model Output')
            axes[idx, 1].axis('off')
            
            axes[idx, 2].imshow(target_img)
            axes[idx, 2].set_title('Target (Manual Edit)')
            axes[idx, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize results
visualize_results(model, val_loader, device, num_samples=5)

# Save final model to Google Drive
final_model_path = MODELS_PATH / 'underwater_enhancer_final.pth'

torch.save({
    'model_state_dict': model.state_dict(),
    'model_config': {
        'n_channels': 3,
        'n_classes': 3,
        'image_size': 256
    },
    'training_history': {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_psnrs': train_psnrs,
        'val_psnrs': val_psnrs
    }
}, final_model_path)

print(f"✓ Model saved to Google Drive: {final_model_path}")

In [None]:
# Export to ONNX for deployment (saved to Google Drive)
onnx_path = MODELS_PATH / 'underwater_enhancer.onnx'

dummy_input = torch.randn(1, 3, 256, 256).to(device)
torch.onnx.export(
    model,
    dummy_input,
    str(onnx_path),
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'},
                  'output': {0: 'batch_size'}}
)
print(f"✓ Model exported to ONNX format: {onnx_path}")

In [None]:
# Export to ONNX for deployment (optional)
dummy_input = torch.randn(1, 3, 256, 256).to(device)
torch.onnx.export(
    model,
    dummy_input,
    "underwater_enhancer.onnx",
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'},
                  'output': {0: 'batch_size'}}
)
print("Model exported to ONNX format as 'underwater_enhancer.onnx'")

## 11. Inference Function for New Images

In [None]:
def enhance_image(model, image_path, device, save_path=None):
    """
    Enhance a single underwater image using the trained model
    """
    model.eval()
    
    # Load and preprocess image
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])
    
    image = Image.open(image_path).convert('RGB')
    original_size = image.size
    input_tensor = transform(image).unsqueeze(0).to(device)
    
    # Generate enhanced image
    with torch.no_grad():
        enhanced_tensor = model(input_tensor)
    
    # Convert back to PIL image
    enhanced_image = transforms.ToPILImage()(enhanced_tensor.squeeze(0).cpu())
    enhanced_image = enhanced_image.resize(original_size, Image.LANCZOS)
    
    # Save if path provided
    if save_path:
        enhanced_image.save(save_path)
        print(f"Enhanced image saved to {save_path}")
    
    return enhanced_image

# Example usage (uncomment and modify path when using):
# enhanced = enhance_image(model, 'path_to_raw_image.jpg', device, 'enhanced_output.jpg')

# Models are already saved to Google Drive!
print("✅ All models are saved in your Google Drive at:")
print(f"   {MODELS_PATH}")
print("\nSaved files:")
print(f"  • best_underwater_enhancer.pth - Best model from training")
print(f"  • underwater_enhancer_final.pth - Final model with full config")
print(f"  • underwater_enhancer.onnx - ONNX format for deployment")
print("\nYou can access these files anytime through Google Drive!")
print("No need to download - they persist between Colab sessions.")

# Optional: Download to local machine if needed
download_locally = False  # Change to True if you want to download

if download_locally:
    from google.colab import files
    print("\nDownloading model files to your local machine...")
    files.download(str(MODELS_PATH / 'best_underwater_enhancer.pth'))
    files.download(str(MODELS_PATH / 'underwater_enhancer_final.pth'))
    files.download(str(MODELS_PATH / 'underwater_enhancer.onnx'))
    print("✓ Download complete!")

In [None]:
# Download model files from Colab
from google.colab import files

print("Downloading model files...")
files.download('best_underwater_enhancer.pth')
files.download('underwater_enhancer_final.pth')
files.download('underwater_enhancer.onnx')
print("Download complete!")