In [None]:
try:
    from google.colab import drive, files
    drive.mount("/content/gdrive")
    %cd gdrive/My Drive/Colab\ Notebooks/master-thesis-code/
    !git pull
    %pip install -r requirements.txt
except:
    print("working locally")

In [20]:
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline

import os
#os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import random
import time

import numpy as np
import cv2
from models.decoder import decoder
#from models.common import ProbabilityDropout

import torch
import torch.optim

from skimage.measure import compare_psnr, compare_ssim
from utils.denoising_utils import *
from utils.bayesian_utils import *
from utils.common_utils import init_normal

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

imsize = (320, 320) # -1
PLOT = True

p_noise = 0.1
sigma = p_noise * 255
sigma_ = p_noise

In [3]:
# deJPEG
# fname = 'data/denoising/snail.jpg'

# denoising
# fname = 'data/denoising/F16_GT.png'
fname = 'GP_DIP/data/denoising/Dataset/image_Peppers512rgb.png'

# shepp logan
# fname = "data/bayesian/SheppLogan_Phantom.png"

In [4]:
# Add synthetic noise
img_pil = crop_image(get_image(fname, imsize)[0], d=32)
img_np = pil_to_np(img_pil)
img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma_)

if PLOT:
    plot_image_grid([img_np, img_noisy_np], 4, 6)

<Figure size 576x1296 with 0 Axes>

In [15]:
num_iter = 25000
num_input_channels = 128
figsize = 4

num_scales = 5
num_channels = [128] * num_scales
num_channels_noise = [4] * num_scales
num_output_channels = 3
upsample_mode = 'bilinear'
need1x1 = True

dm = '2d'
dp = 0.2

In [9]:
INPUT = 'noise' # 'meshgrid'
pad = 'reflection'
OPT_OVER = 'net' # 'net,input'

reg_noise_std = 1./30. # set to 1./20. for sigma=50
LR = 0.01

OPTIMIZER = 'adam' # 'adamw' 'LBFGS'
LOSS = 'mse' # 'nll'
show_every = 100
exp_weight = 0.99

net_input = get_noise(num_input_channels, INPUT, (int(img_pil.size[1]/32), int(img_pil.size[0]/32))).type(dtype).detach()

# Loss
if LOSS == 'mse':
    mse = torch.nn.MSELoss().type(dtype)

img_noisy_torch = np_to_torch(img_noisy_np).type(dtype)

In [11]:
net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()
out_avg = None
last_net = None
psnr_noisy_last = 0
BACKTRACKING = True

burnin_iter = 11000
MCMC_iter = 50
    
