In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
from datasets import load_dataset
from skimage.metrics import peak_signal_noise_ratio as compute_psnr
from skimage.metrics import structural_similarity as compute_ssim
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Set random seed for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False  # For performance
    torch.backends.cudnn.benchmark = True       # For performance

set_seed()

# Save directory in Google Drive
save_dir = '/content/drive/MyDrive/esrgan_output2'
os.makedirs(save_dir, exist_ok=True)
os.makedirs(f"{save_dir}/samples", exist_ok=True)
os.makedirs(f"{save_dir}/checkpoints", exist_ok=True)
os.makedirs(f"{save_dir}/inference", exist_ok=True)

# Model parameters - larger capacity for H100
nf1 = 64    # Number of filters
gc1 = 32    # Growth channels

# Residual Dense Block
class ResidualDenseBlock(nn.Module):
    def __init__(self, nf=nf1, gc=gc1):
        super().__init__()
        self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=False)
        self.conv2 = nn.Conv2d(nf+gc, gc, 3, 1, 1, bias=False)
        self.conv3 = nn.Conv2d(nf+2*gc, gc, 3, 1, 1, bias=False)
        self.conv4 = nn.Conv2d(nf+3*gc, gc, 3, 1, 1, bias=False)
        self.conv5 = nn.Conv2d(nf+4*gc, nf, 3, 1, 1, bias=False)
        self.lrelu = nn.LeakyReLU(0.2, inplace=True)

        # Weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5 * 0.2 + x

class RRDB(nn.Module):
    def __init__(self, nf=nf1, gc=gc1):
        super().__init__()
        self.RDB1 = ResidualDenseBlock(nf, gc)
        self.RDB2 = ResidualDenseBlock(nf, gc)
        self.RDB3 = ResidualDenseBlock(nf, gc)

    def forward(self, x):
        return self.RDB3(self.RDB2(self.RDB1(x))) * 0.2 + x

class GeneratorRRDB(nn.Module):
    def __init__(self, in_nc=3, out_nc=3, nf=nf1, nb=23, gc=gc1):  # 23 blocks like real ESRGAN
        super().__init__()
        self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=False)
        self.RRDB_trunk = nn.Sequential(*[RRDB(nf, gc) for _ in range(nb)])
        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=False)

        # 4x upsampling (256->1024, close to 1048)
        self.upsampling = nn.Sequential(
            nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=False),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, True),
            # nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=False),
            # nn.PixelShuffle(2),
            # nn.LeakyReLU(0.2, True)
        )

        # Final convolution
        self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1)

        # Weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)

    def forward(self, x):
        fea = self.conv_first(x)
        trunk = self.RRDB_trunk(fea)
        fea = fea + self.trunk_conv(trunk)
        out = self.conv_last(self.upsampling(fea))
        return out

