In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='0'

from fastai import *
from fastai.vision import *
from fastai.vision.gan import *
from ArNet.generators import *
from ArNet.critics import *
from ArNet.dataset import *
from ArNet.loss import *
from ArNet.save import *
from ArNet.fid_loss import *
from ArNet.ssim import *
from ArNet.metrics import *

import torchvision
import geffnet # efficient/ mobile net

In [None]:
def do_fit(learn, epochs,save_name, lrs=slice(1e-3), pct_start=0.9):
    learn.fit_one_cycle(epochs, lrs, pct_start=pct_start)
    learn.save(save_name)
    learn.show_results(rows=1, imgsize=20)

In [None]:
path = Path('./dataset/')

path_fullRes = path/'DIV2K_train_HR'

path_lowRes_256 = path/'DIV2K_train_LR_256_QF20'
path_lowRes_512 = path/'DIV2K_train_LR_512_QF20'
path_lowRes_Full = path/'DIV2K_train_LR_Full_QF20'

proj_id = 'unet_superRes_mobilenetV3_LPIPS'

gen_name = proj_id + '_gen'
crit_name = proj_id + '_crit'

nf_factor = 2
pct_start = 1e-8

In [None]:
model = geffnet.mobilenetv3_rw

In [None]:
# loss_func = FeatureLoss()
# loss_func = msssim
# loss_func = fid
# loss_func = F.mse_loss
loss_func = lpips_loss()

# 256px

In [None]:
bs=10
sz=256
lr = 1e-2
wd = 1e-3
epochs = 1

In [None]:
data_gen = get_DIV2k_data_QF(path_lowRes_256, path_fullRes, bs=bs, sz=sz)

In [None]:
learn_gen = gen_learner_wide(data=data_gen,
                             gen_loss=loss_func,
                             arch = model,
                             nf_factor=nf_factor)

In [None]:
learn_gen.metrics.append(SSIM_Metric_gen())
learn_gen.metrics.append(SSIM_Metric_input())
# learn_gen.metrics.append(LPIPS_Metric_gen())
# learn_gen.metrics.append(LPIPS_Metric_input())
learn_gen.metrics.append(BRISQUE_Metric_gen())
learn_gen.metrics.append(BRISQUE_Metric_input())
learn_gen.metrics.append(BRISQUE_Metric_target())
learn_gen.metrics.append(NIQE_Metric_gen())
learn_gen.metrics.append(NIQE_Metric_input())
learn_gen.metrics.append(NIQE_Metric_target())

In [None]:
wandbCallbacks = False

if wandbCallbacks:
    import wandb
    from wandb.fastai import WandbCallback
    config={"batch_size": bs,
            "img_size": (sz, sz),
            "learning_rate": lr,
            "weight_decay": wd,
            "num_epochs": epochs
    }
    wandb.init(project='SuperRes', config=config, id="unet_superRes_mobilenetV3_FID"+ datetime.now().strftime('_%m-%d_%H:%M'))

    learn_gen.callback_fns.append(partial(WandbCallback, input_type='images'))

In [None]:
# learn_gen.lr_find()
# learn_gen.recorder.plot()
# learn_gen.summary()

In [None]:
do_fit(learn_gen, 1, gen_name+"_256px_0", slice(lr*10))

In [None]:
learn_gen.unfreeze()

In [None]:
do_fit(learn_gen, 3, gen_name+"_256px_1", lr)

# 512px

In [None]:
bs=2
sz=512
epochs = 1

In [None]:
data_gen = get_DIV2k_data_QF(path_lowRes_512, path_fullRes, bs=bs, sz=sz)

In [None]:
learn_gen.data = data_gen
learn_gen.freeze()
gc.collect()

6563

In [None]:
# learn_gen.load(gen_name+"_256px_1")

In [None]:
# learn_gen.lr_find()
# learn_gen.recorder.plot()

In [None]:
print("Upsize to gen_512")

do_fit(learn_gen, 3, gen_name+"_512px_0", slice(1e-3))

In [None]:
learn_gen.unfreeze()

In [None]:
do_fit(learn_gen, 3, gen_name+"_512px_1", 1e-3)

In [None]:
learn_gen.show_results(rows=10, imgsize=20)

# Quality 35

In [None]:
bs=2
sz=512
epochs = 1

path_lowRes_512 = path/'DIV2K_train_LR_512_QF35'

