In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='1'

from fastai import *
from fastai.vision import *
from fastai.vision.gan import *
from ArNet.generators import *
from ArNet.critics import *
from ArNet.dataset import *
from ArNet.loss import *
from ArNet.save import *
from ArNet.fid_loss import *
from ArNet.ssim import *
from ArNet.metrics import *

import torchvision
import geffnet # efficient/ mobile net

In [None]:
def save_preds(dl, path_gen, learn):
    i=0
    names = dl.dataset.items
    
    for b in dl:
        preds = learn.pred_batch(batch=b, reconstruct=True, ds_type=DatasetType.Valid)
        for o in preds:
            o.save(path_gen/names[i].name)
            print(path_gen/names[i].name)
            i += 1
            

def save_gen_images(data_gen, path_gen, learn):
    if path_gen.exists(): shutil.rmtree(path_gen)
    path_gen.mkdir(exist_ok=True)
    save_preds(data_gen.fix_dl, path_gen, learn)
    save_preds(data_gen.valid_dl, path_gen, learn)

In [None]:
def get_DIV2k_data_Input(pLow, pFull, bs: int, sz: int):
    """Given the path of low resolution images
       returns a databunch
    """
    src = ImageImageList.from_folder(pLow, presort=True).split_by_idxs(
        train_idx=list(range(0, 800)), valid_idx=list(range(800, 900)))

    data = (src.label_from_func(
        lambda x: pFull/(x.name.replace(".jpg", ".png"))
    ).transform(
        size=sz,
        tfm_y=True,
    ).databunch(bs=bs, num_workers=8, no_check=True)
        .normalize(imagenet_stats, do_y=True))
    data.c = 3
    return data

In [None]:
def toEven(sz):
    tempSz = [sz[0], sz[1]]
    if sz[0]%2 != 0:
        tempSz[0] += 1
    if sz[1]%2 != 0:
        tempSz[1] += 1
    return tempSz

In [None]:
path = Path('./dataset/')

path_fullRes = path/'DIV2K_train_HR'

path_lowRes_512 = path/'DIV2K_train_LR_512_QF20'

In [None]:
model = geffnet.mobilenetv3_rw
loss_func = lpips_loss()

bs=1

data_gen = get_dummy_databunch(1, 512)

learn_gen = gen_learner_wide(data=data_gen,
                             gen_loss=loss_func,
                             arch = model,
                             nf_factor=2)

# Generator dataset with exported model

In [None]:
weights = "/data/students_home/fmameli/repos/Artifact_Removal_GAN/models/unet_wideNf2_mobileV3_DivFlickr1k_P64px_SuperRes_gen_3"
learn_gen.load(weights, with_opt=False)

learn_gen.export("/data/students_home/fmameli/repos/Artifact_Removal_GAN/models/std_path.pkl" )

In [None]:
root_model_path = Path("/data/students_home/fmameli/repos/Artifact_Removal_GAN/models/")
exported_model_standard =Path("/data/students_home/fmameli/repos/Artifact_Removal_GAN/models/std_patch.pkl")

learn_std = load_learner(path=root_model_path, file=exported_model_standard)

In [None]:
for i in range(1, 900):
    id_img = str(i).zfill(4)

    img_low = open_image("dataset/DIV2K_train_LR_1024_QF20/" + id_img + ".jpg")
    size=toEven(img_low.size)
    data_gen = get_dummy_databunch(1, size)

    learn_std.data = data_gen
    
    p,img_hr,b = learn_std.predict(img_low)
    p.save("dataset/MobilenetV3_Patch_GEN/" + id_img + "_LPIPS.png")
    print("dataset/MobilenetV3_Patch_GEN/" + id_img + ".png")

# Input imgs

In [None]:
proj_id = 'unet_superRes_mobilenetV3_Input'

path_input = path/(proj_id + '_image_gen')
path_input

In [None]:
bs=1
sz=512
nf_factor = 2

data_gen = get_DIV2k_data_Input(path_lowRes_512, path_fullRes, bs=bs, sz=sz)

if path_input.exists(): shutil.rmtree(path_input)
path_input.mkdir(exist_ok=True)
i=0
names = data_gen.fix_dl.dataset.items
for img in data_gen.fix_dl.dataset:
    img[0].save(path_input/names[i].name)
    print(path_input/names[i].name)
    i += 1