# Custom Dataset for DIV2K
class DIV2KSRDataset(Dataset):
    def __init__(self, split="train", cache=True, hr_size=512, lr_size=256):
        self.dataset = load_dataset("eugenesiow/Div2k", split=split, cache_dir="./div2k_data")
        self.cache = cache
        self.cached_data = {}
        self.hr_size = hr_size
        self.lr_size = lr_size

        # HR transform - target size 1048x1048
        self.hr_transform = transforms.Compose([
            transforms.Resize((hr_size, hr_size), interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor()
        ])

        # LR transform - input size 256x256
        self.lr_transform = transforms.Compose([
            transforms.Resize((lr_size, lr_size), interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        if self.cache and idx in self.cached_data:
            return self.cached_data[idx]

        # Load image
        img_path = self.dataset[idx]["hr"]
        img = Image.open(img_path).convert("RGB")

        # Generate HR and LR images
        hr = self.hr_transform(img)
        lr = self.lr_transform(img)

        if self.cache:
            self.cached_data[idx] = (lr, hr)

        return lr, hr

# Convert tensor to numpy for visualization
def tensor_to_numpy(tensor):
    """Convert a torch tensor to numpy array for visualization"""
    img = tensor.detach().cpu().numpy().transpose(1, 2, 0)
    img = np.clip(img, 0, 1)
    return img

# Calculate PSNR between two images
def calculate_psnr(img1, img2):
    """Calculate PSNR between two images (numpy arrays)"""
    img1_np = tensor_to_numpy(img1)
    img2_np = tensor_to_numpy(img2)
    return compute_psnr(img1_np, img2_np, data_range=1.0)

# Calculate SSIM between two images
def calculate_ssim(img1, img2):
    """Calculate SSIM between two images (numpy arrays)"""
    img1_np = tensor_to_numpy(img1)
    img2_np = tensor_to_numpy(img2)
    return compute_ssim(img1_np, img2_np, data_range=1.0, channel_axis=2, multichannel=True)

# Visualize results during training
def visualize_samples(model, data_loader, device, epoch, save_dir, num_samples=3):
    """Visualize and save sample results"""
    model.eval()
    os.makedirs(f"{save_dir}/samples/epoch_{epoch}", exist_ok=True)

    with torch.no_grad():
        for i, (lr, hr) in enumerate(data_loader):
            if i >= num_samples:
                break

            lr, hr = lr.to(device), hr.to(device)
            sr = model(lr)

            # Process only the first image in the batch
            lr_np = tensor_to_numpy(lr[0])
            hr_np = tensor_to_numpy(hr[0])
            sr_np = tensor_to_numpy(sr[0])

            # Calculate metrics
            psnr_val = compute_psnr(sr_np, hr_np, data_range=1.0)
            ssim_val = compute_ssim(sr_np, hr_np, data_range=1.0, channel_axis=2, multichannel=True)

            # Create visualization
            plt.figure(figsize=(18, 6))

            plt.subplot(1, 3, 1)
            plt.imshow(lr_np)
            plt.title(f'Low Resolution ({lr_np.shape[0]}x{lr_np.shape[1]})')
            plt.axis('off')

            plt.subplot(1, 3, 2)
            plt.imshow(sr_np)
            plt.title(f'Super Resolution - PSNR: {psnr_val:.2f}, SSIM: {ssim_val:.4f}')
            plt.axis('off')

            plt.subplot(1, 3, 3)
            plt.imshow(hr_np)
            plt.title(f'High Resolution ({hr_np.shape[0]}x{hr_np.shape[1]})')
            plt.axis('off')

            plt.tight_layout()
            plt.savefig(f"{save_dir}/samples/epoch_{epoch}/sample_{i}.png", dpi=200)
            plt.close()

# Evaluate model on validation or test set
def evaluate_model(model, data_loader, device, desc="Evaluating"):
    """Evaluate model and compute metrics"""
    model.eval()
    psnr_values = []
    ssim_values = []

    with torch.no_grad():
        for lr, hr in tqdm(data_loader, desc=desc):
            lr, hr = lr.to(device), hr.to(device)
            sr = model(lr)

            # Calculate metrics for each image in batch
            for i in range(sr.size(0)):
                psnr_val = calculate_psnr(sr[i], hr[i])
                ssim_val = calculate_ssim(sr[i], hr[i])

                psnr_values.append(psnr_val)
                ssim_values.append(ssim_val)

    avg_psnr = np.mean(psnr_values)
    avg_ssim = np.mean(ssim_values)

    return avg_psnr, avg_ssim

def train_esrgan():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA device: {torch.cuda.get_device_name()}")

    # Create model
    model = GeneratorRRDB().to(device)
    print(f"Model created with {sum(p.numel() for p in model.parameters())/1e6:.2f}M parameters")

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, betas=(0.9, 0.99))
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    # Mixed precision training
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

    # Loss function
    criterion = nn.L1Loss()

    # Load and split dataset
    print("Loading dataset...")
    full_dataset = DIV2KSRDataset(split="train", cache=True, hr_size=512, lr_size=256)

    # Split dataset into train, val, test (80%, 10%, 10%)
    dataset_size = len(full_dataset)
    train_size = int(0.8 * dataset_size)
    val_size = int(0.1 * dataset_size)
    test_size = dataset_size - train_size - val_size

    train_dataset, val_dataset, test_dataset = random_split(
        full_dataset, [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )

    print(f"Dataset split: Train={len(train_dataset)}, Val={len(val_dataset)}, Test={len(test_dataset)}")

    # Create data loaders
    batch_size = 2  # Adjust based on GPU memory
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                             num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False,
                             num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False,
                             num_workers=4, pin_memory=True)

    # Sample visualization loader (3 fixed samples from validation)
    vis_loader = DataLoader(val_dataset, batch_size=1, shuffle=False,
                            num_workers=1, pin_memory=True)

    # Training variables
    num_epochs = 20
    best_psnr = 0
    train_losses = []
    val_psnrs = []
    val_ssims = []

    print(f"Starting training for {num_epochs} epochs...")

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        batch_count = 0

        # Training phase
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for lr, hr in pbar:
            # Move data to device
            lr, hr = lr.to(device), hr.to(device)

            optimizer.zero_grad()

            # Mixed precision forward pass
            if scaler is not None:
                with torch.cuda.amp.autocast():
                    sr = model(lr)
                    loss = criterion(sr, hr)

                # Mixed precision backward pass
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                # Standard precision training
                sr = model(lr)
                loss = criterion(sr, hr)
                loss.backward()
                optimizer.step()

            # Track metrics
            epoch_loss += loss.item()
            batch_count += 1
            pbar.set_postfix({'loss': epoch_loss / batch_count})

        # Step the learning rate scheduler
        scheduler.step()

        # Average loss for the epoch
        avg_loss = epoch_loss / batch_count
        train_losses.append(avg_loss)

        # Validation phase (every 2 epochs to save time)
        if (epoch + 1) % 2 == 0 or epoch == 0 or epoch == num_epochs - 1:
            print(f"Validating after epoch {epoch+1}...")

            # Compute validation metrics
            val_psnr, val_ssim = evaluate_model(model, val_loader, device, desc="Validation")
            val_psnrs.append(val_psnr)
            val_ssims.append(val_ssim)

            print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f}, PSNR: {val_psnr:.2f}, SSIM: {val_ssim:.4f}")

            # Save best model
            if val_psnr > best_psnr:
                best_psnr = val_psnr
                print(f"New best model with PSNR: {best_psnr:.2f}")
                torch.save(model.state_dict(), f"{save_dir}/checkpoints/esrgan_best.pth")

            # Visualize samples
            visualize_samples(model, vis_loader, device, epoch+1, save_dir)
        else:
            print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f}")

        # Save checkpoint every 5 epochs
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
            }, f"{save_dir}/checkpoints/esrgan_epoch{epoch+1}.pth")

    # Final model save
    torch.save(model.state_dict(), f"{save_dir}/checkpoints/esrgan_final.pth")

    # Plot training curves
    plt.figure(figsize=(15, 5))

    plt.subplot(1, 3, 1)
    plt.plot(train_losses)
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('L1 Loss')
    plt.grid(True)

    plt.subplot(1, 3, 2)
    plt.plot(val_psnrs)
    plt.title('Validation PSNR')
    plt.xlabel('Epoch')
    plt.ylabel('PSNR (dB)')
    plt.grid(True)

    plt.subplot(1, 3, 3)
    plt.plot(val_ssims)
    plt.title('Validation SSIM')
    plt.xlabel('Epoch')
    plt.ylabel('SSIM')
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(f"{save_dir}/training_curves.png")
    plt.close()

    print("Training completed. Testing best model...")

    # Load best model for testing
    model.load_state_dict(torch.load(f"{save_dir}/checkpoints/esrgan_best.pth"))

    # Test evaluation
    test_psnr, test_ssim = evaluate_model(model, test_loader, device, desc="Testing")
    print(f"Test Results - PSNR: {test_psnr:.2f}, SSIM: {test_ssim:.4f}")

    # Save results
    with open(f"{save_dir}/test_results.txt", "w") as f:
        f.write(f"Test PSNR: {test_psnr:.4f}\n")
        f.write(f"Test SSIM: {test_ssim:.4f}\n")

    return model

