In [None]:
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import display, clear_output

import os
import numpy as np
from models.skip import skip
import torch
import torch.optim

from skimage.util import random_noise
from skimage.metrics import peak_signal_noise_ratio

from utils.common_utils import get_image, crop_image, pil_to_np, np_to_torch, torch_to_np, get_image_grid, get_noise, plot_image_grid
from utils.sr_utils import tv_loss

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
dtype = torch.cuda.FloatTensor

torch.manual_seed(0);
np.random.seed(0)
save_every = 500
num_iterations = 5000

In [None]:
def sparse_noise(shape, p, sigma):
    output = sigma * torch.randn(shape).type(dtype)
    output[torch.rand(*output.shape) > p] = 0
    return output

## Data

In [None]:
img_path  = 'data/gaussian/xray.jpeg'
imsize = 256 # Size of image
dim_div_by = 64

img_pil, _ = get_image(img_path, imsize)
clean_img_np = pil_to_np(crop_image(img_pil, dim_div_by))
clean_img_torch = np_to_torch(clean_img_np).type(dtype)

signal_size = clean_img_np.size

num_measurements = 8000
p = 0.1
noise_level = 1.0

A = np.sqrt(1 / num_measurements) * torch.randn(num_measurements, signal_size).type(dtype)
clean_measurement_torch = clean_img_torch.reshape(1, -1) @ A.T
noise_torch = sparse_noise(clean_measurement_torch.shape, p=p, sigma=noise_level)
noisy_measurement_torch = clean_measurement_torch + noise_torch

In [None]:
plt.imshow(clean_img_np.transpose((1, 2, 0)), cmap='gray')
# plt.axis('off')
plt.xticks([], minor=True)
plt.yticks([], minor=True)
plt.savefig('gs_example.png', dpi=300)

In [None]:
grid = get_image_grid([clean_img_np], 3);
plt.figure(figsize=(8, 8))
plt.imshow(grid.transpose((1, 2, 0)))
plt.axis('off')
plt.savefig('gs_example.png', dpi=300)

In [None]:
# Initialize dictionaries for images and psnr
image_dict = {}
psnr_dict = {}
iteration_dict = {}

## Clean

In [None]:
input_depth = 64

net = skip(
    input_depth,
    num_output_channels=1,
    num_channels_down = [8, 16, 32, 64, 128], 
    num_channels_up   = [8, 16, 32, 64, 128],
    num_channels_skip = [0, 0, 0, 4, 4],
    upsample_mode='bilinear',
    need_sigmoid=True,
).type(dtype)

net_input = torch.randn(1, input_depth, imsize, imsize).type(dtype)
criterion = torch.nn.MSELoss().type(dtype)
optimizer = torch.optim.Adam(net.parameters(), lr=0.1)

for i in range(num_iterations+1):
    optimizer.zero_grad()
    out = net(net_input)
    loss = criterion(torch.matmul(out.reshape(1, -1), A.T), clean_measurement_torch)
    loss.backward()
#     torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.02)
    optimizer.step()
    
    if i % save_every == 0:
        clear_output(wait=True)
        out_np = torch_to_np(out)
        plt.imshow(out_np.transpose((1, 2, 0)), cmap='gray')
        plt.title(f'Iteration: {i}/{num_iterations}, Loss: {loss.item():0.3e}, PSNR: {peak_signal_noise_ratio(out_np, clean_img_np):0.1f}')
        plt.show()

In [None]:
# Save the clean inversion
clean_recon_np = torch_to_np(net(net_input))
clean_psnr = peak_signal_noise_ratio(out_np, clean_img_np)

In [None]:
plot_image_grid([clean_img_np, clean_recon_np], 3, 11);

## DCT-Lasso

In [None]:
from torch_dct import idct_2d

method = 'DCT-Lasso'

psnr_dict[method] = []
iteration_dict[method] = []

w = torch.randn_like(clean_img_torch)
w.requires_grad = True
criterion = torch.nn.L1Loss().type(dtype)
optimizer = torch.optim.SGD([w], lr=1e6, momentum=0.99)
lam = 1e-1

