In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
import sys
sys.path.insert(1, 'pytorch-msssim')

import fastai
from fastai import *
from fastai.vision import *
from fastai.callbacks.tensorboard import *
from fastai.vision.gan import *
from superRes.generators import *
from superRes.critics import *
from superRes.dataset import *
from superRes.loss import *
from superRes.save import *
from superRes.fid_loss import *
from PIL import Image, ImageDraw, ImageFont
from PIL import ImageFile
from pathlib import Path
import torch.nn.functional as F

import torchvision

import geffnet # efficient/ mobile net
import pytorch_msssim # ssim loss

In [2]:
def get_data(bs:int, sz:int, keep_pct:float):
    return get_databunch(sz=sz, bs=bs, crappy_path=path_lowRes, 
                         good_path=path_fullRes, 
                         random_seed=None, keep_pct=keep_pct)

def det_DIV2k_data(bs:int, sz:int):
    lowResSuffix = 'x4m'
    src = ImageImageList.from_folder(path_lowRes).split_by_idxs(train_idx=list(range(0,800)), valid_idx=list(range(800,900)))

    data = (src.label_from_func(lambda x: path_fullRes/(x.name).replace(lowResSuffix, '')).transform(
            get_transforms(
                max_zoom=2.
            ),
            size=sz,
            tfm_y=True,
        ).databunch(bs=bs, num_workers=8, no_check=True).normalize(imagenet_stats, do_y=True))
    data.c = 3
    return data

def create_training_images(fn, i, p_hr, p_lr, size):
    """Create low quality images from folder p_hr in p_lr"""
    dest = p_lr/fn.relative_to(p_hr)
    dest.parent.mkdir(parents=True, exist_ok=True)
    img = PIL.Image.open(fn)
    targ_sz = resize_to(img, size, use_min=True)
    img = img.resize(targ_sz, resample=PIL.Image.BILINEAR).convert('RGB')
    img.save(dest, quality=60)

In [3]:
def do_fit(learn, epochs,save_name, lrs=slice(1e-3), pct_start=0.9):
    learn.fit_one_cycle(epochs, lrs, pct_start=pct_start)
    learn.save(save_name)
    learn.show_results(rows=1, imgsize=5)

In [4]:
path = Path('./dataset/')

path_fullRes = path/'DIV2K_train_HR'
path_lowRes = path/'DIV2K_train_LR_mild'

proj_id = 'unet_superRes_mobilenetV3_SSIM'

gen_name = proj_id + '_gen'
crit_name = proj_id + '_crit'

nf_factor = 2
pct_start = 1e-8

In [5]:
print(path_lowRes)

dataset/DIV2K_train_LR_mild


In [6]:
model = geffnet.mobilenetv3_100
# model = models.resnet34
# model= geffnet.efficientnet_b4

# 256px

In [7]:
bs=10
sz=256
lr = 1e-2
wd = 1e-3
epochs = 5

In [None]:
data_gen = det_DIV2k_data(bs=bs, sz=sz)

In [None]:
# data_gen.show_batch(ds_type=DatasetType.Valid, rows=1, figsize=(9,9))

In [8]:
# loss_func = FeatureLoss()
loss_func = pytorch_msssim.ssim
# loss_func = calculate_frechet_distance

In [12]:
learn_gen = gen_learner_wide(data=data_gen,
                             gen_loss=loss_func,
                             arch = model,
                             nf_factor=nf_factor)

In [None]:
wandbCallbacks = False

if wandbCallbacks:
    wandb.init(project='SuperRes', id="gen_128")

    wandb.config.batch_size = bs
    wandb.config.img_size = (sz, sz)
    wandb.config.learning_rate = lr
    wandb.config.weight_decay = wd
    wandb.config.num_epochs = epochs
    
    learn_gen.callback_fns.append(partial(WandbCallback, input_type='images'))

In [None]:
# learn_gen.lr_find()
# learn_gen.recorder.plot()
# learn_gen.summary()

In [None]:
do_fit(learn_gen, epochs, gen_name+"_256px_0", slice(lr*10))

In [None]:
learn_gen.unfreeze()

In [None]:
do_fit(learn_gen, 3, gen_name+"_256px_1", slice(lr))

# 512px

In [10]:
bs=4
sz=512
epochs = 5

In [11]:
data_gen = det_DIV2k_data(bs, sz)

In [None]:
learn_gen.data = data_gen
learn_gen.freeze()
gc.collect()

In [None]:
learn_gen.load(gen_name+"_256px_1")

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot()

In [None]:
print("Upsize to gen_512")

do_fit(learn_gen, epochs, gen_name+"_512px_0",slice(1e-6))

In [None]:
learn_gen.unfreeze()

In [14]:
do_fit(learn_gen, 1, gen_name+"_512px_1", slice(lr))

epoch,train_loss,valid_loss,time


KeyboardInterrupt: 

In [None]:
learn_gen = None
gc.collect()