# Import libs

In [None]:
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline

import os
#os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import numpy as np
from models.resnet import ResNet
from models.unet import UNet
from models.skip import skip
from models import get_net
import torch
import torch.optim

from skimage.metrics import peak_signal_noise_ratio
from util.common_utils import * 
from util.loss import total_variation

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor

PLOT = True
imsize=-1
dim_div_by = 64
dtype = torch.cuda.FloatTensor

# Choose figure

In [None]:
# base_dir = './data/human'
base_dir = './data/power_supply'
# img_dir = os.path.join(base_dir, "images_lpn") 
img_dir = os.path.join(base_dir, "images_combined") 
img_clean_dir = os.path.join(base_dir, "images_clean")
file_list = os.listdir(img_dir)

img_np_list = []
img_gpu_list = []
img_clean_list = []

for f in file_list:
    _, img_np = get_image(os.path.join(img_dir, f), imsize)
    img_np = img_np[0:1, :, :] 
    img_np_list.append(img_np)
    img_gpu_list.append(np_to_torch(img_np).type(dtype))

    _, img_clean = get_image(os.path.join(img_clean_dir, f), imsize)
    img_clean = img_clean[0:1, :, :] 
    img_clean_list.append(img_clean)

print("img_np shape: ", img_np_list[0].shape)

_ = plot_image_grid([img_clean_list[0], img_np_list[0]], factor=4, nrow=2)

# Set up everything

In [None]:
show_every=200
figsize=5
pad = 'reflection' # 'zero'
INPUT = 'noise'
input_depth = 32
OPTIMIZER = 'adam'
OPT_OVER =  'net'

OPTIMIZER = 'adam'

LR = 1e-3
num_iter = 15000
iter_step = 1500

reg_noise_std = 3e-5

NET_TYPE = 'skip'
net = get_net(input_depth, 'skip', pad, n_channels=1,
                skip_n33d=128, 
                skip_n33u=128, 
                skip_n11=4, 
                num_scales=5,
                upsample_mode='bilinear').type(dtype)

# Loss
mse = torch.nn.MSELoss().type(dtype)
#img_var = np_to_torch(img_np).type(dtype)

net_input = get_noise(input_depth, INPUT, img_np.shape[1:]).type(dtype).detach()

num_parametes = sum(p.numel() for p in net.parameters() if p.requires_grad)
print("number parameters: {}".format(num_parametes))

# Main loop

In [None]:
def closure():

    global i, last_net, net_input
    
    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)
    
    out = net(net_input)

    index = np.random.randint(0, len(img_gpu_list))

    target = img_np_list[index]
    target_gpu = img_gpu_list[index]

    
    if i <= iter_step * 1:
        tv_step = np.random.randint(50, 60)
    elif i <= iter_step * 2:
        tv_step = np.random.randint(40, 50)
    elif i <= iter_step * 3:
        tv_step = np.random.randint(30, 40)
    elif i <= iter_step * 4:
        tv_step = np.random.randint(20, 30)
    elif i <= iter_step * 5:
        tv_step = np.random.randint(10, 20)
    elif i <= iter_step * 6:
        tv_step = np.random.randint(3, 10)
    else:
        tv_step = 1
    
    total_loss = total_variation(img_gpu_list[index] - out, reduction="sum", step=tv_step) #+ 1e-1 * mse(fpn_target, out)


    total_loss.backward()
    
    learned_noise = out.detach().cpu().numpy()[0]
    out_clean = img_np_list[index] - out.detach().cpu().numpy()[0]
    out_clean_norm = (out_clean-np.min(out_clean))/(np.max(out_clean)-np.min(out_clean))
    psrn = peak_signal_noise_ratio(img_clean_list[index], out_clean_norm) 

    print ('Iteration %05d    Loss %f PSNR %f' % (i, total_loss.item(), psrn),'\r', end='')
    
    
    if  PLOT and i % show_every == 0:
        plot_image_grid([np.clip(learned_noise, 0, 1), target, out_clean_norm], factor=figsize, nrow=3)

    i += 1

    return total_loss

# Init globals 
last_net = None
i = 0


net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()

# Run
p = get_params(OPT_OVER, net, net_input)
optimize(OPTIMIZER, p, closure, LR=LR, num_iter=num_iter)

In [None]:
def closure():

    global i, last_net, net_input, index, num_iter
    
    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)
    
    out = net(net_input)

    target = img_np_list[index]
    target_gpu = img_gpu_list[index]
    
    tv_step = np.random.randint(3, 10)
    total_loss = total_variation(img_gpu_list[index] - out, reduction="sum", step=1) + \
                 1e-1 * total_variation(img_gpu_list[index] - out, reduction="sum", step=tv_step)
         #+ 1e-1 * mse(fpn_target, out)

    total_loss.backward()
    
    learned_noise = out.detach().cpu().numpy()[0]
    out_clean = img_np_list[index] - out.detach().cpu().numpy()[0]
    out_clean_norm = (out_clean-np.min(out_clean))/(np.max(out_clean)-np.min(out_clean))
    psrn = peak_signal_noise_ratio(img_clean_list[index], out_clean_norm) 

    print ('Iteration %05d    Loss %f PSNR %f' % (i, total_loss.item(), psrn),'\r', end='')
    
    i += 1
    return total_loss


out_list = []
max_psnr = -1
max_index = -1

for index, img_np in enumerate(img_np_list):
    last_net = None
    i = 0
    num_iter = 50
    LR = 1e-4

    # Run
    p = get_params(OPT_OVER, net, net_input)
    optimize(OPTIMIZER, p, closure, LR=LR, num_iter=num_iter)

    out_np = torch_to_np(net(net_input))
    out_clean = img_np - out_np[0]
    out_clean_norm = (out_clean-np.min(out_clean))/(np.max(out_clean)-np.min(out_clean))
    out_list.append(out_clean_norm)

    psnr = peak_signal_noise_ratio(img_clean_list[index], out_clean_norm)
    if psnr > max_psnr:
        max_psnr = psnr
        max_index = index



In [None]:
import cv2

with open("psnr_ps.csv", 'w') as f:
    f.write("index, psnr\n")
    for i in range(len(out_list)):
        index = i # max_index # file_list.index("24.jpg")
        out_clean = out_list[index]

        image = np.array(out_clean[0] * 255, dtype=np.uint8)

        cv2.imwrite("out_tv_dip-{}.png".format(i), image)

        psnr = peak_signal_noise_ratio(img_clean_list[index], out_clean)
        info = "{}, {}\n".format(file_list[index], psnr)
        f.write(info)
        print(info)

In [None]:
max_psnr