for i in range(num_iterations+1):
    optimizer.zero_grad()
    out = idct_2d(w)
    loss = criterion(torch.matmul(out.reshape(1, -1), A.T), noisy_measurement_torch) + lam * torch.mean(torch.abs(w))
    loss.backward()
    optimizer.step()
    
    if i % save_every == 0:
        clear_output(wait=True)
        out_np = np.clip(torch_to_np(out), 0, 1)
        
        iteration_dict[method].append(i)
        psnr = peak_signal_noise_ratio(out_np, clean_img_np)
        psnr_dict[method].append(psnr)
        
        plt.figure(figsize=(10, 5))
        plt.subplot(121)
        plt.imshow(np.clip(out_np, 0, 1).transpose((1, 2, 0)), cmap='gray')
        plt.title(f'Iteration: {i}/{num_iterations}, Loss: {loss.item():0.3e}, PSNR: {psnr:0.1f}')
        
        plt.subplot(122)
        plt.plot(iteration_dict[method], psnr_dict[method])
        plt.xlabel('Iteration')
        plt.ylabel('PSNR')
        
        plt.show()
        
image_dict[method] = np.clip(torch_to_np(idct_2d(w)), 0, 1)

## Robust-DIP

In [None]:
method = 'Robust-DIP'

psnr_dict[method] = []
iteration_dict[method] = []

input_depth = 64

net = skip(
    input_depth,
    num_output_channels=1,
    num_channels_down = [8, 16, 32, 64, 128], 
    num_channels_up   = [8, 16, 32, 64, 128],
    num_channels_skip = [0, 0, 0, 4, 4],
    upsample_mode='bilinear',
    need_sigmoid=True,
).type(dtype)

net_input = torch.randn(1, input_depth, imsize, imsize).type(dtype)
criterion = torch.nn.L1Loss().type(dtype)
optimizer = torch.optim.Adam(net.parameters(), lr=0.1)

for i in range(num_iterations+1):
    optimizer.zero_grad()
    out = net(net_input)
    loss = criterion(torch.matmul(out.reshape(1, -1), A.T), noisy_measurement_torch)
    loss.backward()
#     torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.02)
    optimizer.step()
    
    if i % save_every == 0:
        clear_output(wait=True)
        out_np = np.clip(torch_to_np(out), 0, 1)
        
        iteration_dict[method].append(i)
        psnr = peak_signal_noise_ratio(out_np, clean_img_np)
        psnr_dict[method].append(psnr)
        
        plt.figure(figsize=(10, 5))
        plt.subplot(121)
        plt.imshow(np.clip(out_np, 0, 1).transpose((1, 2, 0)), cmap='gray')
        plt.title(f'Iteration: {i}/{num_iterations}, Loss: {loss.item():0.3e}, PSNR: {psnr:0.1f}')
        
        plt.subplot(122)
        plt.plot(iteration_dict[method], psnr_dict[method])
        plt.xlabel('Iteration')
        plt.ylabel('PSNR')
        
        plt.show()
        
image_dict[method] = torch_to_np(net(net_input))

## TV-DIP

In [None]:
method = 'TV-DIP'

psnr_dict[method] = []
iteration_dict[method] = []

input_depth = 64

net = skip(
    input_depth,
    num_output_channels=1,
    num_channels_down = [8, 16, 32, 64, 128], 
    num_channels_up   = [8, 16, 32, 64, 128],
    num_channels_skip = [0, 0, 0, 4, 4],
    upsample_mode='bilinear',
    need_sigmoid=True,
).type(dtype)

net_input = torch.randn(1, input_depth, imsize, imsize).type(dtype)
criterion = torch.nn.L1Loss().type(dtype)
optimizer = torch.optim.Adam(net.parameters(), lr=0.1)

for i in range(num_iterations+1):
    optimizer.zero_grad()
    out = net(net_input)
    loss = criterion(torch.matmul(out.reshape(1, -1), A.T), noisy_measurement_torch) + 1e-4 * tv_loss(out)
    loss.backward()
#     torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.02)
    optimizer.step()
    
    if i % save_every == 0:
        clear_output(wait=True)
        out_np = np.clip(torch_to_np(out), 0, 1)
        
        iteration_dict[method].append(i)
        psnr = peak_signal_noise_ratio(out_np, clean_img_np)
        psnr_dict[method].append(psnr)
        
        plt.figure(figsize=(10, 5))
        plt.subplot(121)
        plt.imshow(np.clip(out_np, 0, 1).transpose((1, 2, 0)), cmap='gray')
        plt.title(f'Iteration: {i}/{num_iterations}, Loss: {loss.item():0.3e}, PSNR: {psnr:0.1f}')
        
        plt.subplot(122)
        plt.plot(iteration_dict[method], psnr_dict[method])
        plt.xlabel('Iteration')
        plt.ylabel('PSNR')
        
        plt.show()
        
image_dict[method] = torch_to_np(net(net_input))

## CS-DODIP

In [None]:
input_depth = 2
method = 'CS-DODIP'

psnr_dict[method] = []
iteration_dict[method] = []

