In [None]:
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import display, clear_output

import os
import numpy as np
from models.skip import skip
from models.downsampler import Downsampler
import torch
import torch.optim

from skimage.util import random_noise
from skimage.metrics import peak_signal_noise_ratio

from utils.common_utils import get_image, crop_image, np_to_pil, pil_to_np, np_to_torch, torch_to_np, plot_image_grid, get_image_grid, get_noise
from utils.sr_utils import load_LR_HR_imgs_sr, tv_loss


torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
dtype = torch.cuda.FloatTensor

torch.manual_seed(0);
np.random.seed(0)
save_every = 250
num_iterations = 3000

## Data

In [None]:
img_path = 'data/sr/zebra.png'

imsize = -1
dim_div_by = 64
factor = 4

downsampler = Downsampler(n_planes=3, factor=factor, kernel_type='lanczos2', phase=0.5, preserve_size=True).type(dtype)

imgs = load_LR_HR_imgs_sr(img_path, imsize, factor, 'CROP')

clean_hr_img_np = imgs['HR_np']
clean_hr_img_torch = np_to_torch(clean_hr_img_np).type(dtype)
clean_lr_img_np = np.clip(torch_to_np(downsampler(clean_hr_img_torch)), 0, 1)
clean_lr_upscale_img_np = pil_to_np(np_to_pil(clean_lr_img_np).resize(imgs['HR_pil'].size))
 
noisy_lr_img_np = random_noise(clean_lr_img_np, mode='s&p', amount=0.2)
noisy_lr_upscale_img_np = pil_to_np(np_to_pil(noisy_lr_img_np).resize(imgs['HR_pil'].size))

clean_lr_img_torch = np_to_torch(clean_lr_img_np).type(dtype)
noisy_lr_img_torch = np_to_torch(noisy_lr_img_np).type(dtype)

noise_diff_np = noisy_lr_img_np - clean_lr_img_np
noise_diff_torch = np_to_torch(noise_diff_np).type(dtype)

grid = get_image_grid([clean_hr_img_np, clean_lr_upscale_img_np, noisy_lr_upscale_img_np], 3);
plt.figure(figsize=(15, 10))
plt.imshow(grid.transpose((1, 2, 0)))
plt.axis('off')
plt.savefig('sr_example.png', dpi=300)

In [None]:
# Initialize dictionaries for images and psnr
image_dict = {}
psnr_dict = {}
iteration_dict = {}

## Clean SR

In [None]:
input_depth = 2

net = skip(
    input_depth,
    clean_hr_img_np.shape[0],
    num_channels_down = [128] * 5,
    num_channels_up   = [128] * 5,
    num_channels_skip = [4] * 5, 
    upsample_mode='bilinear',
    downsample_mode='stride',
    need_sigmoid=True, 
    need_bias=True, 
    pad='reflection',
    act_fun='LeakyReLU'
).type(dtype)

net_input = get_noise(input_depth, 'noise', clean_hr_img_np.shape[1:]).type(dtype).detach()
criterion = torch.nn.MSELoss().type(dtype)
optimizer = torch.optim.Adam(net.parameters())

for i in range(num_iterations+1):
    optimizer.zero_grad()
    out = net(net_input)
    loss = criterion(downsampler(out), clean_lr_img_torch)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.02)
    optimizer.step()
    
    if i % save_every == 0:
        clear_output(wait=True)
        out_np = torch_to_np(out)
        
        plt.imshow(np.clip(out_np, 0, 1).transpose((1, 2, 0)))
        plt.title(f'Iteration: {i}/{num_iterations}, Loss: {loss.item():0.3e}, PSNR: {peak_signal_noise_ratio(out_np, clean_hr_img_np):0.1f}')
        plt.show()

In [None]:
# Save the clean SR
clean_sr_np = torch_to_np(net(net_input))
clean_psnr = peak_signal_noise_ratio(clean_sr_np, clean_hr_img_np)

In [None]:
plot_image_grid([clean_hr_img_np, clean_lr_upscale_img_np, clean_sr_np], 3, 11);

## DCT-Lasso

In [None]:
from torch_dct import idct_2d

method = 'DCT-Lasso'

psnr_dict[method] = []
iteration_dict[method] = []

w = torch.randn_like(clean_hr_img_torch)
w.requires_grad = True
criterion = torch.nn.L1Loss().type(dtype)
optimizer = torch.optim.SGD([w], lr=5e7, momentum=0.99)
lam = 5e-3