In [None]:
data_gen = get_DIV2k_data_QF(path_lowRes_512, path_fullRes, bs=bs, sz=sz)

In [None]:
learn_gen.data = data_gen
learn_gen.freeze()
gc.collect()

26433

In [None]:
do_fit(learn_gen, 3, gen_name+"_512px_2", 1e-3)

In [None]:
# learn_gen.load(gen_name+"_512px_2")

In [None]:
learn_gen.unfreeze()

In [None]:
do_fit(learn_gen, 3, gen_name+"_512px_3", slice(1e-3))

# Quality 50

In [None]:
bs=2
sz=512
epochs = 1

path_lowRes_512 = path/'DIV2K_train_LR_512_QF50'

In [None]:
data_gen = get_DIV2k_data_QF(path_lowRes_512, path_fullRes, bs=bs, sz=sz)

In [None]:
# learn_gen.data = data_gen
# learn_gen.freeze()
gc.collect()

0

In [None]:
# learn_gen.lr_find()
# learn_gen.recorder.plot()

In [None]:
do_fit(learn_gen, 1, gen_name+"_512px_4", 1e-3)

In [None]:
learn_gen.load(gen_name+"_512px_4")

In [None]:
learn_gen.unfreeze()

In [None]:
do_fit(learn_gen, 3, gen_name+"_512px_5", slice(1e-3))

# Fine Tune patch model

In [None]:
bs=2
sz=512
lr = 1e-2
wd = 1e-3
epochs = 1

In [None]:
data_gen = get_DIV2k_data_QF(path_lowRes_512, path_fullRes, bs=bs, sz=sz)

In [None]:
learn_gen = gen_learner_wide(data=data_gen,
                             gen_loss=loss_func,
                             arch = model,
                             nf_factor=nf_factor)

In [None]:
learn_gen.metrics.append(SSIM_Metric_gen())
learn_gen.metrics.append(SSIM_Metric_input())
learn_gen.metrics.append(BRISQUE_Metric_gen())
learn_gen.metrics.append(BRISQUE_Metric_input())
learn_gen.metrics.append(BRISQUE_Metric_target())
learn_gen.metrics.append(NIQE_Metric_gen())
learn_gen.metrics.append(NIQE_Metric_input())
learn_gen.metrics.append(NIQE_Metric_target())

In [None]:
learn_gen.load("/data/students_home/fmameli/repos/Artifact_Removal_GAN/dataset/DIV2K_train_LR_Patches/64px_FullQF20/models/unet_superRes_mobilenetV3_Patches64px_gen_64px_2")

In [None]:
do_fit(learn_gen, 1, gen_name+"_512px_5", 1e-3)

In [None]:
learn_gen.unfreeze()

In [None]:
do_fit(learn_gen, 3, gen_name+"_512px_5", 1e-3)

In [None]:
learn_gen.show_results(rows=5, imgsize=15)

# Test

In [None]:
bs=1
sz=512
epochs = 1

In [None]:
path_lowRes_512 = path/'DIV2K_train_LR_512_QF20'
size=( 512, 680)

In [None]:
data_1k = (ImageImageList.from_folder(path_lowRes_512, presort=True).split_by_idxs(
            train_idx=list(range(0, 800)), valid_idx=list(range(800, 900)))
          .label_from_func(lambda x: path_fullRes/x.name.replace(".jpg", ".png"))
          .transform(get_transforms(), size=size, tfm_y=True)
          .databunch(bs=1).normalize(imagenet_stats, do_y=True))
data_1k.c = 3

In [None]:
learn_gen.data = data_1k
learn_gen.freeze()
gc.collect()

20088

In [None]:
learn_gen.load("/data/students_home/fmameli/repos/Artifact_Removal_GAN/dataset/DIV2K_train_LR_Patches/64px_FullQF20/models/unet_superRes_mobilenetV3_Patches64px_gen_64px_2")

In [None]:
fn = data_1k.valid_dl.x.items[2]; fn

PosixPath('dataset/DIV2K_train_LR_512_QF20/0803.jpg')

In [None]:
img = open_image(fn); print(img.shape)
p,img_hr,b = learn_gen.predict(img)

torch.Size([3, 512, 680])


In [None]:
show_image(img, figsize=(15,15), interpolation='nearest');

In [None]:
Image(img_hr).show(figsize=(15,15))