In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
import multiprocessing
from torch import autograd
from fastai.conv_learner import *
from fasterai.images import *
from fasterai.dataset import *
from fasterai.visualize import *
from fasterai.loss import *
from fasterai.modules import *
from fasterai.wgan import *
from fastai.torch_imports import *
from pathlib import Path
from itertools import repeat
import tensorboardX
torch.cuda.set_device(0)
plt.style.use('dark_background')
torch.backends.cudnn.benchmark=True


In [None]:
DATA_PATH = Path('data/imagenet/ILSVRC/Data/CLS-LOC')
TRAIN_SOURCE_PATH = DATA_PATH/'train'
proj_id = 'bw2color'

dpath = DATA_PATH/(proj_id + '_5004_D.h5')
gpath = DATA_PATH/(proj_id + '_5004_G.h5')
keep_pct=1.0
wd=1e-7
bs=8
sz=128
lr=1e-4
torch.backends.cudnn.benchmark=True

## EDSR Model

##### TODO:  Also try making the loss/output based on "classification" like in Zhang et al.
##### TODO:  After making unet version- plug that into a Weiserman GAN setup (the discrimator looks at grey image and colorized image, concatenated together via channels).
##### TODO:  Try using higher res images (from FloydHub blog?)
##### TODO:  Try perceptual loss again....
##### TODO:  To convert real old photos, could force them to normal grayscale first.
##### TODO:  Add tensorboard graphs

In [None]:
class ImageModifierModel(nn.Module):
    def set_trainable(self, trainable):
        set_trainable(self, trainable)
        set_trainable(self.rn, False)
    
    def __init__(self):
        super().__init__() 
        rn, lr_cut = get_pretrained_resnet_base()

        self.rn = rn
        set_trainable(rn, False)
        self.lr_cut = lr_cut
        self.sfs = [SaveFeatures(rn[i]) for i in [2,4,5,6]]
        
        self.up1 = UpSampleBlock(256, 256, 16)  #256 in
        self.up2 = UpSampleBlock(128, 128, 8)  #128 in
        self.up3 = UpSampleBlock(64, 64, 4)    #64 in
        self.up4 = UpSampleBlock(64, 64, 2)   #64 in  
        nf_up = 256+128+64+64+3
        nf_mid = 256  
 
        mid_layers = []
        mid_layers += [ConvBlock(nf_up,nf_mid, bn=True, actn=False)]
        
        for i in range(8): 
            mid_layers.append(ConvBlock(nf_mid, nf_mid))
            mid_layers.append(ConvBlock(nf_mid, nf_mid, bn=False, actn=False))
            
        mid_layers += [ConvBlock(nf_mid,nf_mid, actn=False), 
                       ConvBlock(nf_mid, 3, bn=False, actn=False)]
        self.upconv = nn.Sequential(*mid_layers)
             
        out_layers = []
        out_layers += [ConvBlock(6, 3, ks=1, bn=False, actn=False)]
        self.out = nn.Sequential(*out_layers)
        
    def forward(self, x): 
        self.rn(x)
        x1 = self.up1(self.sfs[3].features)
        x2 = self.up2(self.sfs[2].features)
        x3 = self.up3(self.sfs[1].features)
        x4 = self.up4(self.sfs[0].features) 
        x5 = self.upconv(torch.cat([x, x1, x2, x3, x4], dim=1))
        return F.tanh(self.out(torch.cat([x, x5], dim=1)))

## Training

In [None]:
netG = ImageModifierModel().cuda()
#load_model(netG, gpath)
netD = FeatureCritic(sz).cuda()
#load_model(netD, dpath)

In [None]:
md = get_matched_image_model_data(image_size=sz, batch_size=bs, root_data_path=DATA_PATH, train_root_path=TRAIN_SOURCE_PATH, proj_id=proj_id, keep_pct=keep_pct)

In [None]:
trainer = WGANTrainer(netD=netD, netG=netG, md=md, bs=bs, sz=sz, dpath=dpath, gpath=gpath, lr=lr)

In [None]:
trainer.train(1)