for i in range(num_iterations+1):
    optimizer.zero_grad()
    out = idct_2d(w)
    loss = criterion(downsampler(out), noisy_lr_img_torch) + lam * torch.mean(torch.abs(w))
    loss.backward()
    optimizer.step()
    
    if i % save_every == 0:
        clear_output(wait=True)
        out_np = np.clip(torch_to_np(out), 0, 1)
        
        iteration_dict[method].append(i)
        psnr = peak_signal_noise_ratio(out_np, clean_hr_img_np)
        psnr_dict[method].append(psnr)
        
        plt.figure(figsize=(10, 5))
        plt.subplot(121)
        plt.imshow(np.clip(out_np, 0, 1).transpose((1, 2, 0)))
        plt.title(f'Iteration: {i}/{num_iterations}, Loss: {loss.item():0.3e}, PSNR: {psnr:0.1f}')
        
        plt.subplot(122)
        plt.plot(iteration_dict[method], psnr_dict[method])
        plt.xlabel('Iteration')
        plt.ylabel('PSNR')
        
        plt.show()
        
image_dict[method] = np.clip(torch_to_np(idct_2d(w)), 0, 1)

## Robust-DIP

In [None]:
input_depth = 2
method = 'Robust-DIP'

psnr_dict[method] = []
iteration_dict[method] = []

net = skip(
    input_depth, 
    clean_hr_img_np.shape[0], 
    num_channels_down = [128] * 5,
    num_channels_up   = [128] * 5,
    num_channels_skip = [0] * 5,  
    upsample_mode='nearest', 
    filter_skip_size=1, 
    filter_size_up=3, 
    filter_size_down=3,
    need_sigmoid=True, 
    need_bias=True, 
    pad='reflection', 
    act_fun='LeakyReLU'
).type(dtype)

net_input = get_noise(input_depth, 'meshgrid', clean_hr_img_np.shape[1:]).type(dtype)
criterion = torch.nn.L1Loss().type(dtype)
optimizer = torch.optim.Adam(net.parameters())

for i in range(num_iterations+1):
    optimizer.zero_grad()
    out = net(net_input)
    loss = criterion(downsampler(out), noisy_lr_img_torch)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.02)
    optimizer.step()
    
    if i % save_every == 0:
        clear_output(wait=True)
        out_np = torch_to_np(out)
        
        iteration_dict[method].append(i)
        psnr = peak_signal_noise_ratio(out_np, clean_hr_img_np)
        psnr_dict[method].append(psnr)
        
        plt.figure(figsize=(10, 5))
        plt.subplot(121)
        plt.imshow(np.clip(out_np, 0, 1).transpose((1, 2, 0)))
        plt.title(f'Iteration: {i}/{num_iterations}, Loss: {loss.item():0.3e}, PSNR: {psnr:0.1f}')
        
        plt.subplot(122)
        plt.plot(iteration_dict[method], psnr_dict[method])
        plt.xlabel('Iteration')
        plt.ylabel('PSNR')
        
        plt.show()
        
image_dict[method] = torch_to_np(net(net_input))

## TV-DIP

In [None]:
input_depth = 32
method = 'TV-DIP'

psnr_dict[method] = []
iteration_dict[method] = []

net = skip(
    input_depth,
    clean_hr_img_np.shape[0],
    num_channels_down = [128] * 5,
    num_channels_up   = [128] * 5,
    num_channels_skip = [4] * 5, 
    upsample_mode='bilinear',
    downsample_mode='stride',
    need_sigmoid=True, 
    need_bias=True, 
    pad='reflection',
    act_fun='LeakyReLU'
).type(dtype)

net_input = get_noise(input_depth, 'noise', clean_hr_img_np.shape[1:]).type(dtype).detach()
criterion = torch.nn.L1Loss().type(dtype)
optimizer = torch.optim.Adam(net.parameters())

for i in range(num_iterations+1):
    optimizer.zero_grad()
    out = net(net_input)
    loss = criterion(downsampler(out), noisy_lr_img_torch) + 2e-6 * tv_loss(out)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.02)
    optimizer.step()
    
    if i % save_every == 0:
        clear_output(wait=True)
        out_np = np.clip(torch_to_np(out), 0, 1)
        
        iteration_dict[method].append(i)
        psnr = peak_signal_noise_ratio(out_np, clean_hr_img_np)
        psnr_dict[method].append(psnr)
        
        plt.figure(figsize=(10, 5))
        plt.subplot(121)
        plt.imshow(np.clip(out_np, 0, 1).transpose((1, 2, 0)))
        plt.title(f'Iteration: {i}/{num_iterations}, Loss: {loss.item():0.3e}, PSNR: {psnr:0.1f}')
        
        plt.subplot(122)
        plt.plot(iteration_dict[method], psnr_dict[method])
        plt.xlabel('Iteration')
        plt.ylabel('PSNR')
        
        plt.show()
        
