In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

## Black and White to Color Data

In [2]:
import multiprocessing
from fastai.conv_learner import *
from pathlib import Path
from itertools import repeat
plt.style.use('dark_background')
torch.backends.cudnn.benchmark=True

In [3]:
PATH = Path('data/imagenet/ILSVRC/Data/CLS-LOC')
PATH_TRN = PATH/'train'

In [4]:
sz = 128

In [5]:
train_resized = PATH/('train_' + str(sz))

In [6]:
train_resized_bw = PATH/('train_' + str(sz) + '_bw')

In [7]:
def generate_dest_path(sourceroot: Path, sourcepath: Path, destroot: Path):
    relativepath = sourcepath.relative_to(sourceroot)
    destpath = destroot/relativepath
    return destpath

In [8]:
def dest_path_generator(sourceroot: Path, raw_sourcepaths: [Path], destroot: Path):
    return (generate_dest_path(sourceroot=sourceroot, sourcepath=sourceroot.parent/Path(raw_sourcepath), destroot=destroot) for raw_sourcepath in raw_sourcepaths)

In [9]:
def generate_folders_for_dest(destpaths: [Path]):
    destdirs = set(destpath.parent for destpath in destpaths)
    
    for destdir in destdirs:
        destdir.mkdir(parents=True, exist_ok=True)
    

In [10]:
def transform_image_and_save_new(function, sourcepath: Path, destpath: Path):
    try:
        with Image.open(sourcepath) as image:
            image = function(image)
            image.save(destpath)
    except Exception as ex:
        print(ex)
    

In [14]:
def transform_images_to_new_directory(function, sourceroot: Path, destroot: Path):
    destroot.mkdir(exist_ok=True)
    raw_sourcepaths, _, _ = folder_source(sourceroot.parent, sourceroot.name)
    #First make the destination directories if they don't already exist- we want the subsequent operations to be threadsafe.  Then create
    #another generator of destpaths for use in the image generation
    generate_folders_for_dest(destpaths=dest_path_generator(sourceroot=sourceroot, raw_sourcepaths=raw_sourcepaths, destroot=destroot))   
    destpaths = dest_path_generator(sourceroot=sourceroot, raw_sourcepaths=raw_sourcepaths, destroot=destroot)
    sourcepaths = (sourceroot.parent/Path(raw_sourcepath) for raw_sourcepath in raw_sourcepaths)
    numthreads = multiprocessing.cpu_count()
    
    with ThreadPoolExecutor(numthreads) as e:
        try:
            e.map(partial(transform_image_and_save_new, function), sourcepaths, destpaths)
        except Exception as ex:
            print(ex)

## Resize Images

In [15]:
def resize_image(image: Image, size: int):
    return image.resize((size,size))

In [16]:
transform_images_to_new_directory(partial(resize_image, size=sz), PATH_TRN, train_resized)

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


cannot write mode RGBA as JPEG


  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


## Generate Black and White Versions of Images

In [17]:
def to_grayscale_image(image: Image):
    return image.convert('L')

In [None]:
transform_images_to_new_directory(to_grayscale_image, train_resized, train_resized_bw)

## Model

In [None]:
fnames_full_x,label_arr_full_x,all_labels_x = folder_source(train_resized_bw.parent, train_resized_bw.name)
fnames_full_x = ['/'.join(Path(fn).parts[-2:]) for fn in fnames_full_x]
list(zip(fnames_full_x[:5],label_arr_full_x[:5]))

In [None]:
all_labels_x[:5]

In [None]:
fnames_full_y,label_arr_full_y,all_labels_y = folder_source(train_resized.parent, train_resized.name)
fnames_full_y = ['/'.join(Path(fn).parts[-2:]) for fn in fnames_full_y]
list(zip(fnames_full_y[:5],label_arr_full_y[:5]))

In [None]:
all_labels_y[:5]

In [None]:
bs = 64
np.random.seed(42)
#keep_pct = 1.
keep_pct = 0.02
keeps = np.random.rand(len(fnames_full_x)) < keep_pct
fnames_x = np.array(fnames_full_x, copy=False)[keeps]
label_arr_x = np.array(label_arr_full_x, copy=False)[keeps]
fnames_y = np.array(fnames_full_y, copy=False)[keeps]

In [None]:
arch = vgg16

In [None]:
class MatchedFilesDataset(FilesDataset):
    def __init__(self, fnames_x, fnames_y, transform, path):
        self.y=fnames_y
        assert(len(fnames_x)==len(fnames_y))
        super().__init__(fnames_x, transform, path)
    def get_y(self, i): 
        return open_image(os.path.join(self.path, self.y[i]))
    def get_c(self): return 0

In [None]:
aug_tfms = [RandomDihedral(tfm_y=TfmType.PIXEL)]

