Code inspired by https://github.com/UoB-CS-AVAI/Week2-train-Deep-Neural-Network-to-denoise-image and https://github.com/DmitryUlyanov/deep-image-prior

The model architectures (skip, unet, resnet) and utils functions are directly copied from these githubs

In [None]:
%pip install torch torchvision pillow scikit-image lpips matplotlib 


Collecting models
  Using cached models-0.9.3.tar.gz (16 kB)
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'error'
Note: you may need to restart the kernel to use updated packages.


  error: subprocess-exited-with-error
  
  × python setup.py egg_info did not run successfully.
  │ exit code: 1
  ╰─> [8 lines of output]
      Traceback (most recent call last):
        File "<string>", line 2, in <module>
        File "<pip-setuptools-caller>", line 35, in <module>
        File "C:\Users\blobf.DESKTOP-IUEL8R6\AppData\Local\Temp\pip-install-8q3wi8ft\models_0d20757c3ae2494c8a442aa4f67535a7\setup.py", line 25, in <module>
          import models
        File "C:\Users\blobf.DESKTOP-IUEL8R6\AppData\Local\Temp\pip-install-8q3wi8ft\models_0d20757c3ae2494c8a442aa4f67535a7\models\__init__.py", line 23, in <module>
          from base import *
      ModuleNotFoundError: No module named 'base'
      [end of output]
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
error: metadata-generation-failed

× Encountered error while generating package metadata.
╰─> See above for output.

note: This is an issue with the package mentioned above, n

Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torchvision import transforms
from PIL import Image
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import lpips
import glob
import os
import math
import random

from models import *
from utils import *
from utils.sr_utils import * 
from utils.common_utils import *

dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")


Get LR dataset and HR dataset (for ground truths)

In [None]:
class DIV2KDataset(Dataset):
    def __init__(self, hr_dir, lr_dir):
        self.hr_paths = sorted(glob.glob(os.path.join(hr_dir, '*.png')))
        self.lr_paths = sorted(glob.glob(os.path.join(lr_dir, '*.png')))

        self.transform = transforms.ToTensor() # PIL to Tensor
    
    def __len__(self): # dataloader needs access to length
        return len(self.hr_paths)

    def __getitem__(self, index): # dataloader needs access to dataset items by index
        lr_img = self.transform(Image.open(self.lr_paths[index])) # DIV2K is RGB images
        hr_img = self.transform(Image.open(self.hr_paths[index]))
        return lr_img, hr_img

BASE_DIR = Path.cwd()
DATA_DIR = BASE_DIR / 'data'

train_dataset = DIV2KDataset(
    hr_dir=str(DATA_DIR / 'DIV2K_train_HR'),
    lr_dir=str(DATA_DIR / 'DIV2K_train_LR_x8')
)

val_dataset = DIV2KDataset(
    hr_dir=str(DATA_DIR / 'DIV2K_valid_HR'),
    lr_dir=str(DATA_DIR / 'DIV2K_valid_LR_x8')
)

Get data - 1 randomly selected image (LR as x0 and corresponding HR for PSNR)

In [None]:
# select a random index
index = random.randint(0, len(train_dataset) - 1)
print(f"Selecting image index: {index} from dataset")

# get the LR and HR images at that index
img_LR_tensor, img_HR_tensor = train_dataset[index]

# convert from [C, H,W] to [1, C, H, W] and move to GPU
img_LR_var = img_LR_tensor.unsqueeze(0).to(device)
img_HR_var = img_HR_tensor.unsqueeze(0).to(device)

print(f"HR Image Shape: {img_HR_var.shape}")
print(f"LR Input Shape: {img_LR_var.shape}")

Define network hyperparameters

In [None]:
INPUT = 'noise' # choice of 'noise' or 'meshgrid' - just use noise for DIP
pad = 'reflection' # choice of padding type, which 
OPT_OVER = 'net' # 'net' - optimise the network weights. 'net,input' - optimise the noise too. 

reg_noise_std = 1./30. # std of noise added to input at each iteration
LR = 0.01 # learning rate for optimizer
OPTIMIZER = 'adam' # 'adam' or 'LBFGS'
show_every = 100 # how often to show results
num_iter = 2000 # total iterations 
input_depth = 32 # number of channels in input noise 
figsize = 4 # figure size for plotting

NET_TYPE = 'skip' # choice of 'skip', 'resnet' or 'unet'

Define input noise, z - dimensions of HR image

In [None]:
net_input = get_noise(input_depth, INPUT, (img_HR_var.shape[2], img_HR_var.shape[3])).type(dtype).detach()

print("Input noise shape:", net_input.shape)

Define network architecture (skip, resnet or unet) and loss function

