## Connect the Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Install Necessary Modules

In [None]:
!pip install fastai==1.0.61
# !pip install deoldify==0.1.0
!pip install deoldify
!pip install torch torchvision --upgrade

In [1]:
#NOTE:  This must be the first call in order to work properly!
from deoldify import device
from deoldify.device_id import DeviceId
#choices:  CPU, GPU0...GPU7
device.set(device=DeviceId.GPU0)

<DeviceId.GPU0: 0>

## Import necessary Library

In [None]:
from deoldify.dataset import get_colorize_data
from deoldify.generators import gen_learner_deep

In [None]:
import os
import fastai
from fastai import *
from fastai.vision import *
from fastai.callbacks.tensorboard import *
from fastai.vision.gan import *
from deoldify.generators import *
from deoldify.critics import *
from deoldify.dataset import *
from deoldify.loss import *
from deoldify.save import *
from PIL import Image, ImageDraw, ImageFont
from PIL import ImageFile

## Setting the variables 

In [None]:

path = Path('./data') #-> Sets the base path for data storage.
path_hr = path #->  path for high resolution images
path_lr = path/'bandw' #-> path for low resolution (grayscale) images

proj_id = 'ArtisticModel' #-> Porject id for folder naming

gen_name = proj_id + '_gen' #->['ArtisticModel_gen']  Creates a name for the generator model based on the project ID.
pre_gen_name = gen_name + '_0' #->['ArtisticModel_gen_0']  Creates a name for the initial version of the generator model.
crit_name = proj_id + '_crit' # Creates a name for the critic (or discriminator) model.

name_gen = proj_id + '_image_gen' #-> Creates a folder name for the generated images.
path_gen = path/name_gen #-> Sets the path for the generated images

TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)

nf_factor = 1.5
pct_start = 1e-8

pre_trained_model_path = "/content/drive/MyDrive/models/ColorizeArtistic_gen"

## Important Function 

In [None]:

def get_data(bs:int, sz:int, keep_pct:float): #->This function prepares the data for training the colorization model.
    data = get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr,
                            random_seed=None, keep_pct=keep_pct)
    print(data)
    print(f"Number of items: {len(data.items)}")
    return data

def get_crit_data(classes, bs, sz): #-> function prepares data for training the critic (discriminator) model.
    src = ImageList.from_folder(path, include=classes, recurse=True).split_by_rand_pct(0.1, seed=42)
    ll = src.label_from_folder(classes=classes)
    data = (ll.transform(get_transforms(max_zoom=2.), size=sz)
        .databunch(bs=bs).normalize(imagenet_stats))
    return data

def create_training_images(fn): #-> This function creates grayscale versions of high-resolution images.
    dest = path_lr/fn.relative_to(path_hr)
    dest.parent.mkdir(parents=True, exist_ok=True)
    img = PIL.Image.open(fn).convert('LA').convert('RGB')
    img.save(dest)

def save_preds(dl): #-> This function saves the predictions of the generator model.
    i=0
    names = dl.dataset.items

    for b in dl:
        preds = learn_gen.pred_batch(batch=b, reconstruct=True)
        for o in preds:
            o.save(path_gen/names[i].name)
            i += 1

def save_gen_images(): #-> This function generates and saves images using the generator model.
    if path_gen.exists(): shutil.rmtree(path_gen)
    path_gen.mkdir(exist_ok=True)
    data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)
    save_preds(data_gen.fix_dl)
    PIL.Image.open(path_gen.ls()[0])

## Old checkpoints should be increment by one each time 

In [None]:
old_checkpoint_num = 0
checkpoint_num = old_checkpoint_num + 1
gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)
gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)
crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)
crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)

In [None]:
bs=16
sz=192
keep_pct=1.0

In [None]:
data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)

In [None]:
# Run this first time while loading a pre-trained model
learn_gen = gen_learner_deep(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load("/content/drive/MyDrive/Temples Data/bandw/models/Pre-Trained_Model", with_opt=False)

In [None]:
learn_gen = gen_learner_deep(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)

In [None]:
save_gen_images()

In [None]:
bs=16
sz=192

In [None]:
data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)

In [None]:
data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)

In [None]:
learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)

In [None]:
learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))

In [None]:
learn_critic.fit_one_cycle(2, 1e-4)

In [None]:
learn_critic.save(crit_new_checkpoint_name)

In [None]:
learn_crit=None
learn_gen=None
gc.collect()

In [None]:
lr=1e-5
sz=192
bs=9

In [None]:
data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)

In [None]:
learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)

In [None]:
learn_gen = gen_learner_deep(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)

In [None]:
switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,2.0), show_img=True, switcher=switcher, #-> Updated
                                opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))
learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))
learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))

In [None]:
learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)
learn_gen.freeze_to(-1)
learn.fit(1,lr)

In [None]:
learn_gen.save(gen_new_checkpoint_name)