In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='1'

from fastai import *
from fastai.vision import *
from superRes.generators import *
from superRes.critics import *
from superRes.dataset import *
from superRes.loss import *
from superRes.save import *
from superRes.fid_loss import *
from superRes.ssim import *
from PIL import Image, ImageDraw, ImageFont
from PIL import ImageFile
from pathlib import Path

import torchvision
import geffnet # efficient/ mobile net

In [2]:
def get_data(bs:int, sz:int, keep_pct:float):
    return get_databunch(sz=sz, bs=bs, crappy_path=path_lowRes, 
                         good_path=path_fullRes, 
                         random_seed=None, keep_pct=keep_pct)

def get_DIV2k_data(pLow, bs:int, sz:int):
    """Given the path of low resolution images with a proper suffix
       returns a databunch
    """
    suffixes = {"dataset/DIV2K_train_LR_x8": "x8",
                "dataset/DIV2K_train_LR_difficult":"x4d", 
                "dataset/DIV2K_train_LR_mild":"x4m"}
    lowResSuffix = suffixes[str(pLow)]
    src = ImageImageList.from_folder(pLow).split_by_idxs(train_idx=list(range(0,800)), valid_idx=list(range(800,900)))
    
    data = (src.label_from_func(lambda x: path_fullRes/(x.name).replace(lowResSuffix, '')).transform(
            get_transforms(
                flip_vert=True,
                max_rotate=30,
                max_zoom=3.,
                max_lighting=.4,
                max_warp=.4,
                p_affine=.85
            ),
            size=sz,
            tfm_y=True,
        ).databunch(bs=bs, num_workers=8, no_check=True).normalize(imagenet_stats, do_y=True))
    data.c = 3
    return data

def create_training_images(fn, i, p_hr, p_lr, size):
    """Create low quality images from folder p_hr in p_lr"""
    dest = p_lr/fn.relative_to(p_hr)
    dest.parent.mkdir(parents=True, exist_ok=True)
    img = PIL.Image.open(fn)
    targ_sz = resize_to(img, size, use_min=True)
    img = img.resize(targ_sz, resample=PIL.Image.BILINEAR).convert('RGB')
    img.save(dest, quality=60)

In [3]:
path = Path('./dataset/')

path_fullRes = path/'DIV2K_train_HR'
path_lowRes_diff = path/'DIV2K_train_LR_difficult' # suffix "x4d" ~300px
path_lowRes_mild = path/'DIV2K_train_LR_mild' # suffix "x4m" ~300px
path_lowRes_x8 = path/'DIV2K_train_LR_x8' # suffix "x8" ~150px


proj_id = 'unet_superRes_mobilenetV3_SSIM'

gen_name = proj_id + '_gen'
crit_name = proj_id + '_crit'

nf_factor = 2
pct_start = 1e-8

In [5]:
model = geffnet.mobilenetv3_large_100

In [6]:
loss_func = ssim

In [7]:
data_gen = get_DIV2k_data(path_lowRes_x8, bs=2, sz=64)

In [8]:
x, y = data_gen.one_batch()

In [16]:
x[0][None].size()

torch.Size([1, 3, 64, 64])

In [19]:
s1 = ssim(x[0][None], y[0][None]); s1

tensor(0.0190)

In [20]:
s2 = ssim(x[1][None], y[1][None]); s2

tensor(0.0530)

In [21]:
ssim(x, y)

tensor(0.0360)

In [24]:
(s1+s2)/2

tensor(0.0360)

In [25]:
s1 = msssim(x[0][None], y[0][None]); s1

tensor(0.0015)

In [26]:
s2 = msssim(x[1][None], y[1][None]); s2

tensor(0.0054)

In [27]:
msssim(x, y)

tensor(0.0034)

In [28]:
(s1+s2)/2

tensor(0.0034)

In [7]:
learn_gen = gen_learner_wide(data=data_gen,
                             gen_loss=loss_func,
                             arch = model,
                             nf_factor=nf_factor)

In [8]:
SSIM_Metric

superRes.ssim.SSIM_Metric

In [9]:
learn_gen.metrics.append(SSIM_Metric())

In [10]:
learn_gen.fit_one_cycle(1, slice(1e-3), pct_start=0.9)

epoch,train_loss,valid_loss,ssim__metric,time
0,0.977897,0.888341,0.888341,00:47
