In [None]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.models import vgg19, VGG19_Weights
from torch.utils.data import Dataset
from PIL import Image
import matplotlib.pyplot as plt
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F


Configuration

In [None]:
# Define the transformations on the data
low_res_size = 128
high_res_size = 512

transform_low = transforms.Compose([
    transforms.Resize((low_res_size, low_res_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

transform_high = transforms.Compose([
    transforms.Resize((high_res_size, high_res_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
BATCH_SIZE = 16
NUM_EPOCHS = 100
NUM_WORKERS = 2

Model Classes

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, discriminator=False, use_act=True, use_bn=True, **kwargs):
        #use_act = True ** should activation function be defined? True then yes
        super().__init__()
        self.use_act = use_act
        self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
        self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
        self.act = (
            nn.LeakyReLU(0.2, inplace=True)
            if discriminator
            else nn.PReLU(num_parameters=out_channels)
        )

    def forward(self, x):
        return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))

In [None]:
class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, scale_factor):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * scale_factor ** 2, kernel_size=3, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(in_channels * scale_factor ** 2)
        self.ps = nn.PixelShuffle(scale_factor)
        self.act = nn.PReLU(num_parameters=in_channels)

    def forward(self, x):
        return self.act(self.ps(self.conv(x)))

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.block1 = ConvBlock(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.block2 = ConvBlock(in_channels, in_channels, kernel_size=3, stride=1, padding=1, use_act=False)

    def forward(self, x):
        out = self.block1(x)
        out = self.block2(out)
        return out + x

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, num_channels=64, num_blocks=16):
        super().__init__()
        self.initial = ConvBlock(in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=False)
        self.residuals = nn.Sequential(*[ResidualBlock(num_channels) for i in range(num_blocks)])
        self.conv = ConvBlock(num_channels, num_channels, kernel_size=3, stride=1, padding=1, use_act=False)
        self.upsamples = nn.Sequential(UpsampleBlock(num_channels, scale_factor=2),UpsampleBlock(num_channels, scale_factor=2))
        self.final = nn.Conv2d(num_channels, in_channels, kernel_size=9, stride=1, padding=4)

    def forward(self, x):
        initial = self.initial(x)
        x = self.residuals(initial)
        x = self.conv(x) + initial
        x = self.upsamples(x)
        return torch.tanh(self.final(x))

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
        super().__init__()
        blocks = []
        for index, feature in enumerate(features):
            blocks.append(
                ConvBlock(in_channels, feature, kernel_size=3, stride=1 + index % 2, padding=1, discriminator=True,
                          use_act=True, use_bn=False if index == 0 else True))
            in_channels = feature

        self.blocks = nn.Sequential(*blocks)
        self.classifier = nn.Sequential(nn.AdaptiveAvgPool2d((6, 6)), nn.Flatten(), nn.Linear(512 * 6 * 6, 1024),
                                        nn.LeakyReLU(0.2, inplace=True), nn.Linear(1024, 1))

    def forward(self, x):
        x = self.blocks(x)
        return self.classifier(x)


def test():
    low_resolution = 24
    with torch.amp.autocast('cuda'):
        x = torch.randn((5, 3, low_resolution, low_resolution))
        gen = Generator()
        gen_out = gen(x)
        disc = Discriminator()
        disc_out = disc(gen_out)
        return gen_out, disc_out


if __name__ == '__main__':
    test()

Loss (from SRGAN_Loss)

In [None]:
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)

class VGGLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features[:36].eval()
        for p in self.vgg.parameters():
            p.requires_grad = False
        self.loss = nn.MSELoss()

    def forward(self, input, target):
        # input/target are in [-1, 1] -> convert to [0, 1]
        input_01  = (input  + 1) / 2
        target_01 = (target + 1) / 2

        mean = IMAGENET_MEAN.to(input.device)
        std  = IMAGENET_STD.to(input.device)

        input_n  = (input_01  - mean) / std
        target_n = (target_01 - mean) / std

        f_in  = self.vgg(input_n)
        f_tgt = self.vgg(target_n)
        return self.loss(f_in, f_tgt)


Dataset (from SRGAN_Dataset)

In [None]:
class SRGANDataset(Dataset):
    def __init__(self, root_dir, transform_low, transform_high):
        self.root_dir = root_dir
        self.transform_low = transform_low
        self.transform_high = transform_high
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir, self.image_files[index])
        img = Image.open(img_path).convert('RGB')

        # Create both versions from the same image
        high_res = self.transform_high(img)
        low_res = self.transform_low(img)

        return low_res, high_res

Training loop

In [None]:
!ls ./data/DIV2K_train_HR/

In [None]:
# Clear GPU memory
torch.cuda.empty_cache()


gen = Generator(in_channels=3).to(DEVICE)
disc = Discriminator(in_channels=3).to(DEVICE)

# Optimizers
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))

# Loss functions
vgg_loss = VGGLoss().to(DEVICE)
bce_loss = nn.BCEWithLogitsLoss()
mse_loss = nn.MSELoss()