In [None]:
def get_net(input_depth, NET_TYPE, pad, skip_n33d=128, skip_n33u=128, skip_n11=4, num_scales=5, upsample_mode='bilinear'):
    if NET_TYPE == 'skip':
        return skip(input_depth, 3, 
               num_channels_down = [skip_n33d] * num_scales, 
               num_channels_up =   [skip_n33u] * num_scales,
               num_channels_skip =    [skip_n11] * num_scales, 
               filter_size_up = 3, filter_size_down = 3, 
               upsample_mode=upsample_mode, filter_skip_size=1,
               need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU')
    elif NET_TYPE == 'resnet':
        return resnet(input_depth, 3, num_channels=128, num_blocks=8, act_fun='LeakyReLU')
    elif NET_TYPE == 'unet':
        return unet(input_depth, 3, num_channels=[128]*5, act_fun='LeakyReLU')
    else:
        assert False

# initialise network
net = get_net(input_depth, NET_TYPE, pad,
              skip_n33d=128,
              skip_n33u=128,
              skip_n11=4,
              num_scales=5,
              upsample_mode='bilinear').type(dtype)

# print number of parameters
s  = sum([np.prod(list(p.size())) for p in net.parameters()]); 
print ('Number of params: %d' % s)

# loss function
mse = torch.nn.MSELoss().type(dtype)


Define degradation function, H - known to be bicubic x8 downsampling

In [None]:
def degradation_operator(hr_tensor): 
    return torch.nn.functional.interpolate(
        hr_tensor, # the tensor of the HR x_hat output by our model
        scale_factor=1/8, 
        mode='bicubic', 
        align_corners=True
    )

Get loss: MSE between x0 (original LR image) and model output x_hat downsampled by H

In [None]:
mse = torch.nn.MSELoss().type(dtype)

Training loop - 

In [None]:
# Initialize LPIPS metric (make sure to import lpips first)
loss_fn_lpips = lpips.LPIPS(net='alex').to(device)

# Initialize history lists
psnr_history = [] 
ssim_history = []
lpips_history = []

net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()

i = 0

# define closure function (runs once per iteration inside "optimize")
def closure():
    global i, net_input, net_input_saved, noise
    
    # add noise to input for regularisation
    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)
    
    # forward pass to generate HR estimate
    out_HR = net(net_input)
    
    # downsample the hr output to make it the same dimensions as our lr input
    out_LR = degradation_operator(out_HR)
    
    # compare downsampled model output vs original lr input
    total_loss = mse(out_LR, img_LR_var)
    
    # backpropagate loss
    total_loss.backward()

    if i % show_every == 0:
        out_HR_np = torch_to_np(out_HR)
        img_HR_np = torch_to_np(img_HR_var)

        # calculate PSNR
        mse_val = np.mean((img_HR_np - out_HR_np) ** 2)
        psnr_HR = 10 * np.log10(1 / mse_val)
        psnr_history.append(psnr_HR)

        # calculate SSIM
        # Transpose (C, H, W) -> (H, W, C) for SSIM
        current_ssim = ssim(
            img_HR_np.transpose(1, 2, 0), 
            out_HR_np.transpose(1, 2, 0), 
            data_range=1.0, 
            channel_axis=2
        )
        ssim_history.append(current_ssim)

        # calculate LPIPS
        # normalise from [0, 1] to [-1, 1] for LPIPS
        with torch.no_grad():
            current_lpips = loss_fn_lpips(
                out_HR * 2 - 1, 
                img_HR_var * 2 - 1
            ).item()
        lpips_history.append(current_lpips)
        
        print(f'Iter {i:05d} | Loss {total_loss.item():.6f} | PSNR: {psnr_HR:.2f} | SSIM: {current_ssim:.4f} | LPIPS: {current_lpips:.4f}')

    i += 1
    
    return total_loss

p = get_params(OPT_OVER, net, net_input)
optimize(OPTIMIZER, p, closure, LR, num_iter)

# get final result
out_HR_final = np.clip(torch_to_np(net(net_input)), 0, 1)


Visualisation

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(18, 6))

#display hr
ax[0].imshow(torch_to_np(img_HR_var).transpose(1, 2, 0))
ax[0].set_title(f"Ground Truth HR")
ax[0].axis('off')

# lr
display_LR = torch.nn.functional.interpolate(img_LR_var, size=img_HR_var.shape[2:], mode='nearest')
ax[1].imshow(torch_to_np(display_LR).transpose(1, 2, 0))
ax[1].set_title(f"Input LR (x8 bicubic down)")
ax[1].axis('off')

# DIP output
ax[2].imshow(out_HR_final.transpose(1, 2, 0))
ax[2].set_title(f"DIP Output HR\nPSNR: {psnr_history[-1]:.2f} dB")
ax[2].axis('off')

plt.show()

# Plot PSNR
ax[0].plot(psnr_history, color='g')
ax[0].set_title("PSNR (Higher is better)")
ax[0].set_xlabel("Iteration (x100)")
ax[0].set_ylabel("dB")
ax[0].grid(True, alpha=0.3)

# Plot SSIM
ax[1].plot(ssim_history, color='b')
ax[1].set_title("SSIM (Higher is better)")
ax[1].set_xlabel("Iteration (x100)")
ax[1].set_ylabel("Index")
ax[1].grid(True, alpha=0.3)

# Plot LPIPS
ax[2].plot(lpips_history, color='r')
ax[2].set_title("LPIPS (Lower is better)")
ax[2].set_xlabel("Iteration (x100)")
ax[2].set_ylabel("Distance")
ax[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()