<a href="https://colab.research.google.com/github/iamfaham/image-upscaling-GAN/blob/main/image_upscaling_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# real-ESRGAN

In [None]:
!pip install kagglehub lpips torchsummary opencv-python-headless

Collecting lpips
  Downloading lpips-0.1.4-py3-none-any.whl.metadata (10 kB)
Downloading lpips-0.1.4-py3-none-any.whl (53 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lpips
Successfully installed lpips-0.1.4


In [None]:
# ================================================
# Real-ESRGAN Style Training Script (Colab Ready)
# Datasets: DIV2K + Flickr2K via KaggleHub
# ================================================
import os, random, glob
import numpy as np
import cv2
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision.transforms import functional as TF
from torchvision.models import vgg19
from torchvision import transforms # Import transforms

import lpips

# ============================
# Dataset download via KaggleHub
# ============================
import kagglehub

def get_dataset_path(kaggle_id):
    base_path = kagglehub.dataset_download(kaggle_id)
    candidates = []
    for root, dirs, files in os.walk(base_path):
        if any(f.endswith((".png",".jpg",".jpeg")) for f in files):
            candidates.append(root)
    if len(candidates) == 0:
        raise RuntimeError(f"No image files found under {base_path}")
    return candidates[0]

DIV2K_DIR = get_dataset_path("soumikrakshit/div2k-high-resolution-images")
FLICKR2K_DIR = get_dataset_path("hliang001/flickr2k")

print("DIV2K sample:", glob.glob(os.path.join(DIV2K_DIR,"*.png"))[:3])
print("Flickr2K sample:", glob.glob(os.path.join(FLICKR2K_DIR,"*.png"))[:3])

# ============================
# Realistic Degradation Function
# ============================
def degrade_image_pil(hr_pil):
    hr = np.array(hr_pil).astype(np.uint8)
    h, w = hr.shape[:2]

    # Blur
    if random.random() < 0.7:
        k = random.choice([1,3,5,7])
        if k > 1:
            hr = cv2.GaussianBlur(hr, (k,k), 0)

    # Downscale with random factor
    scale_choice = random.choice([2,3,4])
    interp_down = random.choice([cv2.INTER_AREA, cv2.INTER_LINEAR, cv2.INTER_CUBIC])
    lr_small = cv2.resize(hr, (w//scale_choice, h//scale_choice), interpolation=interp_down)

    # JPEG compression
    if random.random() < 0.8:
        q = random.randint(30, 95)
        _, enc = cv2.imencode('.jpg', lr_small, [int(cv2.IMWRITE_JPEG_QUALITY), q])
        lr_small = cv2.imdecode(enc, cv2.IMREAD_COLOR)

    # Add Gaussian noise
    if random.random() < 0.5:
        sigma = random.uniform(0, 8)
        noise = np.random.randn(*lr_small.shape) * sigma
        lr_small = np.clip(lr_small + noise, 0, 255).astype(np.uint8)

    # Contrast jitter
    if random.random() < 0.5:
        alpha = random.uniform(0.9, 1.1)
        lr_small = np.clip(lr_small * alpha, 0, 255).astype(np.uint8)

    return Image.fromarray(lr_small)

# ============================
# Dataset
# ============================
def center_crop(img: Image.Image, size: int) -> Image.Image:
    """Crop the image to a square of given size at the center."""
    w, h = img.size
    left = (w - size) // 2
    top = (h - size) // 2
    right = left + size
    bottom = top + size
    return img.crop((left, top, right, bottom))

class RealSRDataset(Dataset):
    def __init__(self, hr_dir, scale=4, hr_patch=128, train=True, augment=True):
        self.paths = sorted(Path(hr_dir).glob("*.png")) + sorted(Path(hr_dir).glob("*.jpg"))
        if len(self.paths) == 0:
            raise RuntimeError(f"No images found in {hr_dir}")
        self.scale = scale
        self.hr_patch = hr_patch
        self.train = train
        self.augment = augment

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        w, h = img.size

        # --- fixed HR crop ---
        if self.train:
            if w < self.hr_patch or h < self.hr_patch:
                img = img.resize((max(w, self.hr_patch), max(h, self.hr_patch)), Image.BICUBIC)
                w, h = img.size
            x = random.randint(0, w - self.hr_patch)
            y = random.randint(0, h - self.hr_patch)
            hr = img.crop((x, y, x + self.hr_patch, y + self.hr_patch))
        else:
            hr = center_crop(img, self.hr_patch)

        # --- generate LR patch (fixed size) ---
        lr_size = self.hr_patch // self.scale
        degrade = transforms.Compose([
            transforms.Resize((lr_size, lr_size), interpolation=Image.BICUBIC),
            transforms.ToTensor()
        ])
        lr = degrade(hr)

        hr = TF.to_tensor(hr)

        # --- augmentation ---
        if self.train and self.augment and random.random() < 0.5:
            lr = torch.flip(lr, [2])
            hr = torch.flip(hr, [2])
        if self.train and self.augment and random.random() < 0.5:
            lr = torch.flip(lr, [1])
            hr = torch.flip(hr, [1])

        return lr, hr


# ============================
# Model: Generator & Discriminator
# ============================
class ResidualDenseBlock(nn.Module):
    def __init__(self, channels=64, growth=32):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(5):
            in_c = channels + i * growth
            self.layers.append(nn.Conv2d(in_c, growth if i<4 else channels, 3, 1, 1))
    def forward(self, x):
        inputs = [x]
        for i, layer in enumerate(self.layers):
            out = F.leaky_relu(layer(torch.cat(inputs, 1)), 0.2) if i<4 else layer(torch.cat(inputs, 1))
            inputs.append(out)
        return out * 0.2 + x

class RRDB(nn.Module):
    def __init__(self, channels=64, growth=32):
        super().__init__()
        self.rdb1 = ResidualDenseBlock(channels, growth)
        self.rdb2 = ResidualDenseBlock(channels, growth)
        self.rdb3 = ResidualDenseBlock(channels, growth)
    def forward(self, x):
        return self.rdb3(self.rdb2(self.rdb1(x))) * 0.2 + x

class Generator(nn.Module):
    def __init__(self, scale=4, channels=64, blocks=8):
        super().__init__()
        self.conv1 = nn.Conv2d(3, channels, 3, 1, 1)
        self.trunk = nn.Sequential(*[RRDB(channels) for _ in range(blocks)])
        self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1)
        up = []
        for _ in range(int(np.log2(scale))):
            up += [nn.Conv2d(channels, channels*4, 3, 1, 1), nn.PixelShuffle(2), nn.LeakyReLU(0.2, inplace=True)]
        self.up = nn.Sequential(*up)
        self.conv3 = nn.Conv2d(channels, 3, 3, 1, 1)
    def forward(self, x):
        fea = self.conv1(x)
        trunk = self.conv2(self.trunk(fea))
        fea = fea + trunk
        fea = self.up(fea)
        return torch.clamp(self.conv3(fea), 0, 1)

class Discriminator(nn.Module):
    def __init__(self, in_ch=3):
        super().__init__()
        def block(in_f, out_f, stride=1):
            return [nn.Conv2d(in_f, out_f, 3, stride, 1),
                    nn.BatchNorm2d(out_f),
                    nn.LeakyReLU(0.2, inplace=True)]
        layers = []
        in_f = in_ch
        for out_f in [64,128,256,512]:
            layers += block(in_f, out_f, stride=2)
            in_f = out_f
        self.features = nn.Sequential(*layers)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512, 1)
        )
    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)