#i = 0
def closure():

    global i, out_avg, psnr_noisy_last, last_net, net_input, losses, psnrs, ssims, img_mean, sample_count #, relus

    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)
    
    out = net(net_input)
    
    # Smoothing
    if out_avg is None:
        out_avg = out.detach()
    else:
        out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight)
    
    if LOSS == 'mse':
        _loss = mse(out[:,:3], img_noisy_torch)
    elif LOSS == 'nll':
        if num_channels > 6:
            # shall i convert to grayscale?
            mu = torch.tensor([out[0::3].mean(axis=0), out[1::3].mean(axis=0), out[2::3].mean(axis=0)])
            # do i need log here?
            logvar = torch.tensor([out[0::3].var(axis=0), out[1::3].var(axis=0), out[2::3].var(axis=0)])
        else:
            mu = out[:,:3]
            logvar = out[:,3:]
            
        _loss = gaussian_nll(mu, logvar, img_noisy_torch)
        
    losses.append(_loss)
    _loss.backward()

    _out = out.detach().cpu().numpy()[0,:3]
    _out_avg = out_avg.detach().cpu().numpy()[0,:3]

    psnr_noisy = compare_psnr(img_noisy_np, _out)
    psnr_gt    = compare_psnr(img_np, _out)
    psnr_gt_sm = compare_psnr(img_np, _out_avg)
    
    if i % MCMC_iter == 0 and i > burnin_iter:
        img_mean += _out
        sample_count += 1
    
    _out = swap_channels(_out)
    _out_avg = swap_channels(_out_avg)
    ssim_noisy = compare_ssim(swap_channels(img_noisy_np), _out, multichannel=True)
    ssim_gt    = compare_ssim(swap_channels(img_np), _out, multichannel=True)
    ssim_gt_sm = compare_ssim(swap_channels(img_np), _out_avg, multichannel=True)
    
    psnrs.append([psnr_noisy, psnr_gt, psnr_gt_sm])
    ssims.append([ssim_noisy, ssim_gt, ssim_gt_sm])

    # Note that we do not have GT for the "snail" example
    # So 'PSRN_gt', 'PSNR_gt_sm' make no sense
    print('Iteration: {}    Loss: {}   PSNR_noisy: {}   PSRN_gt: {} PSNR_gt_sm: {}'.format(i, _loss.item(), psnr_noisy, psnr_gt, psnr_gt_sm))
    
    if  PLOT and i % show_every == 0:
        #out_np = torch_to_np(out)

        #img_mean, ale_exp, epi, uncert = calc_uncert(net, net_input_saved, noise, reg_noise_std)
        #print('uncert {}    ale {}    epi {}'.format(uncert.mean().item(), ale_exp.mean().item(), epi.mean().item()))
        
        img_list_np = []

        with torch.no_grad():
            net_input = net_input_saved + (noise.normal_() * reg_noise_std)
            for _ in range(25):
                img = net(net_input)
                img_list_np.append(torch_to_np(img)[:3])

        #uncertainty_map = np.mean(np.std(np.array(img_list_np), axis=0), axis = 0)

        #np_plot(uncertainty_map, 'Uncertainty map', opt = 'map')
        
        out_np = np.mean(np.array(img_list_np), axis=0)
        
        print('#################')
        psnr_noisy = compare_psnr(img_noisy_np, out_np)
        psnr_gt    = compare_psnr(img_np, out_np)
        if sample_count != 0:
            psnr_mean = compare_psnr(img_np, img_mean / sample_count)
        else:
            psnr_mean = 0
        print('psnr_noisy_avg: {} | psnr_gt_avg: {} | psnr_mean: {}'.format(psnr_noisy, psnr_gt, psnr_mean))
        print('###################')
        
        to_plot = [np.clip(out_np[:3], 0, 1), np.clip(torch_to_np(out_avg)[:3], 0, 1)]
        if i > burnin_iter:
            to_plot.append(np.clip(img_mean / sample_count, 0, 1))

        plot_image_grid(to_plot, factor=figsize, nrow=1)

    # Backtracking
    # we want to leave it till the end and see if net overfits
    if BACKTRACKING:
        if i % show_every:
            if psnr_noisy - psnr_noisy_last < -5:
                print('Falling back to previous checkpoint.')

                for new_param, net_param in zip(last_net, net.parameters()):
                    net_param.data.copy_(new_param.cuda())

                return total_loss*0
            else:
                last_net = [x.detach().cpu() for x in net.parameters()]
                psnr_noisy_last = psnr_noisy

    i += 1

    return _loss

In [22]:
net = decoder(num_input_channels, num_output_channels, num_channels, num_channels_noise, 
              need1x1=need1x1, dm=dm, dp=dp).type(dtype)
    
net.apply(init_normal)

# Compute number of parameters
s  = sum([np.prod(list(p.size())) for p in net.parameters()]);
print ('Number of params: %d' % s)

losses = []
psnrs = []
ssims = []

img_mean = 0
sample_count = 0

i = 0

p = get_params(OPT_OVER, net, net_input)
optimize(OPTIMIZER, p, closure, LR, num_iter)

out_mean = img_mean / sample_count
final_psnr = compare_psnr(img_np, out_mean)
print('Final psnr: {}'.format(final_psnr))

#out_np = torch_to_np(net(net_input))[:3]

Number of params: 84227
Starting optimization with ADAM


  "See the documentation of nn.Upsample for details.".format(mode))


Iteration: 0    Loss: 0.09933929890394211   PSNR_noisy: 10.028786178152805   PSRN_gt: 10.325058141233253 PSNR_gt_sm: 10.325058141233253


KeyboardInterrupt: 

In [None]:
q = plot_image_grid([img_np, out_mean, _out_avg, out_np], factor=13)

In [None]:
fig, ax0 = plt.subplots(1, 1)

ax0.plot(range(len(losses)), losses)
ax0.set_title(LOSS)
ax0.set_xlabel('iteration')
ax0.set_ylabel('mse loss')
#ax0.set_ylim(0,0.04)
    
plt.show()

In [None]:
fig, axs = plt.subplots(1, 3, constrained_layout=True)
labels = ["psnr_noisy", "psnr_gt", "psnr_gt_sm"]

_psnrs = np.array(psnrs)
for i in range(_psnrs.shape[1]):
    axs[i].plot(range(_psnrs.shape[0]), _psnrs[:,i])
    axs[i].set_title(labels[i])
    axs[i].set_xlabel('iteration')
    axs[i].set_ylabel('psnr')
    
plt.show()