i=0
names = data_gen.valid_dl.dataset.items
for img in data_gen.valid_dl.dataset:
    img[0].save(path_input/names[i].name)
    print(path_input/names[i].name)
    i += 1

# SSIM

In [None]:
proj_id = 'unet_superRes_mobilenetV3_SSIM'

path_ssim = path/(proj_id + '_image_gen')
path_ssim

In [None]:
model = geffnet.mobilenetv3_rw
loss_func = SSIM()

bs=1
sz=512
nf_factor = 2

data_gen = get_DIV2k_data_QF(path_lowRes_512, path_fullRes, bs=bs, sz=sz)

learn_gen = gen_learner_wide(data=data_gen,
                                 gen_loss=loss_func,
                                 arch = model,
                                 nf_factor=nf_factor)

weights = "/data/students_home/fmameli/repos/SuperRes/models/unet_superRes_mobilenetV3_SSIM_gen_512px_0"
learn_gen.load(weights, with_opt=False)

save_gen_images(data_gen, path_ssim, learn_gen)

# MSE

In [None]:
proj_id = 'unet_superRes_mobilenetV3_MSE'

path_mse = path/(proj_id + '_image_gen')
path_mse

In [None]:
model = geffnet.mobilenetv3_rw
loss_func = nn.MSELoss()

bs=1
sz=512
nf_factor = 2

data_gen = get_DIV2k_data_QF(path_lowRes_512, path_fullRes, bs=bs, sz=sz)

learn_gen = gen_learner_wide(data=data_gen,
                                 gen_loss=loss_func,
                                 arch = model,
                                 nf_factor=nf_factor)

weights = "/data/students_home/fmameli/repos/SuperRes/models/unet_superRes_mobilenetV3_SSIM_gen_512px_0"
learn_gen.load(weights, with_opt=False)

save_gen_images(data_gen, path_mse, learn_gen)

# LPIPS Patch

In [None]:
proj_id = 'unet_superRes_mobilenetV3_LPIPS_Patch'

path_lpips_patch = path/(proj_id + '_image_gen')
path_lpips_patch

In [None]:
model = geffnet.mobilenetv3_rw
loss_func = lpips_loss()

bs=2
sz=512
nf_factor = 2

data_gen = get_DIV2k_data_QF(path_lowRes_512, path_fullRes, bs=bs, sz=sz)

learn_gen = gen_learner_wide(data=data_gen,
                                 gen_loss=loss_func,
                                 arch = model,
                                 nf_factor=nf_factor)

learn_gen.load("/data/students_home/fmameli/repos/SuperRes/models/unet_wideNf2_superRes_mobilenetV3_Patches64px_gen_64px_0")

save_gen_images(data_gen, path_lpips_patch, learn_gen)

# LPIPS

In [None]:
proj_id = 'unet_superRes_mobilenetV3_LPIPS'

path_lpips = path/(proj_id + '_image_gen')
path_lpips

In [None]:
model = geffnet.mobilenetv3_rw
loss_func = lpips_loss()

bs=2
sz=512
nf_factor = 2

data_gen = get_DIV2k_data_QF(path_lowRes_512, path_fullRes, bs=bs, sz=sz)

learn_gen = gen_learner_wide(data=data_gen,
                                 gen_loss=loss_func,
                                 arch = model,
                                 nf_factor=nf_factor)

weights = "/data/students_home/fmameli/repos/SuperRes/models/unet_superRes_mobilenetV3_LPIPS_Tuned_gen_512px_0"
learn_gen.load(weights, with_opt=False)

save_gen_images(data_gen, path_lpips, learn_gen)

# GAN

In [None]:
proj_id = 'unet_superRes_mobilenetV3_GAN'

path_gan = path/(proj_id + '_image_gen')
path_gan

In [None]:
model = geffnet.mobilenetv3_rw
loss_func = lpips_loss()

bs=2
sz=512
nf_factor = 2

data_gen = get_DIV2k_data_QF(path_lowRes_512, path_fullRes, bs=bs, sz=sz)

learn_gen = gen_learner_wide(data=data_gen,
                                 gen_loss=loss_func,
                                 arch = model,
                                 nf_factor=nf_factor)

weights = "/data/students_home/fmameli/repos/SuperRes/models/unet_superRes_mobilenetV3_LPIPS_Tuned_gen_512px_0"
learn_gen.load(weights, with_opt=False)

save_gen_images(data_gen, path_gan, learn_gen)