image_dict[method] = torch_to_np(net(net_input))

## CS-DODIP

In [None]:
input_depth = 32
method = 'CS-DODIP'

psnr_dict[method] = []
iteration_dict[method] = []

net = skip(
    input_depth,
    clean_hr_img_np.shape[0],
    num_channels_down = [128] * 5,
    num_channels_up   = [128] * 5,
    num_channels_skip = [4] * 5, 
    upsample_mode='bilinear',
    downsample_mode='stride',
    need_sigmoid=True, 
    need_bias=True, 
    pad='reflection',
    act_fun='LeakyReLU'
).type(dtype)

net_input = get_noise(input_depth, 'noise', clean_hr_img_np.shape[1:]).type(dtype).detach()
criterion = torch.nn.MSELoss().type(dtype)
optimizer = torch.optim.Adam(net.parameters())

r_img_cor_p_torch = torch.zeros_like(noisy_lr_img_torch).normal_()*1e-5
r_img_cor_n_torch = torch.zeros_like(noisy_lr_img_torch).normal_()*1e-5
r_img_cor_p_torch.requires_grad = True
r_img_cor_n_torch.requires_grad = True

optimizer_sop = torch.optim.SGD([r_img_cor_p_torch, r_img_cor_n_torch], lr=1000, momentum=0.99)

for i in range(num_iterations+1):
    optimizer.zero_grad()
    optimizer_sop.zero_grad()
    
    out = net(net_input)
    r_img_cor_torch = r_img_cor_p_torch ** 2 - r_img_cor_n_torch ** 2
    loss = criterion(downsampler(out) + r_img_cor_torch, noisy_lr_img_torch)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.01)
    
    optimizer.step()
    optimizer_sop.step()
    
    if i % save_every == 0:
        clear_output(wait=True)
        out_np = np.clip(torch_to_np(out), 0, 1)
        
        iteration_dict[method].append(i)
        psnr = peak_signal_noise_ratio(out_np, clean_hr_img_np)
        psnr_dict[method].append(psnr)
        
        plt.figure(figsize=(10, 5))
        plt.subplot(121)
        plt.imshow(np.clip(out_np, 0, 1).transpose((1, 2, 0)))
        plt.title(f'Iteration: {i}/{num_iterations}, Loss: {loss.item():0.3e}, PSNR: {psnr:0.1f}')
        
        plt.subplot(122)
        plt.plot(iteration_dict[method], psnr_dict[method])
        plt.xlabel('Iteration')
        plt.ylabel('PSNR')
        
        plt.show()
        
image_dict[method] = torch_to_np(net(net_input))

In [None]:
for k,v in psnr_dict.items():
    plt.plot(iteration_dict[k], v, '--', linewidth=3, label=k)
plt.hlines(clean_psnr, 99, num_iterations, linestyles='dashdot', label='No Noise', linewidth=3, colors='black')
plt.xlabel('Iteration', fontsize=15)
plt.ylabel('PSNR', fontsize=15)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.legend()
plt.savefig('sr_psnr.png', dpi=300)

In [None]:
grid = get_image_grid([clean_sr_np, image_dict['DCT-Lasso'], image_dict['Robust-DIP'], image_dict['TV-DIP'], image_dict['CS-DODIP']])
plt.figure(figsize=(15, 10))
plt.imshow(grid.transpose((1, 2, 0)))
plt.axis('off')
plt.title('[Clean, DCT-Lasso, Robust-DIP, TV-DIP, CS-DODIP]', fontsize=20)
plt.savefig('sr_result_images.png', dpi=300)

In [None]:
import pickle

result_dict = {
    'iteration': iteration_dict,
    'image': image_dict,
    'psnr': psnr_dict,
    'clean_result': clean_sr_np,
    'clean_psnr': clean_psnr
}

with open('sr_result.pl', 'wb') as handle:
    pickle.dump(result_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
image_dict = result_dict['image']
clean_sr_np = result_dict['clean_result']