In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='1'
import sys
sys.path.append("..")

from fastai import *
from fastai.vision import *
from superRes.generators import *
from superRes.critics import *
from superRes.dataset import *
from superRes.loss import *
from superRes.save import *
from superRes.fid_loss import *
from superRes.ssim import *
from superRes.metrics import *
from pathlib import Path

import torchvision
import geffnet # efficient/ mobile net

In [None]:
def get_DIV2k_data_QF(pLow, bs:int, sz:int):
    """Given the path of low resolution images
       returns a databunch
    """
    src = ImageImageList.from_folder(pLow).split_by_idxs(train_idx=list(range(0,800)), valid_idx=list(range(800,900)))
    
    data = (src.label_from_func(lambda x: path_fullRes/(x.name.replace(".jpg", ".png"))).transform(
            get_transforms(
                max_rotate=30,
                max_lighting=.4,
                max_warp=.4
            ),
            size=sz,
            tfm_y=True,
        ).databunch(bs=bs, num_workers=8, no_check=True).normalize(imagenet_stats, do_y=True))
    data.c = 3
    return data

In [None]:
def do_fit(learn, epochs,save_name, lrs=slice(1e-3), pct_start=0.9):
    learn.fit_one_cycle(epochs, lrs, pct_start=pct_start)
    learn.save(save_name)
    learn.show_results(rows=1, imgsize=7)

In [None]:
path = Path('./dataset/')

path_fullRes = path/'DIV2K_train_HR'
path_lowRes_128 = path/'DIV2K_train_LR_128'
path_lowRes_256 = path/'DIV2K_train_LR_256'

proj_id = 'unet_superRes_mobilenetV3_SSIM'

gen_name = proj_id + '_gen'
crit_name = proj_id + '_crit'

nf_factor = 2
pct_start = 1e-8

In [None]:
print(path_fullRes)

In [None]:
model = geffnet.mobilenetv3_rw

In [None]:
loss_func = msssim

# 128px

In [None]:
bs=10
sz=256
lr = 1e-2
wd = 1e-3
epochs = 1

In [None]:
data_gen = get_DIV2k_data_QF(path_lowRes_128, bs=bs, sz=sz)

In [None]:
x, y = data_gen.one_batch()

# LPIPS

In [None]:
import sys
sys.path.append("../PerceptualSimilarity/")

In [None]:
import PerceptualSimilarity.models as lpips

In [None]:
mod = lpips.PerceptualLoss(model='net-lin', net='alex', use_gpu=True, gpu_ids=[0])
d = mod.forward(x,y);

In [None]:
d.size()

In [None]:
d.mean()

# Brisque

In [None]:
from brisque import BRISQUE

In [None]:
brisque = BRISQUE()

In [None]:
x[0].shape

In [None]:
x[0].permute(1, 2, 0).numpy().shape

In [None]:
brisque.get_score(x[0].permute(1, 2, 0).numpy())

In [None]:
values = []
for img in x:
    score = brisque.get_score(img.permute(1,2,0).numpy())
    values.append(score)

In [None]:
len(values)

# NIQE

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
from skvideo.measure.niqe import *
from torchvision import transforms

In [None]:
x[0][0].shape

In [None]:
transforms.ToPILImage()(x[0][0]).convert("RGB")

In [None]:
values = []
for img in x:
    score = niqe(img[0].numpy())
    values.append(score)
values

In [None]:
niqe(x[0][0].numpy())

In [None]:
learn_gen = gen_learner_wide(data=data_gen,
                             gen_loss=loss_func,
                             arch = model,
                             nf_factor=nf_factor)

In [None]:
learn_gen.metrics.append(LPIPS_Metric())
learn_gen.metrics.append(BRISQUE_Metric())
learn_gen.metrics.append(NIQE_Metric())

In [None]:
do_fit(learn_gen, epochs, gen_name+"_128px_0", slice(lr*10))