## Super resolution

In [None]:
from fastai.vision.all import *
import torch
import time
now = str(time.time())

In [None]:
path = Path('/home/lleonard/Documents/datasets/danbooru/0/danbooru2020/')
path_hr = path/'512px'
path_lr = path/'96px'
path_mr = path/'256px'

In [None]:
items = get_image_files(path_hr)

In [None]:
bs,size=32,128
arch = resnet50

In [None]:
def short_id(full_id):
    return str(int(full_id) % 1000).zfill(4)

def short_id_from_filename(filename):
    return short_id(int(filename.name[:-4]))

def get_y(x):
    parent_folder = short_id_from_filename(x)
    return path_hr / parent_folder / x.name

def get_subset_images(start_folder, end_folder):
    def f(path):
        return [f for f in get_image_files(path) if short_id_from_filename(f) <=end_folder and short_id_from_filename(f) >= start_folder]
    return f

def get_dls(bs,size):
    dblock = DataBlock(blocks=(ImageBlock, ImageBlock),
                   get_items=get_subset_images('0000','0001'),
                   get_y=get_y,
                   splitter=RandomSplitter(),
                   item_tfms=Resize(size),
                   batch_tfms=[*aug_transforms(max_zoom=2.), Normalize()])
    dls = dblock.dataloaders(path_lr, bs=bs, path=path, item_tfms=Resize(size))         
    dls.c = 3
    return dls

In [None]:
dls = get_dls(bs,size)

In [None]:
dls.train.show_batch(max_n=4, figsize=(18,9))

## Feature loss

In [None]:
t = tensor(dls.valid_ds[0][1]).float().permute(2,0,1)/255.
t = torch.stack([t,t])

In [None]:
def gram_matrix(x):
    n,c,h,w = x.size()
    x = x.view(n, c, -1)
    return (x @ x.transpose(1,2))/(c*h*w)

In [None]:
t.shape

In [None]:
gram_matrix(t)

In [None]:
base_loss = F.l1_loss

In [None]:
data = torch.load('./models/danbooru_vgg_rating_classifier.pth')
vgg_m = data['model'].eval()

In [None]:
blocks = [i-1 for i,o in enumerate(vgg_m.children()) if isinstance(o,nn.MaxPool2d)]
blocks, [vgg_m[i] for i in blocks]

In [None]:
class FeatureLoss(Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        self.m_feat = m_feat
        self.loss_features = [m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
              ] + [f'gram_{i}' for i in range(len(layer_ids))]

    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, input, target, reduction='mean'):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        self.feat_losses = [base_loss(input,target,reduction=reduction)]
        self.feat_losses += [base_loss(f_in, f_out,reduction=reduction)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out),reduction=reduction)*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        if reduction=='none': 
            self.feat_losses = [f.mean(dim=[1,2,3]) for f in self.feat_losses[:4]] + [f.mean(dim=[1,2]) for f in self.feat_losses[4:]]
        for n,l in zip(self.metric_names, self.feat_losses): setattr(self, n, l)
        return sum(self.feat_losses)
    
    def __del__(self): self.hooks.remove()

In [None]:
feat_loss = FeatureLoss(vgg_m, blocks[2:5], [5,15,2])

## Train

In [None]:
learn = unet_learner(dls, arch, loss_func=feat_loss, metrics=LossMetrics(feat_loss.metric_names),
                     blur=True, norm_type=NormType.Weight)

In [None]:
learn.lr_find()

In [None]:
lr = 3e-3
wd = 1e-3

In [None]:
def do_fit(save_name, lrs=slice(lr), pct_start=0.9):
    learn.fit_one_cycle(10, lrs, pct_start=pct_start, wd=wd)
    learn.save(save_name + '_' + now)
    learn.show_results(max_n=2, figsize=(15,11))

In [None]:
do_fit('danbooru_custom_classifier_1a', slice(lr*10))

In [None]:
learn.show_results(ds_idx=1, max_n=2, figsize=(15,11))

In [None]:
learn.unfreeze()

In [None]:
do_fit('danbooru_custom_classifier_1b', slice(1e-5,lr))

In [None]:
dls = get_dls(12,size*2)

In [None]:
learn.dls = dls
learn.freeze()

In [None]:
learn.load('danbooru_custom_classifier_1b');

In [None]:
do_fit('danbooru_custom_classifier_2a')

In [None]:
learn.unfreeze()

In [None]:
do_fit('danbooru_custom_classifier_2b', slice(1e-6,1e-4), pct_start=0.3)

In [None]:
learn.show_results(ds_idx=1, max_n=2, figsize=(15,11))

## Test

In [None]:
dls = get_dls(1,size)
loaded_learn = unet_learner(dls, arch, loss_func=feat_loss, metrics=LossMetrics(feat_loss.metric_names),
                     blur=True, norm_type=NormType.Weight)
dls = get_dls(1,size * 2)
loaded_learn.dls = dls

In [None]:
loaded_learn.load('danbooru_custom_classifier_2b');
loaded_learn.cuda()

In [None]:
loaded_learn.show_results(ds_idx=1, max_n=2, figsize=(15,11))

In [None]:
import glob
import torchvision.transforms as T
ratings = pd.read_csv('clean_0000.csv')
image_id = ratings[ratings['rating'] == 's'].sample(1).values[0][0]

id = short_id(image_id)

image = PILImage.create(glob.glob(str(path_hr / id / str(image_id)) + '.*')[0])
show_image(image)
print(image.shape)

part_size = 64
w, h = image.size
image = image.crop((w/2-(part_size / 2), h/2-(part_size / 2),w/2+(part_size / 2), h/2+(part_size / 2))).resize((512,512))

#convert to fastai image type
image = PILImage.create(np.array(image.convert('RGB')))

show_image(image.resize((256,256), Image.LANCZOS))
print(image.shape)


img_hr,*_ = loaded_learn.model(ToTensor()(image).unsqueeze(0).float().cuda() / 255.)
print(img_hr.shape)
show_image(img_hr)

In [None]:
torch.save({'model': loaded_learn.model}, './super_res.pth')