# ============================
# Losses & Helpers
# ============================
class PerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = vgg19(pretrained=True).features[:35].eval()
        for p in vgg.parameters():
            p.requires_grad = False
        self.vgg = vgg
    def forward(self, x, y):
        return F.l1_loss(self.vgg(x), self.vgg(y))

class EMA:
    def __init__(self, model, decay=0.999):
        self.model = model
        self.shadow = {k: v.clone().detach() for k,v in model.state_dict().items()}
        self.decay = decay
    def update(self):
        for k, v in self.model.state_dict().items():
            self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1-self.decay)
    def apply_shadow(self):
        self.backup = {k: v.clone() for k,v in self.model.state_dict().items()}
        self.model.load_state_dict(self.shadow)
    def restore(self):
        self.model.load_state_dict(self.backup)

# ============================
# Training Setup
# ============================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SCALE = 4
HR_PATCH = 128
BATCH_SIZE = 8

train_div = RealSRDataset(DIV2K_DIR, scale=SCALE, hr_patch=HR_PATCH, train=True)
train_flickr = RealSRDataset(FLICKR2K_DIR, scale=SCALE, hr_patch=HR_PATCH, train=True)
train_ds = ConcatDataset([train_div, train_flickr])
val_ds   = RealSRDataset(DIV2K_DIR, scale=SCALE, hr_patch=HR_PATCH, train=False)