In [None]:
val_idxs = get_cv_idxs(len(fnames_x), val_pct=min(0.01/keep_pct, 0.1))
((val_x,trn_x),(val_y,trn_y)) = split_by_idx(val_idxs, np.array(fnames_x), np.array(fnames_y))
len(val_x),len(trn_x)

In [None]:
img_fn = train_resized_bw/'n01558993'/'n01558993_9684.JPEG'

In [None]:
tfms = tfms_from_model(arch, sz, tfm_y=TfmType.PIXEL, aug_tfms=aug_tfms, sz_y=sz)
datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=train_resized_bw.parent)
md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)

In [None]:
denorm = md.val_ds.denorm

In [None]:
def show_img(ims, idx, figsize=(5,5), normed=True, ax=None):
    if ax is None: fig,ax = plt.subplots(figsize=figsize)
    if normed: ims = denorm(ims)
    else:      ims = np.rollaxis(to_np(ims),1,4)
    ax.imshow(np.clip(ims,0,1)[idx])
    ax.axis('off')

In [None]:
x,y = next(iter(md.val_dl))
x.size(),y.size()

In [None]:
idx=1
fig,axes = plt.subplots(1, 2, figsize=(9,5))
show_img(x,idx, ax=axes[0])
show_img(y,idx, ax=axes[1])

In [None]:
batches = [next(iter(md.aug_dl)) for i in range(9)]

In [None]:
fig, axes = plt.subplots(3, 6, figsize=(18, 9))
for i,(x,y) in enumerate(batches):
    show_img(x,idx, ax=axes.flat[i*2])
    show_img(y,idx, ax=axes.flat[i*2+1])

## Model

