# SRGAN demonstration for QR code enhancement (PyTorch)


## Installing dependencies + Unzipping data


In [None]:
!unzip archive.zip


In [None]:
!pip install torch torchvision opencv-python tqdm scikit-image


In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from main_pytorch import Generator, Discriminator, build_vgg

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


## Looking at a single example


In [None]:
datadir = 'qr_dataset'

first_image = None
for img in os.listdir(datadir):
    img_array = cv2.imread(os.path.join(datadir, img), cv2.IMREAD_COLOR)
    if img_array is None:
        continue
    first_image = cv2.cvtColor(img_array, cv2.COLOR_BGR2RGB)
    break

if first_image is not None:
    plt.imshow(first_image)
    plt.title('Sample QR image (RGB)')
    plt.axis('off')
else:
    print('No images found in dataset.')


## Storing the images


In [None]:
high_res_images = []
low_res_images = []

def create_training_data():
    for img in tqdm(list(os.listdir(datadir))):
        img_path = os.path.join(datadir, img)
        img_array = cv2.imread(img_path, cv2.IMREAD_COLOR)
        if img_array is None:
            continue
        img_rgb = cv2.cvtColor(img_array, cv2.COLOR_BGR2RGB)
        high_res = cv2.resize(img_rgb, (128, 128))
        low_res = cv2.resize(img_rgb, (32, 32), interpolation=cv2.INTER_AREA)
        high_res_images.append(high_res)
        low_res_images.append(low_res)

create_training_data()
print(f'Loaded {len(high_res_images)} images')


In [None]:
high_res = np.array(high_res_images, dtype=np.float32) / 255.0
low_res = np.array(low_res_images, dtype=np.float32) / 255.0
print('High-res shape:', high_res.shape)
print('Low-res shape:', low_res.shape)


In [None]:
X_train, X_valid, y_train, y_valid = train_test_split(low_res, high_res, test_size=0.2, random_state=42)
print('Training set:', X_train.shape, y_train.shape)
print('Validation set:', X_valid.shape, y_valid.shape)


In [None]:
def numpy_to_tensor(arr):
    tensor = torch.from_numpy(arr).permute(0, 3, 1, 2).float()
    return tensor

X_train_t = numpy_to_tensor(X_train)
y_train_t = numpy_to_tensor(y_train)
X_valid_t = numpy_to_tensor(X_valid)
y_valid_t = numpy_to_tensor(y_valid)

train_dataset = TensorDataset(X_train_t, y_train_t)
valid_dataset = TensorDataset(X_valid_t, y_valid_t)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=8, shuffle=False)

print('Training batches:', len(train_loader))
print('Validation batches:', len(valid_loader))


## Building the PyTorch models


In [None]:
generator = Generator(res_blocks=1, upsample_blocks=2).to(device)
discriminator = Discriminator().to(device)
vgg = build_vgg().to(device)
for param in vgg.parameters():
    param.requires_grad = False

adversarial_criterion = nn.BCELoss()
pixel_criterion = nn.MSELoss()
content_criterion = nn.MSELoss()

optimizer_G = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.9, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.9, 0.999))

print('Models ready on device:', device)


In [None]:
imagenet_mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
imagenet_std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)

def preprocess_vgg(x):
    return (x - imagenet_mean) / imagenet_std

def train_one_epoch(epoch):
    generator.train()
    discriminator.train()
    running_g_loss = 0.0
    running_d_loss = 0.0
    for lr_imgs, hr_imgs in train_loader:
        lr_imgs = lr_imgs.to(device)
        hr_imgs = hr_imgs.to(device)

        valid = torch.ones((lr_imgs.size(0), 1), device=device)
        fake = torch.zeros((lr_imgs.size(0), 1), device=device)

        optimizer_G.zero_grad()
        gen_imgs = generator(lr_imgs)
        pred_fake = discriminator(gen_imgs).view(-1, 1)
        g_adv = adversarial_criterion(pred_fake, valid)
        gen_features = vgg(preprocess_vgg(gen_imgs))
        with torch.no_grad():
            real_features = vgg(preprocess_vgg(hr_imgs))
        g_content = content_criterion(gen_features, real_features)
        g_pixel = pixel_criterion(gen_imgs, hr_imgs)
        g_loss = g_pixel + 1e-3 * g_adv + 0.006 * g_content
        g_loss.backward()
        optimizer_G.step()

        optimizer_D.zero_grad()
        pred_real = discriminator(hr_imgs).view(-1, 1)
        loss_real = adversarial_criterion(pred_real, valid)
        pred_fake = discriminator(gen_imgs.detach()).view(-1, 1)
        loss_fake = adversarial_criterion(pred_fake, fake)
        d_loss = 0.5 * (loss_real + loss_fake)
        d_loss.backward()
        optimizer_D.step()

        running_g_loss += g_loss.item() * lr_imgs.size(0)
        running_d_loss += d_loss.item() * lr_imgs.size(0)

    epoch_g_loss = running_g_loss / len(train_dataset)
    epoch_d_loss = running_d_loss / len(train_dataset)
    return epoch_g_loss, epoch_d_loss

def validate():
    generator.eval()
    pixel_losses = []
    with torch.no_grad():
        for lr_imgs, hr_imgs in valid_loader:
            lr_imgs = lr_imgs.to(device)
            hr_imgs = hr_imgs.to(device)
            gen_imgs = generator(lr_imgs)
            pixel_losses.append(pixel_criterion(gen_imgs, hr_imgs).item())
    return float(np.mean(pixel_losses)) if pixel_losses else 0.0


In [None]:
num_epochs = 5
train_history = []
for epoch in range(1, num_epochs + 1):
    g_loss, d_loss = train_one_epoch(epoch)
    val_loss = validate()
    train_history.append((epoch, g_loss, d_loss, val_loss))
    print(f'Epoch {epoch}/{num_epochs} | G_loss: {g_loss:.4f} | D_loss: {d_loss:.4f} | Val pixel loss: {val_loss:.4f}')


## Checking the generator output


In [None]:
generator.eval()
with torch.no_grad():
    sample_lr = X_valid_t[:1].to(device)
    sample_hr = y_valid_t[:1].to(device)
    sr_image = generator(sample_lr).cpu().squeeze(0).permute(1, 2, 0).numpy()
    lr_image = sample_lr.cpu().squeeze(0).permute(1, 2, 0).numpy()
    hr_image = sample_hr.cpu().squeeze(0).permute(1, 2, 0).numpy()

sr_image = np.clip(sr_image, 0, 1)
lr_image = np.clip(lr_image, 0, 1)
hr_image = np.clip(hr_image, 0, 1)

plt.figure(figsize=(16, 6))
plt.subplot(1, 3, 1)
plt.title('LR Image (Upsampled)')
plt.imshow(cv2.resize(lr_image, (128, 128)))
plt.axis('off')
plt.subplot(1, 3, 2)
plt.title('Superresolution (Generator)')
plt.imshow(sr_image)
plt.axis('off')
plt.subplot(1, 3, 3)
plt.title('Original HR Image')
plt.imshow(hr_image)
plt.axis('off')
plt.show()