import multiprocessing
num_workers = min(2, max(0, multiprocessing.cpu_count() - 1))
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=(DEVICE=="cuda"))
val_dl   = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=min(1,num_workers), pin_memory=(DEVICE=="cuda"))

G = Generator(scale=SCALE).to(DEVICE)
D = Discriminator().to(DEVICE)
percep = PerceptualLoss().to(DEVICE)
lpips_metric = lpips.LPIPS(net='alex').to(DEVICE)

optG = torch.optim.Adam(G.parameters(), lr=1e-4, betas=(0.9,0.999))
optD = torch.optim.Adam(D.parameters(), lr=1e-4, betas=(0.9,0.999))

ema = EMA(G)

# ============================
# Training Loop (shortened)
# ============================
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim

max_steps = 15000   # for Colab demo; increase to ~15000 for stronger training
log_interval = 500
save_interval = 3000
os.makedirs("checkpoints", exist_ok=True)

step = 0
while step < max_steps:
    for lr_imgs, hr_imgs in train_dl:
        step += 1
        lr_imgs, hr_imgs = lr_imgs.to(DEVICE), hr_imgs.to(DEVICE)

        # D step
        optD.zero_grad()
        sr = G(lr_imgs)
        pred_real = D(hr_imgs)
        pred_fake = D(sr.detach())
        loss_D = F.mse_loss(pred_real, torch.ones_like(pred_real)) + F.mse_loss(pred_fake, torch.zeros_like(pred_fake))
        loss_D.backward(); optD.step()

        # G step
        optG.zero_grad()
        pred_fake_forG = D(sr)
        adv_loss = F.mse_loss(pred_fake_forG, torch.ones_like(pred_fake_forG))
        pix_loss = F.l1_loss(sr, hr_imgs)
        perc_loss = percep(sr, hr_imgs)
        loss_G = pix_loss + 0.01*perc_loss + 0.005*adv_loss
        loss_G.backward(); optG.step()

        ema.update()

        if step % log_interval == 0:
            print(f"[Step {step}] D_loss: {loss_D.item():.4f} | G_loss: {loss_G.item():.4f}")

        if step % save_interval == 0:
            torch.save(G.state_dict(), f"checkpoints/G_step{step}.pth")

        if step >= max_steps:
            break

# ============================
# Validation
# ============================
G.eval(); ema.apply_shadow()
psnr_list, ssim_list, lpips_list = [], [], []
for i, (lr, hr) in enumerate(val_dl):
    if i >= 20: break
    lr, hr = lr.to(DEVICE), hr.to(DEVICE)
    with torch.no_grad():
        sr = G(lr)
    hr_np = hr.squeeze().permute(1,2,0).cpu().numpy()
    sr_np = sr.squeeze().permute(1,2,0).cpu().numpy()
    psnr_list.append(psnr(hr_np, sr_np, data_range=1.0))
    ssim_list.append(ssim(hr_np, sr_np, channel_axis=2, data_range=1.0))
    lpips_list.append(lpips_metric(sr, hr).item())

