Code inspired by https://github.com/assafshocher/ZSSR and https://github.com/jacobgil/pytorch-zssr 

In [None]:
%pip install torch torchvision pillow scikit-image lpips matplotlib 


Collecting models
  Using cached models-0.9.3.tar.gz (16 kB)
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'error'
Note: you may need to restart the kernel to use updated packages.


  error: subprocess-exited-with-error
  
  × python setup.py egg_info did not run successfully.
  │ exit code: 1
  ╰─> [8 lines of output]
      Traceback (most recent call last):
        File "<string>", line 2, in <module>
        File "<pip-setuptools-caller>", line 35, in <module>
        File "C:\Users\blobf.DESKTOP-IUEL8R6\AppData\Local\Temp\pip-install-8q3wi8ft\models_0d20757c3ae2494c8a442aa4f67535a7\setup.py", line 25, in <module>
          import models
        File "C:\Users\blobf.DESKTOP-IUEL8R6\AppData\Local\Temp\pip-install-8q3wi8ft\models_0d20757c3ae2494c8a442aa4f67535a7\models\__init__.py", line 23, in <module>
          from base import *
      ModuleNotFoundError: No module named 'base'
      [end of output]
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
error: metadata-generation-failed

× Encountered error while generating package metadata.
╰─> See above for output.

note: This is an issue with the package mentioned above, n

Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as TF
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torchvision import transforms
from PIL import Image
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import glob
import os
import math
import random
from IPython.display import clear_output


from models import *
from utils import *
from utils.sr_utils import * 
from utils.common_utils import *

# Standard GPU check (using Lab syntax 'dtype')
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")


Get LR dataset and HR dataset (for ground truths)

In [None]:
class DIV2KDataset(Dataset):
    def __init__(self, hr_dir, lr_dir):
        self.hr_paths = sorted(glob.glob(os.path.join(hr_dir, '*.png')))
        self.lr_paths = sorted(glob.glob(os.path.join(lr_dir, '*.png')))

        self.transform = transforms.ToTensor() # PIL to Tensor
    
    def __len__(self): # dataloader needs access to length
        return len(self.hr_paths)

    def __getitem__(self, index): # dataloader needs access to dataset items by index
        lr_img = self.transform(Image.open(self.lr_paths[index])) # DIV2K is RGB images
        hr_img = self.transform(Image.open(self.hr_paths[index]))
        return lr_img, hr_img

BASE_DIR = Path.cwd()
DATA_DIR = BASE_DIR / 'data'

train_dataset = DIV2KDataset(
    hr_dir=str(DATA_DIR / 'DIV2K_train_HR'),
    lr_dir=str(DATA_DIR / 'DIV2K_train_LR_x8')
)

val_dataset = DIV2KDataset(
    hr_dir=str(DATA_DIR / 'DIV2K_valid_HR'),
    lr_dir=str(DATA_DIR / 'DIV2K_valid_LR_x8')
)

Get data - 1 randomly selected image (LR as x0 and corresponding HR for PSNR)

In [None]:
# select a random index
index = random.randint(0, len(train_dataset) - 1)
print(f"Selecting image index: {index} from dataset")

# get the LR and HR images at that index
img_LR_tensor, img_HR_tensor = train_dataset[index]

# convert from [C, H,W] to [1, C, H, W] and move to GPU
img_LR_var = img_LR_tensor.unsqueeze(0).to(device)
img_HR_var = img_HR_tensor.unsqueeze(0).to(device)

print(f"HR Image Shape: {img_HR_var.shape}")
print(f"LR Input Shape: {img_LR_var.shape}")

Define network hyperparameters

In [None]:
SCALE_FACTOR = 4        # We want to go from LR -> HR (x4)
EPOCHS = 15
CROPS_PER_EPOCH = 500   # Number of training examples extracted per epoch
LEARNING_RATE = 0.0005

Define degradation function

In [None]:
def degradation(img_tensor, scale=0.5):
    # downsamples an image by "scale" to create a lower-resolution version
    # so the network can learn how to reverse the degradation
    return TF.interpolate(
        img_tensor,
        scale_factor=scale,
        mode='bicubic',
        align_corners=False
    )

Create the internal dataset for the image - smaller crops of the lr image are downsampled to learn the map of super lr to lr, which will then later be applied to lr to hr. 

For this stage, we treat the lr image as the ground truth and try to learn the mapping to it from smaller even lower resolution images.

We use data augmentation here (flips and rotations) in order to get more training data to learn our mapping function

In [None]:
class ZSSRInternalDataset(Dataset):
    def __init__(self, target_img, num_samples=1000, crop_size=64):
        self.target = target_img
        self.num_samples = num_samples
        self.crop_size = crop_size

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # randomly crop the target (our original lr image)
        _, _, h, w = self.target.shape
        
        cs = self.crop_size
        top = random.randint(0, max(0, h - cs - 1))
        left = random.randint(0, max(0, w - cs - 1))

        hr_crop = self.target[:, :, top:top+cs, left:left+cs]

        # create the input by degrading/downsampling the cropped area
        lr_crop = degradation(hr_crop, scale=0.5)

        # squeeze to (C, H, W) for augmentation
        hr_crop = hr_crop.squeeze(0)
        lr_crop = lr_crop.squeeze(0)

        # data augmentation
        if random.random() > 0.5: # Horizontal Flip
            hr_crop = TF.hflip(hr_crop)
            lr_crop = TF.hflip(lr_crop)
        if random.random() > 0.5: # Vertical Flip
            hr_crop = TF.vflip(hr_crop)
            lr_crop = TF.vflip(lr_crop)
        if random.random() > 0.5: # 90-degree Rotation
            hr_crop = torch.rot90(hr_crop, 1, [1, 2])
            lr_crop = torch.rot90(lr_crop, 1, [1, 2])

        return lr_crop, hr_crop