In [None]:
def conv(ni, nf, kernel_size=3, actn=False):
    layers = [nn.Conv2d(ni, nf, kernel_size, padding=kernel_size//2)]
    if actn: layers.append(nn.ReLU(True))
    return nn.Sequential(*layers)

In [None]:
class ResSequential(nn.Module):
    def __init__(self, layers, res_scale=1.0):
        super().__init__()
        self.res_scale = res_scale
        self.m = nn.Sequential(*layers)

    def forward(self, x): return x + self.m(x) * self.res_scale

In [None]:
def res_block(nf):
    return ResSequential(
        [conv(nf, nf, actn=True), conv(nf, nf)],
        0.1)

In [None]:
def upsample(ni, nf, scale):
    layers = []
    for i in range(int(math.log(scale,2))):
        layers += [conv(ni, nf*(scale**2)), nn.PixelShuffle(2)]
    return nn.Sequential(*layers)

In [None]:
class SrResnet(nn.Module):
    def __init__(self, nf, scale):
        super().__init__()
        features = [conv(3, 64)]
        for i in range(8): features.append(res_block(64))
        features += [conv(64,64), upsample(64, 64, scale),
                     nn.BatchNorm2d(64),
                     conv(64, 3)]
        self.features = nn.Sequential(*features)
        
    def forward(self, x): return self.features(x)

## Perceptual loss

In [None]:
def icnr(x, scale=2, init=nn.init.kaiming_normal):
    new_shape = [int(x.shape[0] / (scale ** 2))] + list(x.shape[1:])
    subkernel = torch.zeros(new_shape)
    subkernel = init(subkernel)
    subkernel = subkernel.transpose(0, 1)
    subkernel = subkernel.contiguous().view(subkernel.shape[0],
                                            subkernel.shape[1], -1)
    kernel = subkernel.repeat(1, 1, scale ** 2)
    transposed_shape = [x.shape[1]] + [x.shape[0]] + list(x.shape[2:])
    kernel = kernel.contiguous().view(transposed_shape)
    kernel = kernel.transpose(0, 1)
    return kernel

In [None]:
m_vgg = vgg16(True)

blocks = [i-1 for i,o in enumerate(children(m_vgg))
              if isinstance(o,nn.MaxPool2d)]
blocks, [m_vgg[i] for i in blocks]

In [None]:
vgg_layers = children(m_vgg)[:23]
m_vgg = nn.Sequential(*vgg_layers).cuda().eval()
set_trainable(m_vgg, False)

In [None]:
m_vgg

In [None]:
def flatten(x): return x.view(x.size(0), -1)

In [None]:
class SaveFeatures():
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = output
    def remove(self): self.hook.remove()        

In [None]:
class FeatureLoss(nn.Module):
    def __init__(self, m, layer_ids, layer_wgts):
        super().__init__()
        self.m,self.wgts = m,layer_wgts
        self.sfs = [SaveFeatures(m[i]) for i in layer_ids]

    def forward(self, input, target, sum_layers=True):
        self.m(VV(target.data))
        res = [F.l1_loss(input,target)/100]
        targ_feat = [V(o.features.data.clone()) for o in self.sfs]
        self.m(input)
        res += [F.l1_loss(flatten(inp.features),flatten(targ))*wgt
               for inp,targ,wgt in zip(self.sfs, targ_feat, self.wgts)]
        if sum_layers: res = sum(res)
        return res
    
    def close(self):
        for o in self.sfs: o.remove()

In [None]:
m = SrResnet(64, scale)

In [None]:
conv_shuffle = m.features[10][0][0]
kernel = icnr(conv_shuffle.weight, scale=scale)
conv_shuffle.weight.data.copy_(kernel);

In [None]:
m = to_gpu(m)

In [None]:
learn = Learner(md, SingleModel(m), opt_fn=optim.Adam)

In [None]:
m = nn.DataParallel(m, [0,1,2,3])
learn = Learner(md, SingleModel(m), opt_fn=optim.Adam)

In [None]:
learn.set_data(md)

In [None]:
learn.crit = FeatureLoss(m_vgg, blocks[:3], [0.2,0.7,0.1])

In [None]:
learn.lr_find(1e-4, 0.1, wds=wd, linear=False)

In [None]:
learn.sched.plot(n_skip_end=1)

In [None]:
lr=3e-3
wd=1e-7

In [None]:
learn.fit(lr, 1, cycle_len=2, wds=wd, use_clr_beta=(20,10,0.95,0.85))

In [None]:
learn.save('bwtoc0')

In [None]:
learn.load('bwtoc0')

In [None]:
lr=lr/2

In [None]:
learn.fit(lr, 1, cycle_len=1, wds=wd, use_clr_beta=(20,10,0.95,0.85))

In [None]:
learn.save('bwtoc1')

In [None]:
learn.load('bwtoc1')

In [None]:
learn.unfreeze()

In [None]:
learn.fit(lr/3, 1, cycle_len=1, wds=wd, use_clr_beta=(20,10,0.95,0.85))

In [None]:
learn.save('bwtoc2')

In [None]:
learn.load('bwtoc2')

In [None]:
learn.sched.plot_loss()

In [None]:
def plot_ds_img(idx, ax=None, figsize=(7,7), normed=True):
    if ax is None: fig,ax = plt.subplots(figsize=figsize)
    im = md.val_ds[idx][0]
    if normed: im = denorm(im)[0]
    else:      im = np.rollaxis(to_np(im),0,3)
    ax.imshow(im)
    ax.axis('off')

In [None]:
fig,axes=plt.subplots(6,6,figsize=(20,20))
for i,ax in enumerate(axes.flat): plot_ds_img(i+200,ax=ax, normed=True)

In [None]:
x,y=md.val_ds[211]

In [None]:
y=y[None]

In [None]:
learn.model.eval()
preds = learn.model(VV(x[None]))
x.shape,y.shape,preds.shape

In [None]:
learn.crit(preds, V(y), sum_layers=False)

In [None]:
_,axes=plt.subplots(1,2,figsize=(14,7))
show_img(x[None], 0, ax=axes[0])
show_img(preds,0, normed=True, ax=axes[1])

In [None]:
bs = 32
sz = 256

In [None]:
t = torch.load(learn.get_model_path('bwtoc2'), map_location=lambda storage, loc: storage)
learn.model.load_state_dict(t, strict=False)

In [None]:
learn.freeze_to(999)

In [None]:
for i in range(10,13): set_trainable(learn.model.module.features[i], True)

In [None]:
conv_shuffle = learn.model.module.features[10][2][0]
kernel = icnr(conv_shuffle.weight, scale=scale)
conv_shuffle.weight.data.copy_(kernel);

In [None]:
lr=6e-3
wd=1e-7

In [None]:
learn.fit(lr, 1, cycle_len=2, wds=wd, use_clr_beta=(20,10,0.95,0.85))

In [None]:
learn.save('bwtoc3')

In [None]:
learn.load('bwtoc3')

In [None]:
learn.fit(lr/2, 1, cycle_len=2, wds=wd, use_clr_beta=(20,10,0.95,0.85))

In [None]:
learn.save('bwtoc4')

In [None]:
learn.load('bwtoc4')

In [None]:
learn.fit(lr/6, 1, cycle_len=1, wds=wd, use_clr=(20,10))

In [None]:
learn.save('bwtoc5')

In [None]:
learn.load('bwtoc5')

In [None]:
train_tfms,val_tfms = tfms_from_model(arch, sz)

In [None]:
image_root = 'data/style/'
image_name = 'csi_enhance'
img_fn = f'{image_root}{image_name}.jpg'
img = open_image(img_fn)
img_tfm = val_tfms(img)
preds = learn.model(VV(img_tfm[None]))

In [None]:
_,axes=plt.subplots(1,2,figsize=(14,7))
show_img(img_tfm[None], 0, ax=axes[0])
show_img(preds,0, normed=True, ax=axes[1])