# Import libs

In [None]:
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline

import os
#os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import numpy as np
from models.resnet import ResNet
from models.unet import UNet
from models.skip import skip
from models import get_net
import torch
import torch.optim

from skimage.metrics import peak_signal_noise_ratio
from util.common_utils import * 
from util.loss import total_variation

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor

PLOT = True
imsize=-1
dim_div_by = 64
dtype = torch.cuda.FloatTensor

# Choose figure

In [None]:
base_dir = './data/human'
# base_dir = './data/power_supply'
img_dir = os.path.join(base_dir, "images_lpn")
# img_dir = os.path.join(base_dir, "images_lpn_2")
img_clean_dir = os.path.join(base_dir, "images_clean")
file_list = os.listdir(img_dir)

img_np_list = []
img_gpu_list = []
img_clean_list = []

for f in file_list:
    _, img_np = get_image(os.path.join(img_dir, f), imsize)
    img_np = img_np[0:1, :, :] 
    img_np_list.append(img_np)
    img_gpu_list.append(np_to_torch(img_np).type(dtype))

    _, img_clean = get_image(os.path.join(img_clean_dir, f), imsize)
    img_clean = img_clean[0:1, :, :] 
    img_clean_list.append(img_clean)

print("img_np shape: ", img_np_list[0].shape)

_ = plot_image_grid([img_clean_list[0], img_np_list[0]], factor=4, nrow=2)

# Set up everything

In [None]:
# low-pass filter
import cv2

def high_pass_blur(img_np, gauss_ksize):
    img_np_cv = np.copy(img_np)
    img_np_cv = np.moveaxis(img_np_cv, 0, 2)
    img_lp = cv2.blur(img_np_cv, (gauss_ksize, gauss_ksize))
    img_hp = img_np_cv - np.expand_dims(img_lp, 2)
    img_hp = np.moveaxis(img_hp, 2, 0)
    return img_hp

def high_pass_nlm(img_np):
    img_np_cv = np.array(img_np[0, :, :], dtype=np.uint8)
    img_lp = cv2.fastNlMeansDenoising(img_np_cv, 1, 3, 7)
    img_hp = np.array(np.expand_dims(img_lp, 0), dtype=np.float32)
    img_hp = img_np - img_hp
    return img_hp


def blur(img_np, gauss_ksize):
    img_np_cv = np.copy(img_np)
    img_np_cv = np.moveaxis(img_np_cv, 0, 2)
    img_lp = cv2.blur(img_np_cv, (gauss_ksize, gauss_ksize))
    img_lp = np.expand_dims(img_lp, 0)
    return img_lp


In [None]:
show_every=200
figsize=5
pad = 'reflection' # 'zero'
INPUT = 'noise'
input_depth = 32
OPTIMIZER = 'adam'
OPT_OVER =  'net'

OPTIMIZER = 'adam'

LR = 1e-3
num_iter = 4000
reg_noise_std = 3e-5

NET_TYPE = 'skip'
net = get_net(input_depth, 'skip', pad, n_channels=1,
                skip_n33d=128, 
                skip_n33u=128, 
                skip_n11=4, 
                num_scales=5,
                upsample_mode='bilinear').type(dtype)

# Loss
mse = torch.nn.MSELoss().type(dtype)
#img_var = np_to_torch(img_np).type(dtype)

net_input = get_noise(input_depth, INPUT, img_np.shape[1:]).type(dtype).detach()

num_parametes = sum(p.numel() for p in net.parameters() if p.requires_grad)
print("number parameters: {}".format(num_parametes))

# Main loop

In [None]:
img_zero = torch.zeros_like(img_gpu_list[0])

def closure():

    global i, last_net, net_input
    
    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)
    
    out = net(net_input)

    index = np.random.randint(0, len(img_gpu_list))

    target = img_np_list[index]
    target_gpu = img_gpu_list[index]

    tv_step = np.random.randint(10, 30)

    total_loss = total_variation(target_gpu - out, step=tv_step) + mse(img_zero, out)

    total_loss.backward()
    
    learned_noise = out.detach().cpu().numpy()[0]
    out_clean = img_np_list[index] - out.detach().cpu().numpy()[0]
    out_clean_norm = (out_clean-np.min(out_clean))/(np.max(out_clean)-np.min(out_clean))
    psrn = peak_signal_noise_ratio(img_clean_list[index], out_clean_norm) 

    print ('Iteration %05d    Loss %f PSNR %f' % (i, total_loss.item(), psrn),'\r', end='')
    
    
    if  PLOT and i % show_every == 0:
        plot_image_grid([np.clip(learned_noise, 0, 1), target, out_clean_norm], factor=figsize, nrow=3)

    i += 1

    return total_loss

# Init globals 
last_net = None
i = 0

net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()

# Run
p = get_params(OPT_OVER, net, net_input)
optimize(OPTIMIZER, p, closure, LR=LR, num_iter=num_iter)

In [None]:
out_np = torch_to_np(net(net_input))

out_list = []

for img_np in img_np_list[:30]:
    out_clean = img_np - out_np[0]
    out_clean_norm = (out_clean-np.min(out_clean))/(np.max(out_clean)-np.min(out_clean))
    out_list.append(out_clean_norm)


q = plot_image_grid(out_list, factor=13)

In [None]:
import cv2
index = file_list.index("24.jpg") #20
out_clean = img_np_list[index] - out_np[0]
out_clean_norm = (out_clean-np.min(out_clean))/(np.max(out_clean)-np.min(out_clean))
image = out_clean_norm[0] * 255
image = np.array(image, dtype=np.uint8)
#image = np.dstack([image, image, image])
cv2.imwrite("out_tv_dip.png", image)


psnr = peak_signal_noise_ratio(img_clean_list[index], out_clean_norm) 
print("PSNR: {}".format(psnr))

In [None]:
noise_lpn = out_np[0]
noise_lpn = noise_lpn / np.max(noise_lpn) * 255
noise_lpn = np.array(noise_lpn, dtype=np.uint8)
cv2.imwrite("out_tv_lpn.png", noise_lpn)