Network architecture

In [None]:
class ZSSRNet(nn.Module):
    def __init__(self, channels=64):
        super(ZSSRNet, self).__init__()
        
        # head
        self.head = nn.Conv2d(3, channels, kernel_size=3, padding=1)
        
        # body "We use a simple, fully convolutional network, with 8 hidden layers"
        body_layers = []
        for _ in range(8):
            body_layers.append(nn.Conv2d(channels, channels, kernel_size=3, padding=1))
            body_layers.append(nn.ReLU(inplace=True))
        self.body = nn.Sequential(*body_layers)
        
        # tail (predicts residual, i.e. corrections)
        self.tail = nn.Conv2d(channels, 3, kernel_size=3, padding=1)

    def forward(self, x):
        # bicubic upsample - "base guess" is direct upsampling - this will be blurry but we can learn improvements
        x_upscaled = TF.interpolate(x, scale_factor=2, mode='bicubic', align_corners=False)
        
        # predict residual (corrections to the blurry upsampled image)
        feat = self.head(x_upscaled)
        feat = self.body(feat)
        residual = self.tail(feat)
        
        # add residual to base 
        return x_upscaled + residual

Training loop

In [None]:

# use lr image as the training target
model_zssr = ZSSRNet().to(device)
optimizer_zssr = torch.optim.Adam(model_zssr.parameters(), lr=LEARNING_RATE)
zssr_ds = ZSSRInternalDataset(img_LR_var, num_samples=CROPS_PER_EPOCH)
zssr_loader = DataLoader(zssr_ds, batch_size=16, shuffle=True)

loss_history = []
print("Starting ZSSR Training...")

for epoch in range(EPOCHS):
    epoch_loss = 0
    model_zssr.train()
    
    for i, (lr_batch, hr_batch) in enumerate(zssr_loader):
        lr_batch, hr_batch = lr_batch.to(device), hr_batch.to(device)
        
        output = model_zssr(lr_batch)
        
        # L1 loss like in the original paper
        loss = TF.l1_loss(output, hr_batch)
        
        optimizer_zssr.zero_grad()
        loss.backward()
        optimizer_zssr.step()
        
        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(zssr_loader)
    loss_history.append(avg_loss)

    if epoch % 3 == 0 or epoch == EPOCHS - 1:
        clear_output(wait=True)
        
        # Run a quick test on the full LR image (x2 scale)
        model_zssr.eval()
        with torch.no_grad():
            test_out = model_zssr(img_LR_var)
            disp_pred = np.clip(test_out.squeeze(0).permute(1, 2, 0).cpu().numpy(), 0, 1)
            
            plt.figure(figsize=(15, 5))
            plt.subplot(1, 2, 1)
            plt.plot(loss_history, label='L1 Loss')
            plt.title(f"Training Loss (Epoch {epoch})")
            plt.legend()
            
            plt.subplot(1, 2, 2)
            plt.imshow(disp_pred)
            plt.title(f"Internal Validation (x2) | Epoch {epoch}")
            plt.axis('off')
            plt.show()

print("Training Complete.")

Inference - 3 passes: x2, x4, x8 to reach from LR to HR 

In [None]:
# --- Cell: Final Inference (x8 Upscaling) ---

print("Running ZSSR Inference (x8 Gradual)...")
model_zssr.eval()

with torch.no_grad():
    # Pass 1: LR -> x2
    print("Upscaling Step 1/3 (x2)...")
    sr_x2 = model_zssr(img_LR_var)
    
    # Pass 2: x2 -> x4
    # feed the result of step 1 back into the network
    print("Upscaling Step 2/3 (x4)...")
    sr_x4 = model_zssr(sr_x2)
    
    # Pass 3: x4 -> x8 
    # feed the result of step 2 back into the network
    print("Upscaling Step 3/3 (x8)...")
    sr_x8 = model_zssr(sr_x4)
    
    # resize output to match HR ground truth dimensions for calculaltions
    # LR is 175x175,  175 * 8 = 1400, but HR is 1404 x 1404
    target_h, target_w = img_HR_var.shape[2], img_HR_var.shape[3]
    
    final_sr = TF.interpolate(
        sr_x8, 
        size=(target_h, target_w), 
        mode='bicubic', 
        align_corners=False
    )
    final_sr = torch.clamp(final_sr, 0, 1)

# convert to numpy
sr_np = final_sr.squeeze(0).permute(1, 2, 0).cpu().numpy()
gt_np = img_HR_var.squeeze(0).permute(1, 2, 0).cpu().numpy()
lr_np = img_LR_var.squeeze(0).permute(1, 2, 0).cpu().numpy()

mse = np.mean((gt_np - sr_np) ** 2)
if mse == 0:
    final_psnr = 100
else:
    final_psnr = 20 * np.log10(1.0 / np.sqrt(mse))

print(f"Final ZSSR PSNR (x8): {final_psnr:.2f} dB")

plt.figure(figsize=(18, 6))

plt.subplot(1, 3, 1)
plt.imshow(lr_np)
plt.title(f"LR Input (x8 smaller)\n{lr_np.shape}")
plt.axis('off')

plt.subplot(1, 3, 2)
plt.imshow(sr_np)
plt.title(f"ZSSR Output (x8)\nPSNR: {final_psnr:.2f} dB")
plt.axis('off')

plt.subplot(1, 3, 3)
plt.imshow(gt_np)
plt.title(f"Ground Truth\n{gt_np.shape}")
plt.axis('off')

plt.show()