In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
import multiprocessing
from torch import autograd
from fastai.conv_learner import *
from fasterai.images import *
from fasterai.dataset import *
from fasterai.visualize import *
from fasterai.loss import *
from fastai.torch_imports import *
from pathlib import Path
from itertools import repeat
import tensorboardX
torch.cuda.set_device(1)
plt.style.use('dark_background')
torch.backends.cudnn.benchmark=True


In [None]:
DATA_PATH = Path('data/imagenet/ILSVRC/Data/CLS-LOC')
TRAIN_SOURCE_PATH = DATA_PATH/'train'
proj_id = 'bw2color'

D_MODEL_SAVE_PATH = DATA_PATH/(proj_id + '_8212_D.h5')
G_MODEL_SAVE_PATH = DATA_PATH/(proj_id + '_8212_G.h5')
#keep_pct=0.02
keep_pct=1.0
#keep_pct=0.005
#keep_pct=1.0

In [None]:
md = get_matched_image_model_data(image_size=224, batch_size=128, root_data_path=DATA_PATH, train_root_path=TRAIN_SOURCE_PATH, proj_id=proj_id, keep_pct=keep_pct)
denorm = md.val_ds.denorm

## EDSR Model

##### TODO:  Also try making the loss/output based on "classification" like in Zhang et al.
##### TODO:  After making unet version- plug that into a Weiserman GAN setup (the discrimator looks at grey image and colorized image, concatenated together via channels).
##### TODO:  Try using higher res images (from FloydHub blog?)
##### TODO:  Try perceptual loss again....
##### TODO:  To convert real old photos, could force them to normal grayscale first.
##### TODO:  Add tensorboard graphs

In [None]:
def icnr(x, scale=2, init=nn.init.kaiming_normal):
    new_shape = [int(x.shape[0] / (scale ** 2))] + list(x.shape[1:])
    subkernel = torch.zeros(new_shape)
    subkernel = init(subkernel)
    subkernel = subkernel.transpose(0, 1)
    subkernel = subkernel.contiguous().view(subkernel.shape[0],
                                            subkernel.shape[1], -1)
    kernel = subkernel.repeat(1, 1, scale ** 2)
    transposed_shape = [x.shape[1]] + [x.shape[0]] + list(x.shape[2:])
    kernel = kernel.contiguous().view(transposed_shape)
    kernel = kernel.transpose(0, 1)
    return kernel

In [None]:
def conv(ni, nf, kernel_size=3, actn=False, stride=1, normalizer=None):
    layers = [nn.Conv2d(ni, nf, kernel_size, padding=kernel_size//2, stride=stride)]
    if normalizer is not None: layers.append(normalizer)
    if actn: layers.append(nn.LeakyReLU())
    return nn.Sequential(*layers)

In [None]:
class ResSequential(nn.Module):
    def __init__(self, layers, res_scale=1.0):
        super().__init__()
        self.res_scale = res_scale
        self.m = nn.Sequential(*layers)

    def forward(self, x): return x + self.m(x) * self.res_scale

In [None]:
def res_block_upsample(nf):
    layers = []
    layers.append(conv(nf, nf, actn=True))
    layers.append(conv(nf, nf))
    return ResSequential(layers, 0.1)

In [None]:
class UpSampleBlock(nn.Module):
    def __init__(self, nf, scale=2):
        super().__init__()
        layers = []
        
        for i in range(int(math.log(scale,2))):
            layers += [conv(nf, nf*4), nn.PixelShuffle(2)]
                       
        self.sequence = nn.Sequential(*layers)
        self.icnr_init()
        
    def icnr_init(self):
        conv_shuffle = self.sequence[0][0]
        kernel = icnr(conv_shuffle.weight)
        conv_shuffle.weight.data.copy_(kernel);
    
    def forward(self, x):
        return self.sequence(x)

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, ni, no, ks, stride, bn=True, pad=None):
        super().__init__()   
        if pad is None: pad = ks//2//stride
        self.conv = nn.Conv2d(ni, no, ks, stride, padding=pad, bias=False)
        self.bn = nn.BatchNorm2d(no) if bn else None
        self.relu = nn.LeakyReLU(0.2)
    
    def forward(self, x):
        x = self.relu(self.conv(x))
        return self.bn(x) if self.bn else x

In [None]:
def conv_layer(ni, nf, ks=3, stride=1):
    return nn.Sequential(
        nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2),
        nn.BatchNorm2d(nf),
        nn.LeakyReLU())

