## Super resolution

In [None]:
import fastai
from fastai import *
from fastai.vision import *
from fastai.callbacks import *

In [None]:
from torchvision.models import vgg16_bn

In [None]:
torch.cuda.set_device(2)

First you need to create a folder of lower-quality images. They should have the same folder structure as your original images. Here's how you can do that using `rsync` (to create the folder structure) and `imagemagick` with `gnu parallel` (to convert the images in parallel) - each of these needs to be installed (e.g. using `sudo apt rsync` in Linux). In this case I'm assuming the source directory is called *images* and the target is called *small-96*. In this case I'm creating images with longest side 96px, and with low jpeg quality.

In [None]:
path = Path('data/oxford-iiit-pet')
path_hr = path/'images'
path_lr = path/'small-96'

In [None]:
il = ImageItemList.from_folder(path_hr, label_cls=ImageItemList)

In [None]:
def resize_one(fn,i):
    dest = path_lr/fn.relative_to(path_hr)
    dest.parent.mkdir(parents=True, exist_ok=True)
    img = PIL.Image.open(fn)
    targ_sz = resize_to(img,96,use_min=True)
    img = img.resize(targ_sz, resample=PIL.Image.BILINEAR)
    img.save(dest, quality=60)

In [None]:
# parallel(resize_one, il.items)

In [None]:
open_image(path_lr/il.items[0].relative_to(path_hr)).show()

In [None]:
# bs,size=16,256
bs,size=32,160
arch = models.resnet34

src =  (ImageItemList.from_folder(path_lr, label_cls=ImageItemList)
       .random_split_by_pct(0.1, seed=42))
data = (src.label_from_func(lambda x: path_hr/x.name)
       .transform(get_transforms(max_zoom=2.), size=size, tfm_y=True)
       .databunch(bs=bs).normalize(imagenet_stats, do_y=True))

data.c = 3

In [None]:
trn,val = data.train_ds,data.valid_ds

In [None]:
show_multi(lambda i,j: val[i][j], 2, 2, figsize=(12,12))

## Feature loss

In [None]:
k = tensor([
    [0.  ,-5/3,1],
    [-5/3,-5/3,1],
    [1.  ,1   ,1],
]).cuda().expand(1,3,3,3)/6

base_loss = F.l1_loss

In [None]:
class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        super().__init__()
        self.m_feat = m_feat
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['L1      ','edge    '] + [f'feat_{i}  ' for i in range(len(layer_ids))]

    def make_feature(self, bs, o, clone=False):
        feat = o.view(bs, -1)
        return feat.clone() if clone else feat
    
    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [self.make_feature(x.shape[0], o, clone) for o in self.hooks.stored]
    
    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        px_loss   = base_loss(input,target)
        edge_loss = base_loss(F.conv2d(target, k), F.conv2d(input, k))
        self.feat_losses = [px_loss,edge_loss]
        self.feat_losses += [base_loss(f_in, f_out)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)
    
    def __del__(self): self.hooks.remove()

In [None]:
vgg_m = vgg16_bn(True).features.cuda().eval()
requires_grad(vgg_m, False)

In [None]:
blocks = [i-1 for i,o in enumerate(children(vgg_m)) if isinstance(o,nn.MaxPool2d)]
blocks, [vgg_m[i] for i in blocks]

In [None]:
# del(feat_loss,learn)

In [None]:
feat_loss = FeatureLoss(vgg_m, blocks[2:5], [5,5,5])

In [None]:
def get_preds():
    x,y=learn.data.one_batch(DatasetType.Valid, True,True)
    preds = data.denorm(learn.pred_batch(DatasetType.Valid).detach())
    return x,y,preds

def show_pred(i=0):
    _,axes = plt.subplots(1,3, figsize=(15,5))
    show_image(x[i], axes[0])
    show_image(preds[i], axes[1])
    show_image(y[i], axes[2]);

## Small

In [None]:
wd = 1e-3

In [None]:
learn = Learner.create_unet(data, arch, wd=wd, loss_func=feat_loss,
                            callback_fns=LossMetrics, all_wn=True, blur=True)

In [None]:
learn.lr_find()
learn.recorder.plot()

In [None]:
lr = 1e-3

In [None]:
learn.load('2');

In [None]:
#size 160
learn.fit_one_cycle(10, slice(lr*10), pct_start=0.9)

In [None]:
#size 256
learn.fit_one_cycle(10, slice(lr), pct_start=0.9)

In [None]:
learn.recorder.plot_losses()

In [None]:
m = learn.model.eval()
bn = [o for o in flatten_model(m) if isinstance(o,bn_types)]
[(t.running_mean.max(),t.running_var.max()) for t in bn]

In [None]:
x,y,preds = get_preds()

In [None]:
#160
show_pred(0)

In [None]:
#256
show_pred(0)

In [None]:
learn.save('1')

In [None]:
learn.save('1b')

In [None]:
learn.unfreeze()

In [None]:
max_lr = learn.lr_range(slice(1,100))

In [None]:
learn.lr_find(max_lr*1e-10,max_lr)
learn.recorder.plot()

In [None]:
lrs = slice(1e-5,1e-3)

In [None]:
learn.load('1');

In [None]:
learn.fit_one_cycle(10, lrs)

In [None]:
#256
learn.fit_one_cycle(10, lrs)

In [None]:
learn.save('2');

In [None]:
learn.save('2b');

In [None]:
learn.recorder.plot_losses()

In [None]:
x,y,preds = get_preds()

In [None]:
#160
show_pred(0)

In [None]:
#256
show_pred(0)

In [None]:
data_hr = (src.label_from_func(lambda x: path_hr/x.name)
          .transform(get_transforms(max_zoom=2.), size=1024, tfm_y=True)
          .databunch(bs=bs).normalize(imagenet_stats, do_y=True))

data_hr.c = 3
learn.data = data_hr

In [None]:
img = data_hr.valid_ds[0][1]

In [None]:
_,img_hr,b = learn.predict(img)

In [None]:
data.valid_ds[0][1].show(figsize=(18,18))

In [None]:
Image(data_hr.denorm(img_hr)).show(figsize=(18,18))

## fin