Code for **"Blind restoration of a JPEG-compressed image"** and **"Blind image denoising"** figures. Select `fname` below to switch between the two.

- To see overfitting set `num_iter` to a large value.

In [None]:
"""
*Uncomment if running on colab* 
Set Runtime -> Change runtime type -> Under Hardware Accelerator select GPU in Google Colab 
"""
# !git clone https://github.com/ubc-vision/juho-usra.git
# !mv juho-usra/* ./

# Import libs

In [None]:
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline

import os

import numpy as np

import torch
import torch.nn as nn
import torch.optim

## compare_psnr is renamed to peak_signal_noise_ratio
## from skimage.measure import compare_psnr
from skimage.metrics import peak_signal_noise_ratio

# use tensorboard with pytorch
from torch.utils.tensorboard import SummaryWriter

from models import *
from utils.denoising_utils import *

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = None
if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
else:
    dtype = torch.FloatTensor

imsize =-1
PLOT = True
sigma = 25
sigma_ = sigma/255.

In [None]:
# deJPEG 
fname = 'data/denoising/snail.jpg'

## denoising
# fname = 'data/denoising/F16_GT.png'

# Load image

In [None]:
if fname == 'data/denoising/snail.jpg':
    img_noisy_pil = crop_image(get_image(fname, imsize)[0], d=32)
    img_noisy_np = pil_to_np(img_noisy_pil)
    
    # As we don't have ground truth
    img_pil = img_noisy_pil
    img_np = img_noisy_np
    
    if PLOT:
        plot_image_grid([img_np], 4, 5);
        
elif fname == 'data/denoising/F16_GT.png':
    # Add synthetic noise
    img_pil = crop_image(get_image(fname, imsize)[0], d=32)
    img_np = pil_to_np(img_pil)
    
    img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma_)
    
    if PLOT:
        plot_image_grid([img_np, img_noisy_np], 4, 6);
else:
    assert False

# Print out layers in Pytorch Sequential()

In [None]:
######## To print layer outputs ########
class PrintLayer(nn.Module):
    def __init__(self):
        super(PrintLayer, self).__init__()
                    
    def forward(self, x):
        # Do your print / debug stuff here
        print(x)
        return x

########################################

printer = PrintLayer()

# Setup

In [None]:
INPUT = 'noise' # 'meshgrid'
pad = 'reflection'
OPT_OVER = 'net' # 'net,input'

reg_noise_std = 1./30. # set to 1./20. for sigma=50
LR = 0.01

show_every = 100
exp_weight=0.99

if fname == 'data/denoising/snail.jpg':
    num_iter = 2400
    input_depth = 3
    figsize = 5 
    
    net = Net(
                input_depth, 3, 
                channels_down = [8, 16, 32, 64, 128], 
                channels_up   = [8, 16, 32, 64, 128],
                channels_skip = [0, 0, 0, 4, 4], 
                kernel_size_down = [3, 3, 3, 3, 3],
                kernel_size_up = [3, 3, 3, 3, 3],
                upsample_mode='bilinear',
                need_sigmoid=True, need_bias=True, pad=pad)

    net = net.type(dtype)

elif fname == 'data/denoising/F16_GT.png':
    num_iter = 3000
    input_depth = 32 
    figsize = 4 
    
    net = Net (
                input_depth, 3, 
                channels_down = [128, 128, 128, 128, 128],
                channels_up = [128, 128, 128, 128, 128],
                channels_skip = [4, 4, 4, 4, 4],
                kernel_size_down = [3, 3, 3, 3, 3],
                kernel_size_up = [3, 3, 3, 3, 3],
                upsample_mode = 'bilinear',
                need_sigmoid=True, need_bias=True, pad=pad)

else:
    assert False

# to print out the dimension of each layer of an neural network 
printer(net)

# Compute number of parameters
s  = sum([np.prod(list(p.size())) for p in net.parameters()]); 
print ('Number of params: %d' % s)

# Train

In [None]:
net_input = get_noise(input_depth, INPUT, (img_pil.size[1], img_pil.size[0])).type(dtype).detach()
img_noisy_torch = np_to_torch(img_noisy_np).type(dtype)

net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()
out_avg = None
last_net = None
psrn_noisy_last = 0


def train():
    global out_avg, psrn_noisy_last, last_net, net_input

    # Create optimizier
    parameters = get_params(OPT_OVER, net, net_input)
    optimizer = torch.optim.Adam(parameters, lr=LR)

    # Loss
    loss = nn.MSELoss().type(dtype)    

    # tensorboard log directory 
    log_dir = "./logs/denoising/train"

    # Create summary writer
    writer = SummaryWriter(log_dir)

    # Create log directory and save directory if it does not exist
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # Training loop
    for i in range(num_iter):
        
        if reg_noise_std > 0:
            net_input = net_input_saved + (noise.normal_() * reg_noise_std)
        
        # Apply the model to obtain scores (forward pass)
        out = net.forward(net_input)
        
        # Smoothing 
        if out_avg is None:
            out_avg = out.detach()
        else:
            out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight)

        # Compute the loss        
        total_loss = loss(out, img_noisy_torch)
        # Compute gradients
        total_loss.backward()
        # Update parameters
        optimizer.step()
        # Zero the parameter gradients in the optimizer
        optimizer.zero_grad()
            
        psrn_noisy = peak_signal_noise_ratio(img_noisy_np, out.detach().cpu().numpy()[0]) 
        psrn_gt    = peak_signal_noise_ratio(img_np, out.detach().cpu().numpy()[0]) 
        psrn_gt_sm = peak_signal_noise_ratio(img_np, out_avg.detach().cpu().numpy()[0]) 
        
        # Image plot and monitor results
        if  PLOT and i % show_every == 0:
            # Write output image to tensorboard, using keywords `image_output`
            writer.add_image("image_output", out, global_step=i, dataformats='NCHW')
            # Write loss to tensorboard, using keywords `loss`
            writer.add_scalar("loss", total_loss, global_step=i)
        
        # Backtracking
        if i % show_every:
            if psrn_noisy - psrn_noisy_last < -5: 
                print('Falling back to previous checkpoint.')

                for new_param, net_param in zip(last_net, net.parameters()):
                    net_param.data.copy_(new_param.cuda())

                return total_loss*0
            else:
                last_net = [x.detach().cpu() for x in net.parameters()]
                psrn_noisy_last = psrn_noisy

        
train()

# Result

In [None]:
out_np = torch_to_np(net(net_input))
q = plot_image_grid([np.clip(out_np, 0, 1), img_np], factor=13);