In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
import multiprocessing
from fastai.conv_learner import *
from fasterai.images import *
from fasterai.dataset import *
from fasterai.visualize import *
from fasterai.loss import *
from fasterai.models import *
from pathlib import Path
from itertools import repeat
torch.cuda.set_device(3)
plt.style.use('dark_background')
torch.backends.cudnn.benchmark=True

  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  from numpy.core.umath_tests import inner1d


In [3]:
DATA_PATH = Path('data/imagenet/ILSVRC/Data/CLS-LOC')
TRAIN_SOURCE_PATH = DATA_PATH/'train'
proj_id = 'bw2color'
wd=1e-7
keep_pct=0.20

## Model

##### TODO:  Try using unet instead of SrResnet- also suspect that using pretrained model as base will work much better.
##### TODO:  Also try making the loss/output based on "classification" like in Zhang et al.
##### TODO:  After making unet version- plug that into a Weiserman GAN setup (the discrimator looks at grey image and colorized image, concatenated together via channels).
##### TODO:  Try using higher res images (from FloydHub blog?)
##### TODO:  Why freezeto(1)...?  Seems like we should be training everything after the model cut at the very least....
##### TODO:  Try perceptual loss again....

In [4]:
class Unet34(nn.Module):  
    @staticmethod
    def _generate_base_model():
        f = resnet34
        cut,lr_cut = model_meta[f]
        layers = cut_model(f(True), cut)
        return nn.Sequential(*layers), lr_cut
    
    def __init__(self, nf=256):
        super().__init__()
        rn, lr_cut = Unet34._generate_base_model()
        self.rn = rn
        self.lr_cut = lr_cut
        self.sfs = [SaveFeatures(rn[i]) for i in [2,4,5,6]]
        self.up1 = UnetBlock(512,256,nf)
        self.up2 = UnetBlock(nf,128,nf)
        self.up3 = UnetBlock(nf,64,nf)
        self.up4 = UnetBlock(nf,64,nf)
        self.up5 = UpSampleBlock(nf, nf, 2)      
        self.out= nn.Sequential(nn.BatchNorm2d(nf), ConvBlock(nf, 3, ks=1, actn=False, bn=False))
           
    def forward(self, x):
        x = F.relu(self.rn(x))
        x = self.up1(x, self.sfs[3].features)
        x = self.up2(x, self.sfs[2].features)
        x = self.up3(x, self.sfs[1].features)
        x = self.up4(x, self.sfs[0].features)
        x = self.up5(x)
        x = self.out(x)
        return x
    
    def close(self):
        for sf in self.sfs: 
            sf.remove()

In [5]:
class UnetModel():
    def __init__(self):
        self.model = to_gpu(Unet34())
        self.name = 'unet'

    def get_layer_groups(self, precompute):
        lgs = list(split_by_idxs(children(self.model.rn), [self.model.lr_cut]))
        return lgs + [children(self.model)[1:]]

## Training

In [6]:
def train(lrs, session_num: int, cycle_len=2, use_clr_beta=(20,10,0.95,0.85)):
    if session_num > 0:
        learn.load(proj_id + '_1_' + str(session_num - 1))
    learn.fit(lrs, 1, cycle_len=cycle_len, wds=wd, use_clr_beta=use_clr_beta)
    learn.save(proj_id + '_1_' + str(session_num))

In [7]:
md = get_matched_image_model_data(image_size=64, batch_size=256, root_data_path=DATA_PATH, train_root_path=TRAIN_SOURCE_PATH, proj_id=proj_id, keep_pct=keep_pct)

In [8]:
#m = SrResnet(64, 1)
unet = UnetModel()
learn = ConvLearner(md, unet)
learn.metrics = []
learn.opt_fn=optim.Adam
#learn.crit = F.mse_loss #(turns sepia/blurry)
learn.crit = FeatureLoss()
#learn.crit = F.l1_loss

  subkernel = init(subkernel)
  subkernel = init(subkernel)
  subkernel = init(subkernel)
  subkernel = init(subkernel)
  subkernel = init(subkernel)


In [9]:
learn.freeze_to(1)

In [None]:
learn.lr_find(1e-4, 1e1, wds=wd, linear=False)

HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

 29%|██▊       | 272/949 [15:12<37:50,  3.35s/it, loss=0.0759]

In [None]:
visualize_image_gen_model(md, unet.model, 220, 8, immediate_display=False)

In [None]:
learn.sched.plot(n_skip=0, n_skip_end=0)

In [None]:
lr=8e-2
lrs = np.array([lr/100,lr/10,lr])

In [None]:
train(lr,0,cycle_len=2, use_clr_beta=(5,8,0.95,0.85)) 
visualize_image_gen_model(md, unet.model, 220, 8, immediate_display=False)

In [None]:
learn.unfreeze()
learn.bn_freeze(True)

In [None]:
train(lrs/4,1,cycle_len=2,use_clr_beta=(20,10,0.95,0.85)) 
visualize_image_gen_model(md, unet.model, 220, 8, immediate_display=False)

## 128 x 128

In [None]:
learn.freeze_to(1)

In [None]:
md = get_matched_image_model_data(image_size=128, batch_size=64, root_data_path=DATA_PATH, train_root_path=TRAIN_SOURCE_PATH, proj_id=proj_id, keep_pct=keep_pct)
learn.set_data(md)

In [None]:
lr=lr/8
lrs = np.array([lr/100,lr/10,lr])

In [None]:
train(lrs,2,cycle_len=2, use_clr_beta=(5,5,0.95,0.85)) 
visualize_image_gen_model(md, unet.model, 220, 8, immediate_display=False)

In [None]:
visualize(40,64, figsize=(20,160), immediate_display=False)

In [None]:
learn.unfreeze()
learn.bn_freeze(True)

In [None]:
train(lrs/4,3,cycle_len=2, use_clr_beta=(20,8,0.95,0.85)) 
visualize_image_gen_model(md, unet.model, 220, 8, immediate_display=False)

## 224 x 224

In [None]:
learn.freeze_to(1)

In [None]:
md = get_matched_image_model_data(image_size=224, batch_size=16, root_data_path=DATA_PATH, train_root_path=TRAIN_SOURCE_PATH, proj_id=proj_id, keep_pct=keep_pct)
learn.set_data(md)

In [None]:
learn.lr_find(1e-4, 1e-1, wds=wd, linear=False)

In [None]:
learn.sched.plot(n_skip=0, n_skip_end=0)

In [None]:
#lr=lr/8
lr = 2e-2
lrs = np.array([lr/100,lr/10,lr])

In [None]:
train(lrs,0,cycle_len=2, use_clr_beta=(5,4,0.95,0.85)) 
visualize_image_gen_model(md, unet.model, 220, 8, immediate_display=False)

In [None]:
learn.unfreeze()
learn.bn_freeze(True)

In [None]:
#train(lrs/10,1,cycle_len=2,use_clr_beta=(20,8,0.95,0.85)) 
train(lrs/10,1,cycle_len=2,use_clr_beta=(20,8,0.95,0.85)) 
visualize_image_gen_model(md, unet.model, 220, 8, immediate_display=False)

In [None]:
#train(lrs/20,2,cycle_len=2,use_clr_beta=(20,8,0.95,0.85)) 
train(lrs/20,2,cycle_len=2,use_clr_beta=(20,8,0.95,0.85)) 
visualize_image_gen_model(md, unet.model, 220, 8, immediate_display=False)

In [None]:
visualize(40,64, figsize=(20,160))

In [None]:
train(lrs/40,3,cycle_len=2,use_clr_beta=(20,8,0.95,0.85)) 
visualize_image_gen_model(md, unet.model, 220, 8, immediate_display=False)

In [None]:
visualize(200,64, figsize=(20,160))

In [None]:
train(lrs/80,4,cycle_len=2,use_clr_beta=(20,8,0.95,0.85)) 
visualize_image_gen_model(md, unet.model, 220, 8, immediate_display=False)

In [None]:
visualize_image_gen_model(md, unet.model, 40, 64, figsize=(20,160), immediate_display=False)

In [None]:
train(lrs/160,5,cycle_len=2,use_clr_beta=(20,8,0.95,0.85)) 
visualize_image_gen_model(md, unet.model, 220, 8, immediate_display=False)

In [None]:
visualize_image_gen_model(md, unet.model, 40, 64, figsize=(20,160), immediate_display=False)