In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
import multiprocessing
from torch import autograd
from fastai.conv_learner import *
from fasterai.images import *
from fasterai.dataset import *
from fasterai.visualize import *
from fasterai.loss import *
from fasterai.modules import *
from fasterai.wgan import *
from fastai.torch_imports import *
from pathlib import Path
from itertools import repeat
import tensorboardX
torch.cuda.set_device(1)
plt.style.use('dark_background')
torch.backends.cudnn.benchmark=True


In [None]:
DATA_PATH = Path('data/imagenet/ILSVRC/Data/CLS-LOC')
TRAIN_SOURCE_PATH = DATA_PATH/'train'
proj_id = 'bw2color'

dpath = DATA_PATH/(proj_id + '_8212_D.h5')
gpath = DATA_PATH/(proj_id + '_8212_G.h5')
keep_pct=1.0
wd=1e-7
bs=8
sz=224
lr=1e-4
torch.backends.cudnn.benchmark=True

## EDSR Model

##### TODO:  Also try making the loss/output based on "classification" like in Zhang et al.
##### TODO:  After making unet version- plug that into a Weiserman GAN setup (the discrimator looks at grey image and colorized image, concatenated together via channels).
##### TODO:  Try using higher res images (from FloydHub blog?)
##### TODO:  Try perceptual loss again....
##### TODO:  To convert real old photos, could force them to normal grayscale first.
##### TODO:  Add tensorboard graphs

In [None]:
class ImageModifierModel(nn.Module):
    @staticmethod
    def _generate_res_layer(ni: int):
        layers = []
        layers.append(ConvBlock(ni, ni//2, ks=1))
        layers.append(ConvBlock(ni//2, ni))
        return ResSequential(layers)
        
    def set_trainable(self, trainable):
        set_trainable(self, trainable)
        set_trainable(self.rn, False)
        
    def __init__(self):
        super().__init__() 
        
        self.rn, _ = get_pretrained_resnet_base(1)
        set_trainable(self.rn, False)
        
        self.color = nn.Sequential(
            ConvBlock(256, 512),
            ImageModifierModel._generate_res_layer(512),
            UpSampleBlock(512, 512, 16),
            ConvBlock(512, 3)
        )
        
        self.out = nn.Sequential(
            ConvBlock(6, 12),
            ImageModifierModel._generate_res_layer(12),
            ConvBlock(12,3, actn=False, bn=False)
        )
        
    def forward(self, orig): 
        x = self.rn(orig)
        x = self.color(x)
        return F.tanh(self.out(torch.cat([orig, x], dim=1)))

## Training

In [None]:
netG = ImageModifierModel().cuda()
#load_model(netG, gpath)
netD = FeatureCritic(sz).cuda()
#load_model(netD, dpath)

In [None]:
md = get_matched_image_model_data(image_size=sz, batch_size=bs, root_data_path=DATA_PATH, train_root_path=TRAIN_SOURCE_PATH, proj_id=proj_id, keep_pct=keep_pct)

In [None]:
trainer = WGANTrainer(netD=netD, netG=netG, md=md, bs=bs, sz=sz, dpath=dpath, gpath=gpath, lr=lr)

In [None]:
trainer.train(1)