net = skip(
    input_depth, 
    clean_img_np.shape[0], 
    num_channels_down = [128] * 5,
    num_channels_up   = [128] * 5,
    num_channels_skip = [0] * 5,  
    upsample_mode='nearest', 
    filter_skip_size=1, 
    filter_size_up=3, 
    filter_size_down=3,
    need_sigmoid=True, 
    need_bias=True, 
    pad='reflection', 
    act_fun='LeakyReLU'
).type(dtype)

net_input = get_noise(input_depth, 'meshgrid', clean_img_np.shape[1:]).type(dtype)
criterion = torch.nn.MSELoss().type(dtype)
optimizer = torch.optim.Adam(net.parameters(), lr=0.1)

r_img_cor_p_torch = torch.zeros_like(noisy_measurement_torch).normal_()*1e-5
r_img_cor_n_torch = torch.zeros_like(noisy_measurement_torch).normal_()*1e-5
r_img_cor_p_torch.requires_grad = True
r_img_cor_n_torch.requires_grad = True

optimizer_sop = torch.optim.SGD([r_img_cor_p_torch, r_img_cor_n_torch], lr=100)

for i in range(num_iterations+1):
    optimizer.zero_grad()
    optimizer_sop.zero_grad()
    
    out = net(net_input)
    r_img_cor_torch = r_img_cor_p_torch ** 2 - r_img_cor_n_torch ** 2
    loss = criterion(torch.matmul(out.reshape(1, -1), A.T) + r_img_cor_torch, noisy_measurement_torch)
    loss.backward()
#     torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.01)
    
    optimizer.step()
    optimizer_sop.step()
    
    if i % save_every == 0:
        clear_output(wait=True)
        out_np = np.clip(torch_to_np(out), 0, 1)
        
        iteration_dict[method].append(i)
        psnr = peak_signal_noise_ratio(out_np, clean_img_np)
        psnr_dict[method].append(psnr)
        
        plt.figure(figsize=(10, 5))
        plt.subplot(121)
        plt.imshow(np.clip(out_np, 0, 1).transpose((1, 2, 0)), cmap='gray')
        plt.title(f'Iteration: {i}/{num_iterations}, Loss: {loss.item():0.3e}, PSNR: {psnr:0.1f}')
        
        plt.subplot(122)
        plt.plot(iteration_dict[method], psnr_dict[method])
        plt.xlabel('Iteration')
        plt.ylabel('PSNR')
        plt.title(f'Noise Diff: {torch.sum(torch.abs(r_img_cor_torch - noise_torch)):0.3f}')
        
        plt.show()
        
image_dict[method] = torch_to_np(net(net_input))

In [None]:
for k,v in psnr_dict.items():
    plt.plot(iteration_dict[k], v, '--', linewidth=3, label=k)
plt.hlines(clean_psnr, 99, num_iterations, linestyles='dashdot', label='No Noise', linewidth=3, colors='black')
plt.xlabel('Iteration', fontsize=15)
plt.ylabel('PSNR', fontsize=15)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.legend()
plt.savefig('gs_psnr.png', dpi=300)

In [None]:
grid = get_image_grid([clean_recon_np, image_dict['DCT-Lasso'], image_dict['Robust-DIP'], image_dict['TV-DIP'], image_dict['CS-DODIP']])
plt.figure(figsize=(15, 10))
plt.imshow(grid.transpose((1, 2, 0)))
plt.axis('off')
plt.title('[Clean, DCT-Lasso, Robust-DIP, TV-DIP, CS-DODIP]', fontsize=20)
plt.savefig('gs_result_images.png', dpi=300)

In [None]:
import pickle

result_dict = {
    'iteration': iteration_dict,
    'image': image_dict,
    'psnr': psnr_dict,
    'clean_result': clean_recon_np,
    'clean_psnr': clean_psnr
}

with open('gs_result.pl', 'wb') as handle:
    pickle.dump(result_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
grid = get_image_grid([clean_recon_np, image_dict['DCT-Lasso'], image_dict['Robust-DIP'], image_dict['TV-DIP'], image_dict['CS-DODIP']])
plt.figure(figsize=(15, 10))
plt.imshow(grid.transpose((1, 2, 0)))
plt.axis('off')
plt.title('[Clean, DCT-Lasso, Robust-DIP, TV-DIP, CS-DODIP]', fontsize=20)
plt.savefig('gs_result_images.png', dpi=300)

In [None]:
import pickle

with open('gs_result.pl', 'rb') as handle:
    result_dict = pickle.load(handle)

In [None]:
image_dict = result_dict['image']
clean_recon_np = result_dict['clean_result']