In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
import multiprocessing
from fastai.conv_learner import *
from fasterai.images import *
from fasterai.dataset import *
from fasterai.visualize import *
from pathlib import Path
from itertools import repeat
torch.cuda.set_device(2)
plt.style.use('dark_background')
torch.backends.cudnn.benchmark=True

In [None]:
DATA_PATH = Path('data/imagenet/ILSVRC/Data/CLS-LOC')
TRAIN_SOURCE_PATH = DATA_PATH/'train'
uid = 'bw2color'
keep_pct=0.1
#keep_pct=0.1
#keep_pct=0.005
#keep_pct=1.0

In [None]:
def get_model_data(image_size: int, batch_size: int, keep_pct: float):
    TRAIN_X_PATH = generate_image_preprocess_path(TRAIN_SOURCE_PATH, is_x=True, size=image_size, uid=uid)
    TRAIN_Y_PATH = generate_image_preprocess_path(TRAIN_SOURCE_PATH, is_x=False, size=image_size, uid=uid)
    x_paths, y_paths = get_matched_xy_file_lists(TRAIN_X_PATH, TRAIN_Y_PATH)
    x_paths_str = convert_paths_to_str(x_paths)
    y_paths_str = convert_paths_to_str(y_paths)
    print(x_paths_str[:5])
    print(y_paths_str[:5])
    np.random.seed(42)
    keeps = np.random.rand(len(x_paths_str)) < keep_pct
    fnames_x = np.array(x_paths_str, copy=False)[keeps]
    fnames_y = np.array(y_paths_str, copy=False)[keeps]
    val_idxs = get_cv_idxs(len(fnames_x), val_pct=min(0.01/keep_pct, 0.1))
    ((val_x,trn_x),(val_y,trn_y)) = split_by_idx(val_idxs, np.array(fnames_x), np.array(fnames_y))
    img_fn = TRAIN_Y_PATH/'n01558993'/'n01558993_9684.JPEG'
    tfms = tfms_from_model(vgg16, image_size, tfm_y=TfmType.PIXEL, aug_tfms=transforms_side_on, sz_y=image_size)
    datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=TRAIN_Y_PATH.parent)
    md = ImageData(DATA_PATH, datasets, batch_size, num_workers=16, classes=None)
    return md

In [None]:
md = get_model_data(image_size=64, batch_size=24, keep_pct=keep_pct)
denorm = md.val_ds.denorm

In [None]:
def generate_denormed_image_pairs(ds: FilesDataset, batches: [(ndarray,ndarray)], idx:int = 0):
    return [(ds.denorm(x)[idx], ds.denorm(y)[idx])for (x,y) in batches]

## Model

##### TODO:  Also try making the loss/output based on "classification" like in Zhang et al.
##### TODO:  After making unet version- plug that into a Weiserman GAN setup (the discrimator looks at grey image and colorized image, concatenated together via channels).
##### TODO:  Try using higher res images (from FloydHub blog?)
##### TODO:  Try perceptual loss again....
##### TODO:  To convert real old photos, could force them to normal grayscale first.

In [None]:
def icnr(x, scale=2, init=nn.init.kaiming_normal):
    new_shape = [int(x.shape[0] / (scale ** 2))] + list(x.shape[1:])
    subkernel = torch.zeros(new_shape)
    subkernel = init(subkernel)
    subkernel = subkernel.transpose(0, 1)
    subkernel = subkernel.contiguous().view(subkernel.shape[0],
                                            subkernel.shape[1], -1)
    kernel = subkernel.repeat(1, 1, scale ** 2)
    transposed_shape = [x.shape[1]] + [x.shape[0]] + list(x.shape[2:])
    kernel = kernel.contiguous().view(transposed_shape)
    kernel = kernel.transpose(0, 1)
    return kernel

In [None]:
class SaveFeatures():
    features=None
    def __init__(self, m): 
        self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): 
        self.features = output
    def remove(self): 
        self.hook.remove()