In [None]:
class ResLayer(nn.Module):
    def __init__(self, ni):
        super().__init__()
        self.conv1=conv_layer(ni, ni//2, ks=1)
        self.conv2=conv_layer(ni//2, ni, ks=3)
        
    def forward(self, x): return x.add(self.conv2(self.conv1(x)))

In [None]:
class ImageModifierModel(nn.Module):
    @staticmethod
    def generate_base_model():
        f = resnet34
        cut,lr_cut = model_meta[f]
        cut-=1
        layers = cut_model(f(True), cut)
        return nn.Sequential(*layers)
    
    def set_trainable(self, trainable):
        set_trainable(self, trainable)
        set_trainable(self.rn, False)
       
    def make_group_layer(self, ch_in, num_blocks, stride=1):
        layers = [conv_layer(ch_in, ch_in*2,stride=stride)]
        layers += [(ResLayer(ch_in*2)) for i in range(num_blocks)]
        return nn.Sequential(*layers)

        
    def __init__(self):
        super().__init__() 
        
        self.rn = ImageModifierModel.generate_base_model()
        set_trainable(self.rn, False)
        
        self.color = nn.Sequential(
            self.make_group_layer(256, 1),
            UpSampleBlock(512, 16),
            conv_layer(512, 3)
        )
        
        self.out = nn.Sequential(
            self.make_group_layer(6, 1),
            conv(12,3)
        )
        
    def forward(self, orig): 
        x = self.rn(orig)
        x = self.color(x)
        return F.tanh(self.out(torch.cat([orig, x], dim=1)))

## WGAN Critic Model

In [None]:
class Critic2(nn.Module):
    @staticmethod
    def generate_base_model():
        f = resnet34
        cut,lr_cut = model_meta[f]
        layers = cut_model(f(True), cut)
        return nn.Sequential(*layers), lr_cut
    
    def set_trainable(self, trainable):
        set_trainable(self, trainable)
        set_trainable(self.rn, False)
        
    @staticmethod
    def generate_eval_layers(nf_in, nf_mid, sz):
        layers = [] 
        layers.append(ConvBlock(nf_in, nf_mid, 4, 2, bn=False))
        csize,cndf = sz//2,nf_mid
        layers.append(nn.LayerNorm([cndf, csize, csize]))
        layers.append(ConvBlock(cndf, cndf, 3, 1, bn=False))
        layers.append(nn.LayerNorm([cndf, csize, csize]))

        while csize > 8:
            layers.append(ConvBlock(cndf, cndf*2, 4, 2, bn=False))
            cndf = int(cndf*2)
            csize = int(csize//2)
            layers.append(nn.LayerNorm([cndf, csize, csize]))
        
        layers.append(nn.Conv2d(cndf, 1, 4, padding=0, bias=False))    
        return nn.Sequential(*layers) 
            
    def __init__(self, sz):
        super().__init__()
         
        rn, lr_cut = Critic2.generate_base_model()
        self.rn = rn
        set_trainable(self.rn, False)
        self.lr_cut = lr_cut
        self.sfs = [SaveFeatures(rn[i]) for i in [2,4,5,6]]
        
        self.feature_eval_1 = Critic2.generate_eval_layers(256*2, 256, sz//16)
        self.feature_eval_2 = Critic2.generate_eval_layers(128*2, 128, sz//8)
        self.feature_eval_3 = Critic2.generate_eval_layers(64*2, 64, sz//4)
        self.feature_eval_4 = Critic2.generate_eval_layers(64*2, 64, sz//2)     
        self.pixel_eval = Critic2.generate_eval_layers(6, 64, sz)
        
    def forward(self, input, orig):
        self.rn(orig)
        x1 = self.sfs[3].features
        x2 = self.sfs[2].features
        x3 = self.sfs[1].features
        x4 = self.sfs[0].features
        
        self.rn(input)
        y1 = self.sfs[3].features
        y2 = self.sfs[2].features
        y3 = self.sfs[1].features
        y4 = self.sfs[0].features 

        f1 = self.feature_eval_1(torch.cat([x1, y1], dim=1))
        f2 = self.feature_eval_2(torch.cat([x2, y2], dim=1))
        f3 = self.feature_eval_3(torch.cat([x3, y3], dim=1))
        f4 = self.feature_eval_4(torch.cat([x4, y4], dim=1))
  
        p = self.pixel_eval(torch.cat([orig, input], dim=1))
        return f1.mean() + f2.mean() + f3.mean()  + f4.mean() + p.mean()

## Training

In [None]:
wd=1e-7
bs = 8
sz = 224
md = get_matched_image_model_data(image_size=sz, batch_size=bs, root_data_path=DATA_PATH, train_root_path=TRAIN_SOURCE_PATH, proj_id=proj_id, keep_pct=keep_pct)

In [None]:
netG = ImageModifierModel().cuda()
#load_model(netG, G_MODEL_SAVE_PATH)
netD = Critic2(sz).cuda()
#load_model(netD, D_MODEL_SAVE_PATH)

In [None]:
#optimizerD = optim.RMSprop(filter(lambda p: p.requires_grad,netD.parameters()), lr = 1e-4)
#optimizerG = optim.RMSprop(filter(lambda p: p.requires_grad,netG.parameters()), lr = 1e-4)

#optimizerD = optim.RMSprop(filter(lambda p: p.requires_grad,netD.parameters()), lr = 1e-3)
#optimizerG = optim.RMSprop(filter(lambda p: p.requires_grad,netG.parameters()), lr = 1e-3)

#TODO:  Weird beta 1 value, but we'll see...

optimizerD = optim.Adam(filter(lambda p: p.requires_grad,netD.parameters()), lr=1e-5, betas=(0., 0.9))
optimizerG = optim.Adam(filter(lambda p: p.requires_grad,netG.parameters()), lr=1e-5, betas=(0., 0.9))

In [None]:
def calc_gradient_penalty(netD, real_data, fake_data, orig_data):
    lamda = 10 # Gradient penalty lambda hyperparameter
    # print "real_data: ", real_data.size(), fake_data.size()
    alpha = torch.rand(bs, 1)
    alpha = alpha.expand(bs, real_data.nelement()//bs).contiguous().view(bs, 3, sz, sz)
    alpha = alpha.cuda()
    differences = fake_data - real_data
    interpolates = real_data + (alpha*differences)
    interpolates = interpolates.cuda()
    interpolates = autograd.Variable(interpolates, requires_grad=True)
    disc_interpolates = netD(interpolates, orig_data)
    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lamda
    return gradient_penalty

In [None]:
def progress_update(i, w_dist, gradient_penalty, disc_real, disc_fake, md, netG, netD, ecount):
    if i % 50 == 0:
        print(f'\nWDist {to_np(w_dist)}; GPenalty {to_np(gradient_penalty)}; RScore {to_np(disc_real)};'
          + f' FScore {to_np(disc_fake)}; ECount: {ecount}')

    if i % 500 == 0:
        visualize_image_gen_model(md, netG, 500, 8)
        save_model(netD, D_MODEL_SAVE_PATH)
        save_model(netG, G_MODEL_SAVE_PATH)

def is_equilibrium(disc_real, disc_fake):
    if disc_real < disc_fake:
        return False
        
    return abs(disc_real + disc_fake) < (abs(disc_real) + abs(disc_fake))*0.30

def train(niter, first=True):
    gen_iterations = 0
    for epoch in trange(niter):
        netD.train(); netG.train()
        data_iter = iter(md.trn_dl)
        i,n = 0,len(md.trn_dl)
        n = n-(n%bs)
        with tqdm(total=n) as pbar:
            while i < n:
                netD.set_trainable(True)
                netG.set_trainable(False)
                j = 0
                equilibrium = False
                while (not equilibrium) and (i < n) and j<10000:
                    j += 1; i += 1
                    #or p in netD.parameters(): p.data.clamp_(-0.01, 0.01)
                    x, y = next(data_iter)
                    orig_image = V(x)
                    real_image = V(y)                        
                    #Higher == Real
                    disc_real = netD(real_image, orig_image)
                    fake_image = netG(orig_image)
                    disc_fake = netD(V(fake_image.data), orig_image)
                    equilibrium = is_equilibrium(disc_real, disc_fake)
                    
                    netD.zero_grad()
                          
                    gradient_penalty = calc_gradient_penalty(netD, real_image.data, fake_image.data, orig_image)              
                    disc_cost = disc_fake - disc_real + gradient_penalty
                    w_dist = disc_fake - disc_real
                    disc_cost.backward()
                    optimizerD.step()
                    pbar.update()
             
                    progress_update(i, w_dist, gradient_penalty, disc_real, disc_fake, md, netG, netD, gen_iterations)

                    
                netD.set_trainable(False)
                netG.set_trainable(True)
                netG.zero_grad()
                
                x, y = next(data_iter)
                orig_image = V(x)
                real_image = V(y)   
                fake_image = netG(orig_image)
                gen_mse_cost = F.mse_loss(fake_image, real_image)
                gen_cost  = -netD(fake_image, orig_image)
                gen_cost .backward()
                optimizerG.step()
                gen_iterations += 1
                
                progress_update(i, w_dist, gradient_penalty, disc_real, disc_fake, md, netG, netD, gen_iterations)

In [None]:
torch.backends.cudnn.benchmark=True

In [None]:
train(1, True)
#train(1, False)

In [None]:
train(9, False)