In [1]:
import numpy as np
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
from torch.autograd import Variable

import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torchvision.models import vgg19
import functools

from skimage.io import imread, imsave
from skimage import img_as_float, img_as_uint
from skimage.measure import compare_psnr
from skimage.transform import resize
from skimage.measure import compare_psnr
import cv2

from artificial_bluring import blur_img
from IPython import display

from copy import deepcopy
#from models.instancenormrs import *

from torch.utils.data import Dataset

import matplotlib.pyplot as plt
%matplotlib inline

from tqdm import tqdm, tqdm_notebook
from tensorboardX import SummaryWriter

from utils import *
from network_models import *
from traintest import *

In [2]:
writer = SummaryWriter('tbruns/exp21')

In [3]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

#### Data:

In [None]:
coco_path = 'coco/unlabeled2017/'
coco_files = np.random.choice(os.listdir(coco_path), size=1500, replace=False)
coco_files_train = coco_files[:1000]
coco_files_test = coco_files[1000:]

coco_files_train = list(map(lambda x: coco_path+x, coco_files_train))
coco_files_test = list(map(lambda x: coco_path+x, coco_files_test))

In [4]:
data = GOPRO_extended(include_sharp=0, include_coco=None,#{'train':coco_files_train, 'test':coco_files_test}, 
                      returnLP=3,
                      desired_shape=None,
                      transform=False,
                      crop=(256,256))

data = torch.utils.data.DataLoader(data, batch_size=1, shuffle=True)

In [5]:
test_data = GOPRO_extended(include_sharp=0,
                           train=False,
                           include_coco=None,#{'train':coco_files_train, 'test':coco_files_test}, 
                           returnLP=3,
                           transform=False,
                           desired_shape=([360,480],[640]))

In [6]:
test_data2 = GOPRO_extended(include_sharp=0,
                           train=True,
                           include_coco=None,#{'train':coco_files_train, 'test':coco_files_test}, 
                           returnLP=3,
                           transform=False,
                           desired_shape=([360,480],[640]))

### Network:

In [7]:
norm_layer = get_norm_layer()

D = NLayerDiscriminatorRF(input_nc=3, n_layers=4, ndf=128, norm_layer=norm_layer, gpu_ids=[0,1,2])

device = torch.device("cuda:0")
D.cuda();

In [8]:
gen_arg = {'input_nc':3, 'input_enc':9, 'output_nc':3, 'ngf':128, 'norm_layer':norm_layer,
           'use_dropout':True, 'n_blocks':9, 'gpu_ids':[0,1,2], 'use_parallel':False,
           'learn_residual':True, 'padding_type':'zero', 'partial_downsample':False}

G = ResnetGenerator(**gen_arg)
G.cuda();

### Train:

In [9]:
init_lr_g=1e-5
init_lr_d=1e-5

In [10]:
optD = torch.optim.Adam(D.parameters(), lr=init_lr_d)#,betas=(0., 0.5)#, weight_decay=1e-4)
optG = torch.optim.Adam(G.parameters(), lr=init_lr_g)#,betas=(0., 0.5)#, weight_decay=1e-4)

pepceptual = PerceptualLoss()
pepceptual.initialize(loss=nn.MSELoss())

PatchGAN:

