In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='0'

import fastai
from fastai import *
from fastai.vision import *
from fastai.callbacks.tensorboard import *
from fastai.vision.gan import *
from superRes.generators import *
from superRes.critics import *
from superRes.dataset import *
from superRes.loss import *
from superRes.save import *
from PIL import Image, ImageDraw, ImageFont
from PIL import ImageFile

import torchvision

import geffnet # efficient/ mobile net

In [None]:
def get_data(bs:int, sz:int, keep_pct:float):
    return get_databunch(sz=sz, bs=bs, crappy_path=path_lowRes, 
                         good_path=path_fullRes, 
                         random_seed=None, keep_pct=keep_pct)

def create_training_images(fn, i, p_hr, p_lr, size):
    """Create low quality images from folder p_hr in p_lr"""
    dest = p_lr/fn.relative_to(p_hr)
    dest.parent.mkdir(parents=True, exist_ok=True)
    img = PIL.Image.open(fn)
    targ_sz = resize_to(img, size, use_min=True)
    img = img.resize(targ_sz, resample=PIL.Image.BILINEAR).convert('RGB')
    img.save(dest, quality=60) 

In [None]:
def do_fit(learn, epochs,save_name, lrs=slice(1e-3), pct_start=0.9):
    learn.fit_one_cycle(epochs, lrs, pct_start=pct_start)
    learn.save(save_name)
    learn.show_results(rows=1, imgsize=5)

In [None]:
path = untar_data(URLs.PETS)

path_fullRes = path/'images'
path_lowRes = path/'lowRes-96'
path_medRes = path/'lowRes-256'

proj_id = 'unet_superRes'

gen_name = proj_id + '_gen'
crit_name = proj_id + '_crit'

TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)

nf_factor = 2
pct_start = 1e-8

In [None]:
print(path_fullRes)

In [None]:
sets = [(path_lowRes, 96),(path_medRes, 256)]
il = ImageList.from_folder(path_fullRes)

for p,size in sets:
    if not p.exists():
        print(f"resizing to {size} into {p}")
        parallel(partial(create_training_images, p_hr=path_fullRes, p_lr=p, size=size), il.items)

In [None]:
model = geffnet.mobilenetv3_100
# model = models.resnet34

# 128px

In [None]:
bs=25
sz=128
lr = 1e-3
wd = 1e-3
keep_pct=1.0
epochs = 10

In [None]:
data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)

In [None]:
data_gen.show_batch(ds_type=DatasetType.Valid, rows=1, figsize=(9,9))

In [None]:
learn_gen = gen_learner_wide(data=data_gen,
                             gen_loss=FeatureLoss(),
                             arch = model,
                             nf_factor=nf_factor)

In [None]:
wandbCallbacks = False

if wandbCallbacks:
    wandb.init(project='SuperRes', id="gen_128")

    wandb.config.batch_size = bs
    wandb.config.img_size = (sz, sz)
    wandb.config.learning_rate = lr
    wandb.config.weight_decay = wd
    wandb.config.num_epochs = epochs
    
    learn_gen.callback_fns.append(partial(WandbCallback, input_type='images'))

In [None]:
# learn_gen.lr_find()
# learn_gen.recorder.plot()
# learn_gen.summary()

In [None]:
do_fit(learn_gen, epochs, gen_name+"_128px_0", slice(lr*10))

In [None]:
learn_gen.unfreeze()

In [None]:
do_fit(learn_gen, epochs, gen_name+"_128px_1", slice(1e-5, lr))

# 256px

In [None]:
learn_gen = None
gc.collect()

In [None]:
bs=10
sz=256
lr = 1e-3
wd = 1e-3
keep_pct=1.0
epochs = 10

In [None]:
data_gen = get_data(bs, sz, keep_pct=keep_pct)

In [None]:
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), arch = model, nf_factor=nf_factor)
learn_gen.load(gen_name+"_128px_1");
learn_gen.data = data_gen
learn_gen.freeze()

In [None]:
wandbCallbacks = False

if wandbCallbacks:
    wandb.init(project='SuperRes', id="gen_256")

    wandb.config.batch_size = bs
    wandb.config.img_size = (sz, sz)
    wandb.config.learning_rate = lr
    wandb.config.weight_decay = wd
    wandb.config.num_epochs = epochs
    
    learn_gen.callback_fns.append(partial(WandbCallback, input_type='images'))

In [None]:
do_fit(learn_gen, epochs, gen_name+"_256px_0")

In [None]:
learn_gen.unfreeze()

In [None]:
do_fit(learn_gen, epochs, gen_name+"_256px_1", slice(1e-6,1e-4), pct_start=0.3)

In [None]:
learn_gen = None
gc.collect()