print("Avg PSNR:", np.mean(psnr_list))
print("Avg SSIM:", np.mean(ssim_list))
print("Avg LPIPS:", np.mean(lpips_list))
ema.restore()

# ============================
# Inference on Your Own Image
# ============================
def upscale_image(img_path, model_path, scale=4, save_path="output.png"):
    model = Generator(scale=scale).to(DEVICE)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.eval()

    img = Image.open(img_path).convert("RGB")
    lr_t = TF.to_tensor(img).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        sr = model(lr_t)
    sr_img = sr.squeeze().permute(1,2,0).cpu().numpy()
    sr_img = (sr_img*255).astype(np.uint8)
    Image.fromarray(sr_img).save(save_path)
    print(f"Saved upscaled image to {save_path}")

# Example usage:
upscale_image("input.jpg", "checkpoints/G_step15000.pth")

Downloading from https://www.kaggle.com/api/v1/datasets/download/soumikrakshit/div2k-high-resolution-images?dataset_version_number=1...


100%|██████████| 3.71G/3.71G [02:59<00:00, 22.2MB/s]

Extracting files...





Downloading from https://www.kaggle.com/api/v1/datasets/download/hliang001/flickr2k?dataset_version_number=1...


100%|██████████| 20.0G/20.0G [15:49<00:00, 22.7MB/s]

Extracting files...





DIV2K sample: ['/root/.cache/kagglehub/datasets/soumikrakshit/div2k-high-resolution-images/versions/1/DIV2K_train_HR/DIV2K_train_HR/0137.png', '/root/.cache/kagglehub/datasets/soumikrakshit/div2k-high-resolution-images/versions/1/DIV2K_train_HR/DIV2K_train_HR/0719.png', '/root/.cache/kagglehub/datasets/soumikrakshit/div2k-high-resolution-images/versions/1/DIV2K_train_HR/DIV2K_train_HR/0059.png']
Flickr2K sample: ['/root/.cache/kagglehub/datasets/hliang001/flickr2k/versions/1/Flickr2K/Flickr2K_LR_unknown/X3/001223x3.png', '/root/.cache/kagglehub/datasets/hliang001/flickr2k/versions/1/Flickr2K/Flickr2K_LR_unknown/X3/000654x3.png', '/root/.cache/kagglehub/datasets/hliang001/flickr2k/versions/1/Flickr2K/Flickr2K_LR_unknown/X3/001872x3.png']




Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


100%|██████████| 548M/548M [00:07<00:00, 74.6MB/s]


Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth


100%|██████████| 233M/233M [00:01<00:00, 143MB/s]


Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth
[Step 500] D_loss: 0.0799 | G_loss: 0.0467
[Step 1000] D_loss: 0.0927 | G_loss: 0.0591
[Step 1500] D_loss: 0.1592 | G_loss: 0.0433
[Step 2000] D_loss: 0.0584 | G_loss: 0.0445
[Step 2500] D_loss: 0.0117 | G_loss: 0.0482
[Step 3000] D_loss: 0.0096 | G_loss: 0.0436
[Step 3500] D_loss: 0.0022 | G_loss: 0.0427
[Step 4000] D_loss: 0.0022 | G_loss: 0.0385
[Step 4500] D_loss: 0.0030 | G_loss: 0.0371
[Step 5000] D_loss: 0.0105 | G_loss: 0.0487
[Step 5500] D_loss: 0.3187 | G_loss: 0.0391
[Step 6000] D_loss: 0.0402 | G_loss: 0.0476
[Step 6500] D_loss: 0.0421 | G_loss: 0.0417
[Step 7000] D_loss: 0.0527 | G_loss: 0.0466
[Step 7500] D_loss: 0.2620 | G_loss: 0.0420
[Step 8000] D_loss: 0.5213 | G_loss: 0.0430
[Step 8500] D_loss: 0.4668 | G_loss: 0.0382
[Step 9000] D_loss: 0.2695 | G_loss: 0.0297
[Step 9500] D_loss: 0.0999 | G_loss: 0.0471
[Step 10000] D_loss: 0.0124 | G_loss: 0.0557
[Step 10500] D_loss: 0.1249 | G_