In [None]:
class FeatureLoss(nn.Module):
    def __init__(self, block_wgts: [float] = [0.2,0.7,0.1]):
        super().__init__()
        m_vgg = vgg16(True)
        
        blocks = [i-1 for i,o in enumerate(children(m_vgg))
              if isinstance(o,nn.MaxPool2d)]
        blocks, [m_vgg[i] for i in blocks]
        layer_ids = blocks[:3]
        
        vgg_layers = children(m_vgg)[:23]
        m_vgg = nn.Sequential(*vgg_layers).cuda().eval()
        set_trainable(m_vgg, False)
        
        self.m,self.wgts = m_vgg,block_wgts
        self.sfs = [SaveFeatures(m_vgg[i]) for i in layer_ids]

    def forward(self, input, target, sum_layers=True):
        self.m(VV(target.data))
        res = [F.l1_loss(input,target)/100]
        targ_feat = [V(o.features.data.clone()) for o in self.sfs]
        self.m(input)
        res += [F.l1_loss(self.flatten(inp.features),self.flatten(targ))*wgt
               for inp,targ,wgt in zip(self.sfs, targ_feat, self.wgts)]
        if sum_layers: res = sum(res)
        return res
    
    def flatten(self, x): 
        return x.view(x.size(0), -1)
    
    def close(self):
        for o in self.sfs: o.remove()

