In [1]:
import numpy as np
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
from torch.autograd import Variable

import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torchvision.models import vgg19
import functools

from skimage.io import imread, imsave
from skimage import img_as_float, img_as_uint
from skimage.measure import compare_psnr
from skimage.transform import resize
from skimage.measure import compare_psnr
import cv2

from artificial_bluring import blur_img
from IPython import display

from copy import deepcopy
#from models.instancenormrs import *

from torch.utils.data import Dataset

import matplotlib.pyplot as plt
%matplotlib inline

from tqdm import tqdm, tqdm_notebook
from tensorboardX import SummaryWriter

from utils import *
from network_models import *
from traintest import *
from unet import *

In [2]:
writer = SummaryWriter('tbruns/uexp1')

In [3]:
os.environ['CUDA_VISIBLE_DEVICES'] = '5'

#### Data:

In [4]:
coco_path = 'coco/unlabeled2017/'
coco_files = np.random.choice(os.listdir(coco_path), size=1500, replace=False)
coco_files_train = coco_files[:1000]
coco_files_test = coco_files[1000:]

coco_files_train = list(map(lambda x: coco_path+x, coco_files_train))
coco_files_test = list(map(lambda x: coco_path+x, coco_files_test))

In [5]:
data = GOPRO_extended(include_sharp=0, include_coco=None,#{'train':coco_files_train, 'test':coco_files_test}, 
                      returnLP=3,
                      desired_shape=None,
                      transform=False,
                      crop=(256,256))

data = torch.utils.data.DataLoader(data, batch_size=1, shuffle=True)

In [6]:
test_data = GOPRO_extended(include_sharp=0,
                           train=False,
                           include_coco=None,#{'train':coco_files_train, 'test':coco_files_test}, 
                           returnLP=3,
                           transform=False,
                           crop=None,
                           desired_shape=([360,480],[640]))

In [7]:
test_data2 = GOPRO_extended(include_sharp=0,
                           train=True,
                           include_coco=None,#{'train':coco_files_train, 'test':coco_files_test}, 
                           returnLP=3,
                           transform=False,
                           desired_shape=([360,480],[640]))

### Network:

In [8]:
norm_layer = get_norm_layer()

In [9]:
D = NLayerDiscriminatorRF(input_nc=3, n_layers=4, ndf=128, norm_layer=norm_layer, gpu_ids=[0,1,2])

device = torch.device("cuda:0")
D.cuda();

In [10]:
G = UNetGenerator(input_nc=6, output_nc=3, ngf=64, norm_layer=norm_layer)
G.cuda();

### Train:

In [11]:
init_lr_g=1e-4
init_lr_d=1e-4

In [12]:
optD = torch.optim.Adam(D.parameters(), lr=init_lr_d)#,betas=(0., 0.5)#, weight_decay=1e-4)
optG = torch.optim.Adam(G.parameters(), lr=init_lr_g)#,betas=(0., 0.5)#, weight_decay=1e-4)

pepceptual = PerceptualLoss()
pepceptual.initialize(loss=nn.MSELoss())

In [13]:
optD.param_groups[0]['lr'] = 1e-5
optG.param_groups[0]['lr'] = 1e-5

PatchGAN:

In [None]:
levelLP = -1
kd=5
kg=1
disc_losses, gen_losses, psnrs, mses = list(),list(),list(),list()
reals, fakes = [],[]