def inference(model_path, test_dataset, save_dir):
    """Run inference on test dataset and visualize results"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load model
    model = GeneratorRRDB().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    # Create data loader for test samples
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)

    # Create directory for inference results
    os.makedirs(f"{save_dir}/inference", exist_ok=True)

    # Run inference on test samples
    with torch.no_grad():
        for i, (lr, hr) in enumerate(tqdm(test_loader, desc="Inference")):
            if i >= 20:  # Generate 20 samples
                break

            lr, hr = lr.to(device), hr.to(device)

            # Generate super-resolution image
            sr = model(lr)

            # Calculate metrics
            lr_np = tensor_to_numpy(lr[0])
            hr_np = tensor_to_numpy(hr[0])
            sr_np = tensor_to_numpy(sr[0])

            psnr_val = compute_psnr(sr_np, hr_np, data_range=1.0)
            ssim_val = compute_ssim(sr_np, hr_np, data_range=1.0, channel_axis=2, multichannel=True)

            # Create visualization
            plt.figure(figsize=(18, 8))

            plt.subplot(1, 3, 1)
            plt.imshow(lr_np)
            plt.title(f'Low Resolution ({lr_np.shape[0]}x{lr_np.shape[1]})')
            plt.axis('off')

            plt.subplot(1, 3, 2)
            plt.imshow(sr_np)
            plt.title(f'Super Resolution - PSNR: {psnr_val:.2f}, SSIM: {ssim_val:.4f}')
            plt.axis('off')

            plt.subplot(1, 3, 3)
            plt.imshow(hr_np)
            plt.title(f'High Resolution ({hr_np.shape[0]}x{hr_np.shape[1]})')
            plt.axis('off')

            plt.tight_layout()
            plt.savefig(f"{save_dir}/inference/sample_{i+1}.png", dpi=300)
            plt.close()

if __name__ == "__main__":
    # Train the model
    trained_model = train_esrgan()

    # Create test dataset
    test_dataset = DIV2KSRDataset(split="validation", cache=True, hr_size=1048, lr_size=256)

    # Run inference
    inference(
        model_path=f"{save_dir}/checkpoints/esrgan_best.pth",
        test_dataset=test_dataset,
        save_dir=save_dir
    )

    print(f"All results saved to {save_dir}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using device: cuda
CUDA available: True
CUDA device: Tesla T4
Model created with 16.72M parameters
Loading dataset...


  scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Dataset split: Train=640, Val=80, Test=80
Starting training for 20 epochs...


  with torch.cuda.amp.autocast():
Epoch 1/20: 100%|██████████| 320/320 [09:12<00:00,  1.73s/it, loss=5.25]


Validating after epoch 1...


Validation: 100%|██████████| 80/80 [00:50<00:00,  1.59it/s]


Epoch 1 - Loss: 5.2515, PSNR: 7.74, SSIM: 0.0339
New best model with PSNR: 7.74


Epoch 2/20: 100%|██████████| 320/320 [09:08<00:00,  1.71s/it, loss=0.478]


Validating after epoch 2...


Validation: 100%|██████████| 80/80 [00:49<00:00,  1.62it/s]


Epoch 2 - Loss: 0.4779, PSNR: 12.14, SSIM: 0.1593
New best model with PSNR: 12.14


Epoch 3/20: 100%|██████████| 320/320 [09:07<00:00,  1.71s/it, loss=0.235]


Epoch 3 - Loss: 0.2345


Epoch 4/20: 100%|██████████| 320/320 [09:08<00:00,  1.71s/it, loss=0.125]


Validating after epoch 4...


Validation: 100%|██████████| 80/80 [00:48<00:00,  1.67it/s]


Epoch 4 - Loss: 0.1254, PSNR: 16.13, SSIM: 0.3277
New best model with PSNR: 16.13


Epoch 5/20: 100%|██████████| 320/320 [09:08<00:00,  1.71s/it, loss=0.0815]


Epoch 5 - Loss: 0.0815


Epoch 6/20: 100%|██████████| 320/320 [09:06<00:00,  1.71s/it, loss=0.0553]


Validating after epoch 6...


Validation: 100%|██████████| 80/80 [00:48<00:00,  1.64it/s]


Epoch 6 - Loss: 0.0553, PSNR: 24.50, SSIM: 0.7407
New best model with PSNR: 24.50


Epoch 7/20: 100%|██████████| 320/320 [09:06<00:00,  1.71s/it, loss=0.0418]


Epoch 7 - Loss: 0.0418


Epoch 8/20: 100%|██████████| 320/320 [09:08<00:00,  1.71s/it, loss=0.0348]


Validating after epoch 8...


Validation: 100%|██████████| 80/80 [00:48<00:00,  1.66it/s]


Epoch 8 - Loss: 0.0348, PSNR: 27.67, SSIM: 0.8382
New best model with PSNR: 27.67


Epoch 9/20: 100%|██████████| 320/320 [09:07<00:00,  1.71s/it, loss=0.0319]


Epoch 9 - Loss: 0.0319


Epoch 10/20: 100%|██████████| 320/320 [09:06<00:00,  1.71s/it, loss=0.0272]


Validating after epoch 10...


Validation: 100%|██████████| 80/80 [00:47<00:00,  1.67it/s]


Epoch 10 - Loss: 0.0272, PSNR: 28.77, SSIM: 0.8757
New best model with PSNR: 28.77


Epoch 11/20:  61%|██████    | 195/320 [05:33<03:33,  1.71s/it, loss=0.0232]

In [None]:
%pip install datasets