train_dataset = SRGANDataset(
    root_dir='./data/DIV2K_train_HR/DIV2K_train_HR/',
    transform_low=transform_low,
    transform_high=transform_high
)

val_dataset = SRGANDataset(
    root_dir='./data/DIV2K_valid_HR/DIV2K_valid_HR/',
    transform_low=transform_low,
    transform_high=transform_high
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")


def train_one_epoch(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg):
    loop = tqdm(loader, leave=True)

    # Track losses across ALL batches
    total_gen_loss = 0
    total_disc_loss = 0
    num_batches = 0

    for idx, (low_res, high_res) in enumerate(loop):
        low_res = low_res.to(DEVICE)
        high_res = high_res.to(DEVICE)

        ### Train Discriminator ###
        fake = gen(low_res)
        disc_real = disc(high_res)
        disc_fake = disc(fake.detach())

        disc_loss_real = bce(disc_real, torch.ones_like(disc_real))
        disc_loss_fake = bce(disc_fake, torch.zeros_like(disc_fake))
        disc_loss = disc_loss_real + disc_loss_fake

        opt_disc.zero_grad()
        disc_loss.backward()
        opt_disc.step()

        ### Train Generator ###
        fake = gen(low_res)
        disc_fake = disc(fake)

        adversarial_loss = bce(disc_fake, torch.ones_like(disc_fake))  # Remove 1e-3 here
        perceptual_loss = vgg(fake, high_res)
        content_loss = mse_loss(fake, high_res)

        gen_loss = content_loss + 0.1 * perceptual_loss + 1e-3 * adversarial_loss



        opt_gen.zero_grad()
        gen_loss.backward()
        opt_gen.step()

        # Accumulate losses
        total_gen_loss += gen_loss.item()
        total_disc_loss += disc_loss.item()
        num_batches += 1

        # Update progress bar
        loop.set_postfix(
            gen_loss=gen_loss.item(),
            disc_loss=disc_loss.item()
        )

    # Return AVERAGE loss across all batches
    return total_gen_loss / num_batches, total_disc_loss / num_batches


d_losses = []
g_losses = []

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")

    gen_loss, disc_loss = train_one_epoch(
        train_loader, disc, gen, opt_gen, opt_disc,
        mse_loss, bce_loss, vgg_loss
    )

    d_losses.append(disc_loss)
    g_losses.append(gen_loss)

    if (epoch + 1) % 10 == 0:
        torch.save(gen.state_dict(), f'generator_epoch_{epoch + 1}.pth')
        torch.save(disc.state_dict(), f'discriminator_epoch_{epoch + 1}.pth')
        print(f"Checkpoint saved at epoch {epoch + 1}")

    # Plot progress every 10 epochs
    if (epoch + 1) % 10 == 0:
        plt.figure(figsize=(10, 5))
        plt.plot(d_losses, label='Discriminator Loss')
        plt.plot(g_losses, label='Generator Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.title('Training Progress')
        plt.show()

print("Training complete!")

torch.save(gen.state_dict(), 'generator_final.pth')
torch.save(disc.state_dict(), 'discriminator_final.pth')
print("Final models saved!")

In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr

def denormalize(tensor):
    return (tensor * 0.5) + 0.5

# Get one batch of test images
gen.eval()
with torch.no_grad():
    low_res, high_res = next(iter(val_loader))
    low_res = low_res.to(DEVICE)
    high_res = high_res.to(DEVICE)

    # Generate SR images
    sr_images = gen(low_res)

    # Take first 3 images from batch
    for i in range(min(3, low_res.size(0))):
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))

        # Low-res (upscaled with bicubic)
        lr_np = denormalize(low_res[i]).cpu().numpy().transpose(1, 2, 0)
        lr_upscaled = np.array(Image.fromarray((lr_np * 255).astype(np.uint8)).resize((128, 128), Image.BICUBIC)) / 255.0
        axes[0].imshow(lr_upscaled)
        axes[0].set_title('Bicubic Upscale')
        axes[0].axis('off')

        # SRGAN output
        sr_np = denormalize(sr_images[i]).cpu().numpy().transpose(1, 2, 0)
        axes[1].imshow(np.clip(sr_np, 0, 1))
        axes[1].set_title('SRGAN Output')
        axes[1].axis('off')

        # High-res ground truth
        hr_np = denormalize(high_res[i]).cpu().numpy().transpose(1, 2, 0)
        axes[2].imshow(np.clip(hr_np, 0, 1))
        axes[2].set_title('Ground Truth')
        axes[2].axis('off')

        # Calculate PSNR (ADD THESE 3 LINES HERE)
        psnr_bicubic = psnr(hr_np, lr_upscaled)
        psnr_srgan = psnr(hr_np, np.clip(sr_np, 0, 1))
        print(f"Image {i}: Bicubic PSNR = {psnr_bicubic:.2f} dB, SRGAN PSNR = {psnr_srgan:.2f} dB")


        plt.tight_layout()
        plt.savefig(f'comparison_{i}.png', dpi=150, bbox_inches='tight')
        plt.show()
        plt.close()

gen.train()