In [None]:
# === SAVE FILES NEEDED FOR INFERENCE ===
import torch, textwrap
from google.colab import files

# 1) Save EMA generator weights
ema.apply_shadow()  # swap EMA params into G
torch.save(G.state_dict(), "generator_sr_ema.pth")
ema.restore()

# 2) Save inference script (architecture + usage)
inference_code = textwrap.dedent("""
import torch
import torch.nn as nn
import math
from PIL import Image
import torchvision.transforms.functional as TF

# === Generator definition (RRDBNet, same as training) ===
class ResidualDenseBlock(nn.Module):
    def __init__(self, channels=64, growth=32):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(5):
            in_c = channels + i * growth
            out_c = growth if i < 4 else channels
            self.layers.append(nn.Conv2d(in_c, out_c, 3, 1, 1))
        self.lrelu = nn.LeakyReLU(0.2, inplace=True)
    def forward(self, x):
        inputs = [x]
        for i, layer in enumerate(self.layers):
            out = layer(torch.cat(inputs, 1))
            if i < 4:
                out = self.lrelu(out)
            inputs.append(out)
        return out * 0.2 + x

class RRDB(nn.Module):
    def __init__(self, channels=64, growth=32):
        super().__init__()
        self.rdb1 = ResidualDenseBlock(channels, growth)
        self.rdb2 = ResidualDenseBlock(channels, growth)
        self.rdb3 = ResidualDenseBlock(channels, growth)
    def forward(self, x):
        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        return out * 0.2 + x

class RRDBNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, nf=64, nb=8, gc=32, scale=4):
        super().__init__()
        self.scale = scale
        self.conv_first = nn.Conv2d(in_ch, nf, 3, 1, 1)
        self.trunk = nn.Sequential(*[RRDB(nf, gc) for _ in range(nb)])
        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1)
        up_layers = []
        for _ in range(int(math.log2(self.scale))):
            up_layers += [
                nn.Conv2d(nf, nf*4, 3, 1, 1),
                nn.PixelShuffle(2),
                nn.LeakyReLU(0.2, inplace=True)
            ]
        self.upsampler = nn.Sequential(*up_layers)
        self.conv_last = nn.Conv2d(nf, out_ch, 3, 1, 1)
    def forward(self, x):
        fea = self.conv_first(x)
        trunk = self.trunk_conv(self.trunk(fea))
        fea = fea + trunk
        fea = self.upsampler(fea)
        out = self.conv_last(fea)
        return torch.sigmoid(out)

# === Inference helper ===
def upscale_image(input_path, output_path="sr_out.png", model_path="generator_sr_ema.pth", scale=4, nb=8):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    G = RRDBNet(nb=nb, scale=scale).to(device)
    G.load_state_dict(torch.load(model_path, map_location=device))
    G.eval()

    img = Image.open(input_path).convert("RGB")
    t = TF.to_tensor(img).unsqueeze(0).to(device)

    with torch.no_grad():
        sr = G(t).clamp(0,1)

    sr_img = TF.to_pil_image(sr.squeeze().cpu())
    sr_img.save(output_path)
    print(f"Saved: {output_path}")
    return output_path

# Example usage:
# upscale_image("myphoto.png", "myphoto_sr.png", model_path="generator_sr_ema.pth")
""")

with open("inference.py", "w") as f:
    f.write(inference_code)

# 3) Download both files to local system
files.download("generator_sr_ema.pth")
files.download("inference.py")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>