In [None]:
def conv(ni, nf, kernel_size=3, actn=False):
    layers = [nn.Conv2d(ni, nf, kernel_size, padding=kernel_size//2)]
    if actn: layers.append(nn.ReLU(True))
    return nn.Sequential(*layers)

In [None]:
class UpSampleBlock(nn.Module):
    def __init__(self, ni, nf, scale=2):
        super().__init__()
        layers = []
        for i in range(int(math.log(scale,2))):
            layers += [conv(ni, nf*4), nn.PixelShuffle(2)]
        self.sequence = nn.Sequential(*layers)
        self.icnr_init()
        
    def icnr_init(self):
        conv_shuffle = self.sequence[0][0]
        kernel = icnr(conv_shuffle.weight)
        conv_shuffle.weight.data.copy_(kernel);
    
    def forward(self, x):
        return self.sequence(x)

In [None]:
class ResSequential(nn.Module):
    def __init__(self, layers, res_scale=1.0):
        super().__init__()
        self.res_scale = res_scale
        self.m = nn.Sequential(*layers)

    def forward(self, x): return x + self.m(x) * self.res_scale

In [None]:
def res_block(nf):
    return ResSequential(
        [conv(nf, nf, actn=True), conv(nf, nf)], 0.1)

In [None]:
class ImageModifierModel(nn.Module):
    @staticmethod
    def generate_base_model():
        f = resnet34
        cut,lr_cut = model_meta[f]
        layers = cut_model(f(True), cut)
        return nn.Sequential(*layers), lr_cut

    def set_trainable(self, trainable):
        set_trainable(self, trainable)
        set_trainable(self.rn, False)
    
    def __init__(self, nf_up=64, nf_mid=256):
        super().__init__() 
        rn, lr_cut = ImageModifierModel.generate_base_model()
        self.rn = rn
        set_trainable(rn, False)
        self.lr_cut = lr_cut
        self.sfs = [SaveFeatures(rn[i]) for i in [2,4,5,6]]
        
        self.up1 = nn.Sequential(*[UpSampleBlock(256,nf_up, 2), UpSampleBlock(nf_up,nf_up, 8)])  #256 in
        self.up2 = nn.Sequential(*[UpSampleBlock(128, nf_up, 2), UpSampleBlock(nf_up,nf_up, 4)])  #128 in
        self.up3 = nn.Sequential(*[UpSampleBlock(64,nf_up), UpSampleBlock(nf_up,nf_up, 2)])    #64 in
        self.up4 = UpSampleBlock(64, nf_up)   #64 in  
        
        mid_layers = []
        mid_layers += [conv(nf_up * 4,nf_mid), nn.BatchNorm2d(nf_mid)]
        for i in range(8): mid_layers.append(res_block(nf_mid))
        mid_layers += [nn.BatchNorm2d(nf_mid), conv(nf_mid, 3, kernel_size=1)]
        self.upconv = nn.Sequential(*mid_layers)
             
        out_layers = []
        out_layers += [conv(6, 3, kernel_size=1)]
        self.out = nn.Sequential(*out_layers)
        
    def forward(self, x): 
        self.rn(x)
        x1 = self.up1(self.sfs[3].features)
        x2 = self.up2(self.sfs[2].features)
        x3 = self.up3(self.sfs[1].features)
        x4 = self.up4(self.sfs[0].features) 
        x5 = self.upconv(torch.cat([x1, x2, x3, x4], dim=1))
        #x5 = self.upconv(torch.cat([x2, x3, x4], dim=1))
        return F.tanh(self.out(torch.cat([x, x5], dim=1)))

In [None]:
class ImageModifierModelWrapper():
    def __init__(self):
        self.model = to_gpu(ImageModifierModel())
        self.name = 'imod'

    def get_layer_groups(self, precompute):
        lgs = list(split_by_idxs(children(self.model.rn), [self.model.lr_cut]))
        return lgs + [children(self.model)[1:]]

## Training

In [None]:
def train(lrs, session_num: int, cycle_len=2, use_clr_beta=(20,10,0.95,0.85)):
    if session_num > 0:
        learn.load(uid + '_2_' + str(session_num - 1))
    learn.fit(lrs, 1, cycle_len=cycle_len, wds=wd, use_clr_beta=use_clr_beta)
    learn.save(uid + '_2_' + str(session_num))

In [None]:
#m = SrResnet(64, 1)
#imod = ImageModifierModelWrapper()
imod = ImageModifierModelWrapper()
learn = ConvLearner(md, imod)
learn.metrics = []
learn.opt_fn=optim.Adam
learn.crit = F.mse_loss #(turns sepia/blurry)
#learn.crit = FeatureLoss()
#learn.crit = F.l1_loss
wd=1e-7
#learn.models.model = nn.DataParallel(learn.models.model, [0,1,2,3])

In [None]:
learn.freeze_to(1)

In [None]:
learn.lr_find(1e-3, 1e2, wds=wd, linear=False)

In [None]:
learn.sched.plot(n_skip=0, n_skip_end=0)

In [None]:
lr=5e-4
lrs = np.array([lr/100,lr/10,lr])

In [None]:
train(lr,0,cycle_len=2, use_clr_beta=(5,8,0.95,0.85)) 
visualize_image_gen_model(md, imod.model, 220, 8, immediate_display=False)

In [None]:
learn.unfreeze()
learn.bn_freeze(True)

In [None]:
train(lrs/4,1,cycle_len=2,use_clr_beta=(20,10,0.95,0.85)) 
visualize_image_gen_model(md, imod.model, 220, 8, immediate_display=False)

## 128 x 128

In [None]:
learn.freeze_to(1)

In [None]:
md = get_model_data(image_size=128, batch_size=64, keep_pct=keep_pct)
learn.set_data(md)

In [None]:
lr=lr/8
lrs = np.array([lr/100,lr/10,lr])

In [None]:
train(lrs,2,cycle_len=2, use_clr_beta=(5,5,0.95,0.85)) 
visualize_image_gen_model(220,8, immediate_display=False)

In [None]:
visualize_image_gen_model(md, imod.model, 40, 64, figsize=(20,160), immediate_display=False)

In [None]:
learn.unfreeze()
learn.bn_freeze(True)

In [None]:
train(lrs/4,3,cycle_len=2, use_clr_beta=(20,8,0.95,0.85)) 
visualize_image_gen_model(220,8, immediate_display=False)

## 224 x 224

In [None]:
learn.freeze_to(1)

In [None]:
md = get_model_data(image_size=224, batch_size=16, keep_pct=keep_pct)
learn.set_data(md)

In [None]:
learn.lr_find(1e-4, 1e1, wds=wd, linear=False)

In [None]:
learn.sched.plot(n_skip=0, n_skip_end=0)

In [None]:
#lr=lr/8
lr =1e-3
lrs = np.array([lr/100,lr/10,lr])

In [None]:
visualize_image_gen_model(md, imod.model, 40, 64, figsize=(20,160), immediate_display=False)

In [None]:
train(lrs/10,4,cycle_len=2, use_clr_beta=(20,8,0.95,0.85)) 
visualize_image_gen_model(220,8, immediate_display=False)