for epoch in tqdm(range(134,300)):
    
    #train
    D.train(True)
    G.train(True)
    for i, ((blurred_LP,blurred), (sharp_LP, sharp)) in tqdm_notebook(enumerate(data)):
        
        target = sharp_LP[levelLP]
        
        main_blurred = blurred_LP[levelLP]
        blurred_img = blurred.transpose(dim0=2, dim1=3).transpose(dim0=1, dim1=2).float()
        main_blurred = torch.cat([main_blurred, blurred_img],1)
        
        #update discriminator
        for _ in range(kd):
            
            #Calculate critic loss
            X_fake = G(Variable(main_blurred).cuda())
            X_fake_noise = torch.randn_like(X_fake)
            X_fake_noise = X_fake + X_fake_noise/np.random.randint(20,50)
    

            X_real = Variable(target).cuda()
            X_real_noise = torch.randn_like(X_real)
            X_real_noise = X_real + X_real_noise/np.random.randint(20,50)
    
            
            critic_loss = D(X_fake_noise).mean() - D(X_real_noise).mean()
            
            #Calculate GP
            eps = torch.rand(1, 1)
            eps = eps.expand(X_real_noise.size())
            eps = eps.cuda()

            interpolates = eps*X_real_noise + ((1-eps)*X_fake_noise)
            interpolates = interpolates.cuda()
            interpolates = Variable(interpolates, requires_grad=True)
            
            D_interpolates = D(interpolates)
            grad = autograd.grad(outputs=D_interpolates, inputs=interpolates,
                                 grad_outputs=torch.ones(D_interpolates.size()).cuda(),
                                 create_graph=True, retain_graph=True)[0]

            GP = 10*torch.pow(grad.norm(2, dim=1)-1,2).mean()
            
            #Do update
            Dloss = critic_loss + GP
            D.zero_grad()
            Dloss.backward(retain_graph=True)
            optD.step()
            
            writer.add_scalar('critic loss', critic_loss, global_step=i)
            writer.add_scalar('GPs', GP, global_step=i)
            writer.add_scalar('D_loss', Dloss.item(), global_step=i)
            
        if len(fakes)>=2:
            idxs = [np.random.random_integers(0, len(fakes)-1, size=np.random.randint(0, len(fakes)//2))]
            for fake_buf,reals_buf in np.array(list(zip(fakes, reals)), dtype=np.object)[idxs]:

                if np.random.rand()<0.5:
                    X_fake_noise = torch.randn_like(fake_buf)
                    X_fake_noise = fake_buf + X_fake_noise/np.random.randint(20,50)
                else:
                    X_fake_noise = fake_buf

                if np.random.rand()<0.5:
                    X_real_noise = torch.randn_like(reals_buf)
                    X_real_noise = reals_buf + X_real_noise/np.random.randint(20,50)
                else:
                    X_real_noise = reals_buf

                critic_loss = D(X_fake_noise).mean() - D(X_real_noise).mean()

                #Calculate GP
                eps = torch.rand(1, 1)
                eps = eps.expand(X_real_noise.size())
                eps = eps.cuda()

                interpolates = eps*X_real_noise + ((1-eps)*X_fake_noise)
                interpolates = interpolates.cuda()
                interpolates = Variable(interpolates, requires_grad=True)

                D_interpolates = D(interpolates)
                grad = autograd.grad(outputs=D_interpolates, inputs=interpolates,
                                     grad_outputs=torch.ones(D_interpolates.size()).cuda(),
                                     create_graph=True, retain_graph=True)[0]

                GP = 10*torch.pow(grad.norm(2, dim=1)-1,2).mean()

                #Do update
                Dloss = critic_loss + GP
                D.zero_grad()
                Dloss.backward(retain_graph=True)
                optD.step()

        #update generator
        for _ in range(kg):
            X_fake = G(Variable(main_blurred).cuda())

            content_loss = pepceptual.get_loss(X_fake, X_real)
            Adv_loss = -D(X_fake).mean()
            Gloss = Adv_loss + 0.5*content_loss

            G.zero_grad()
            Gloss.backward()
            optG.step()

            writer.add_scalar('Gen loss', Gloss.item(), global_step=i)
            writer.add_scalar('Gen Wasserstein loss', Adv_loss.item(), global_step=i)
            writer.add_scalar('Gen content loss', content_loss.item(), global_step=i)
        
        
        if len(fakes)>20:
            rm = np.random.randint(0,len(fakes))
            del fakes[rm]
            del reals[rm]
        
        fakes.append(X_fake.detach())
        reals.append(X_real.detach())
        
    #Dynamic lr
    if epoch>150:
        optD.param_groups[0]['lr'] = optD.param_groups[0]['lr'] - init_lr_d/150
        optG.param_groups[0]['lr'] = optG.param_groups[0]['lr'] - init_lr_g/150
            
    #Save
    torch.save(G.state_dict(), 'unet_gen_exp1.pth')
    torch.save(D.state_dict(), 'unet_disc_exp1.pth')
    torch.save(optG.state_dict(), 'uoptG_exp1.pth')
    torch.save(optD.state_dict(), 'uoptD_exp1.pth')
    
    #test
    (recon,gt,orig), psnr, _ = test_deblurring_unet(net={'lvl3_net':G, 'lvl2_net':None, 'lvl1_net':None},
                                                   test_dataset=test_data,
                                                   lvl=[3],
                                                   return_sharp_blurred=True)
    
        
    recon_lap_vis = get_laplacian_pyramid(recon)[-1].clip(0,1)
    gt_lap_vis = get_laplacian_pyramid(gt)[-1].clip(0,1)
    
    writer.add_image('recon_lp_test', torch.FloatTensor(np.transpose(recon_lap_vis, [2,0,1])), epoch)
    writer.add_image('gt_lp_test', torch.FloatTensor(np.transpose(gt_lap_vis, [2,0,1])), epoch)
    writer.add_scalar('PSNRs', psnr, global_step=epoch)
    
    #test on train
    (recon,gt,orig), psnr, _ = test_deblurring_unet(net={'lvl3_net':G, 'lvl2_net':None, 'lvl1_net':None},
                                                   test_dataset=test_data2,
                                                   lvl=[3],
                                                   return_sharp_blurred=True)
    
    recon_lap_vis = get_laplacian_pyramid(recon)[-1].clip(0,1)
    gt_lap_vis = get_laplacian_pyramid(gt)[-1].clip(0,1)
    
    writer.add_image('recon_lp_train', torch.FloatTensor(np.transpose(recon_lap_vis, [2,0,1])), epoch)
    writer.add_image('gt_lp_train', torch.FloatTensor(np.transpose(gt_lap_vis, [2,0,1])), epoch)

  0%|          | 0/166 [00:00<?, ?it/s]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