In [None]:
levelLP = -1
kd=5
kg=1
disc_losses, gen_losses, psnrs, mses = list(),list(),list(),list()
anomal_pos, anomal_neg, peaks = [],[],[]
for epoch in tqdm(range(259,300)):
    
    #train
    D.train(True)
    G.train(True)
    for i, ((blurred_LP,blurred), (sharp_LP, sharp)) in tqdm_notebook(enumerate(data)):
        
        target = sharp_LP[levelLP]
        main_blurred, aux_blurred = get_network_tensors(blurred_LP, base_lvl=levelLP)

        #update discriminator
        for _ in range(kd):
            
            #Calculate critic loss
            X_fake = G(Variable(main_blurred).cuda(),Variable(aux_blurred).cuda())
            #X_fake_noise = torch.randn_like(X_fake)
            X_fake_noise = X_fake# + X_fake_noise/np.random.randint(20,50)
            
            #if np.random.rand() < 0.3:
            #    X_fake_noise = torch.transpose(X_fake_noise, dim0=2, dim1=3)

            X_real = Variable(target).cuda()
            #X_real_noise = torch.randn_like(X_real)
            X_real_noise = X_real# + X_real_noise/np.random.randint(20,50)
            
            #if np.random.rand() < 0.3:
            #    X_real_noise = torch.transpose(X_real_noise, dim0=2, dim1=3)
            
            critic_loss = D(X_fake_noise).mean() - D(X_real_noise).mean()
            
            #Calculate GP
            eps = torch.rand(1, 1)
            eps = eps.expand(X_real_noise.size())
            eps = eps.cuda()

            interpolates = eps*X_real_noise + ((1-eps)*X_fake_noise)
            interpolates = interpolates.cuda()
            interpolates = Variable(interpolates, requires_grad=True)
            
            D_interpolates = D(interpolates)
            grad = autograd.grad(outputs=D_interpolates, inputs=interpolates,
                                 grad_outputs=torch.ones(D_interpolates.size()).cuda(),
                                 create_graph=True, retain_graph=True)[0]

            GP = 10*torch.pow(grad.norm(2, dim=1)-1,2).mean()
            
            #Do update
            Dloss = critic_loss + GP
            D.zero_grad()
            Dloss.backward(retain_graph=True)
            optD.step()
            
            writer.add_scalar('critic loss', critic_loss, global_step=i)
            writer.add_scalar('GPs', GP, global_step=i)
            writer.add_scalar('D_loss', Dloss.item(), global_step=i)
            
            disc_losses.append(Dloss.item())
            if len(disc_losses)>2:
                del disc_losses[0]
            
            if disc_losses==2:
                delta=disc_losses[0]-disc_losses[1]
                if np.abs(delta)>=80:
                    peaks.append(delta)
                    if delta>0:
                        anomal_pos.append((X_fake, X_real))
                    else:
                        anomal_neg.append((X_fake, X_real))
                    print('anomal')
                    
        #update generator
        for _ in range(kg):
            X_fake = G(Variable(main_blurred).cuda(),Variable(aux_blurred).cuda())

            content_loss = pepceptual.get_loss(X_fake, X_real)
            Adv_loss = -D(X_fake).mean()
            Gloss = Adv_loss + 0.5*content_loss

            G.zero_grad()
            Gloss.backward()
            optG.step()

            writer.add_scalar('Gen loss', Gloss.item(), global_step=i)
            writer.add_scalar('Gen adv loss', Adv_loss.item(), global_step=i)
            writer.add_scalar('Gen content loss', 0.5*content_loss.item(), global_step=i)
        
    #Dynamic lr
    if epoch>150:
        optD.param_groups[0]['lr'] = optD.param_groups[0]['lr'] - init_lr_d/150
        optG.param_groups[0]['lr'] = optG.param_groups[0]['lr'] - init_lr_g/150
            
    #Save
    torch.save(G.state_dict(), 'ResGen_exp21.pth')
    torch.save(D.state_dict(), 'ResDisc_exp21.pth')
    torch.save(optG.state_dict(), 'optG_exp21.pth')
    torch.save(optD.state_dict(), 'optD_exp21.pth')
    
    #test
    (recon,gt,orig), psnr, _ = test_deblurring_cur(net={'lvl3_net':G, 'lvl2_net':None, 'lvl1_net':None},
                                                   test_dataset=test_data,
                                                   lvl=[3],
                                                   return_sharp_blurred=True)
    
        
    recon_lap_vis = get_laplacian_pyramid(recon)[-1].clip(0,1)
    gt_lap_vis = get_laplacian_pyramid(gt)[-1].clip(0,1)
    
    writer.add_image('recon_lp_test', torch.FloatTensor(np.transpose(recon_lap_vis, [2,0,1])), epoch)
    writer.add_image('gt_lp_test', torch.FloatTensor(np.transpose(gt_lap_vis, [2,0,1])), epoch)
    writer.add_scalar('PSNRs', psnr, global_step=epoch)
    
    #test on train
    (recon,gt,orig), psnr, _ = test_deblurring_cur(net={'lvl3_net':G, 'lvl2_net':None, 'lvl1_net':None},
                                                   test_dataset=test_data2,
                                                   lvl=[3],
                                                   return_sharp_blurred=True)
    
    recon_lap_vis = get_laplacian_pyramid(recon)[-1].clip(0,1)
    gt_lap_vis = get_laplacian_pyramid(gt)[-1].clip(0,1)
    
    writer.add_image('recon_lp_train', torch.FloatTensor(np.transpose(recon_lap_vis, [2,0,1])), epoch)
    writer.add_image('gt_lp_train', torch.FloatTensor(np.transpose(gt_lap_vis, [2,0,1])), epoch)


  0%|          | 0/41 [00:00<?, ?it/s][A

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Exception in thread Thread-224:
Traceback (most recent call last):
  File "/data/install/anaconda3/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/data/install/anaconda3/lib/python3.6/site-packages/tqdm/_monitor.py", line 63, in run
    for instance in self.tqdm_cls._instances:
  File "/data/install/anaconda3/lib/python3.6/_weakrefset.py", line 60, in __iter__
    for itemref in self.data:
RuntimeError: Set changed size during iteration

  warn("Anti-aliasing will be enabled by default in skimage 0.15 to "
  2%|▏         | 1/41 [1:08:53<45:55:32, 4133.31s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

  5%|▍         | 2/41 [2:19:44<45:25:01, 4192.34s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

  7%|▋         | 3/41 [3:30:42<44:29:03, 4214.29s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

 10%|▉         | 4/41 [4:40:34<43:15:22, 4208.72s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

 12%|█▏        | 5/41 [5:50:10<42:01:19, 4202.20s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

 15%|█▍        | 6/41 [6:59:47<40:48:47, 4197.92s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

 17%|█▋        | 7/41 [8:09:32<39:37:46, 4196.08s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

 20%|█▉        | 8/41 [9:19:07<38:26:23, 4193.44s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

 22%|██▏       | 9/41 [10:28:40<37:15:16, 4191.14s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

 24%|██▍       | 10/41 [11:38:15<36:04:35, 4189.54s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

 27%|██▋       | 11/41 [12:47:39<34:53:36, 4187.21s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

 29%|██▉       | 12/41 [13:56:18<33:41:03, 4181.50s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

 32%|███▏      | 13/41 [15:05:05<32:29:24, 4177.31s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

 34%|███▍      | 14/41 [16:12:31<31:15:34, 4167.95s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

 37%|███▋      | 15/41 [17:18:49<30:00:37, 4155.27s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

 39%|███▉      | 16/41 [18:24:43<28:46:07, 4142.71s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

 41%|████▏     | 17/41 [19:30:58<27:33:08, 4132.86s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

 44%|████▍     | 18/41 [20:37:00<26:20:37, 4123.35s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

 46%|████▋     | 19/41 [21:43:52<25:09:45, 4117.50s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

 49%|████▉     | 20/41 [22:50:24<23:58:56, 4111.24s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

 51%|█████     | 21/41 [23:58:48<22:50:17, 4110.87s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Exception in thread Thread-4:
Traceback (most recent call last):
  File "/data/install/anaconda3/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/data/install/anaconda3/lib/python3.6/site-packages/tensorboardX/event_file_writer.py", line 189, in run
    self._ev_writer.write_event(event)
  File "/data/install/anaconda3/lib/python3.6/site-packages/tensorboardX/event_file_writer.py", line 71, in write_event
    return self._write_serialized_event(event.SerializeToString())
  File "/data/install/anaconda3/lib/python3.6/site-packages/tensorboardX/event_file_writer.py", line 75, in _write_serialized_event
    self._py_recordio_writer.write(event_str)
  File "/data/install/anaconda3/lib/python3.6/site-packages/tensorboardX/record_writer.py", line 31, in write
    self._writer.flush()
OSError: [Errno 28] No space left on device

