# StyleGAN2 Simple pytorch

https://github.com/lucidrains/stylegan2-pytorch

Make sure you have a GPU runtime!

In [None]:
from google.colab import drive
drive.mount("/content/drive")

- Set a tag `my_tag` for this experiment in the form, for example your name, dataset whatever else is imporatant.
- Or set a path to a `model_N.pt` file as the `resume_from` field. The code will resume training for that model. For example: "/content/drive/My Drive/IRCMS_GAN_collaborative_database/Experiments/colab-stylegan2-simple/mathis-pad-double-attn-256/0003/models/stock-images/model_41.pt"


In [None]:
# Set a tag for _your_ experiments, for example your name.
my_tag = "" #@param {type:"string"}

dataset = "/content/drive/My Drive/IRCMS_GAN_collaborative_database/Research/Daniel/ALL CROP.zip" #@param {type:"string"}

resume_from = "" #@param {type:"string"}

desired_size = 256 #@param {type:"integer"}

name = "stock-images" #@param {type: "string"}

if not resume_from and not my_tag:
    raise ValueError("Please set 'my_tag' for new experiments or 'resume_from' to continue training.")

from pathlib import Path

def print_config(d):
    for k, v in sorted(d.items()):
        print(f"=> {k}: {v}")

if resume_from:
    resume_path = Path(resume_from)
    load_epoch = int(resume_path.name.split('.')[0].split('_')[1])
    name = resume_path.parent.name
    models_dir = resume_path.parent.parent
    results_dir = models_dir.parent.joinpath("results")
    print("Resuming from checkpoint:")
    print_config(dict(models_dir=models_dir, results_dir=results_dir, name=name, load_epoch=load_epoch))
else:
    resume_path = None
    load_epoch = -1
    experiment_dir = f"/content/drive/My Drive/IRCMS_GAN_collaborative_database/Experiments/colab-stylegan2-simple/{my_tag}"
    models_dir = Path(experiment_dir).joinpath('models')
    results_dir = Path(experiment_dir).joinpath('results')

    if Path(experiment_dir).exists():
        raise ValueError(f"The directory {experiment_dir} already exists. Please choose another 'my_tag' to avoid overwriting.")
        
    print(f"Running new experiment:")
    print_config(dict(models_dir=models_dir, results_dir=results_dir, name=name))

In [None]:
# Check with GPU we have, P100 is the fast one
!nvidia-smi -L

In [None]:
!rsync -avP "$dataset" /content/dataset.zip

In [None]:
extract_dir = "dataset_extracted"
!unzip -q dataset.zip -d $extract_dir

In [None]:
from PIL import Image, ImageOps
from tqdm import tqdm
import os
import numpy as np

resize_dir = 'dataset_resized'

# 'crop' or 'pad'
resize_method = 'pad'

files = list(Path(extract_dir).rglob("*.*"))

for path in tqdm(files):
    im = Image.open(path)
    relative_path = path.relative_to(extract_dir)

    old_size = im.size 
    ratio = float(desired_size) / (max(old_size) if resize_method == 'pad' else min(old_size)) 
    new_size = tuple([int(x * ratio) for x in old_size])
    im = im.resize(new_size, Image.ANTIALIAS)
    # create a new image and paste the resized on it

    imResize = Image.new("RGB", (desired_size, desired_size), color=(0, 0, 0, 255))
    imResize.paste(im, ((desired_size - new_size[0]) // 2, new_size[1]))
    imResize.paste(im, ((desired_size - new_size[0]) // 2, 0))
    destination = Path(resize_dir).joinpath(relative_path)
    destination.parent.mkdir(parents=True, exist_ok=True)
    imResize.save(destination, 'JPEG', quality=99)

In [None]:
# Check a few images to make sure we didn't mess them up.

from IPython.display import Image
for img in list(Path(resize_dir).rglob("*.*"))[:3]:
    print(img)
    display(Image(str(img)))

In [None]:
!pip install stylegan2_pytorch

In [None]:
from tqdm import tqdm
from stylegan2_pytorch import Trainer, NanException
from datetime import datetime

data = resize_dir

new = False
load_from = -1
image_size = 128
network_capacity = 16
transparent = False
batch_size = 3
gradient_accumulate_every = 5
num_train_steps = 150000
learning_rate = 2e-4
num_workers =  None
save_every = 1000
generate = False
generate_interpolation = False
save_frames = False
num_image_tiles = 8
trunc_psi = 0.75
fp16 = False
cl_reg = False
fq_layers = []
fq_dict_size = 256
attn_layers = []
no_const = False
aug_prob = 0.
dataset_aug_prob = 0.

model = Trainer(
    name,        
    results_dir,
    models_dir,
    batch_size = batch_size,
    gradient_accumulate_every = gradient_accumulate_every,
    image_size = image_size,
    network_capacity = network_capacity,
    transparent = transparent,
    lr = learning_rate,
    num_workers = num_workers,
    save_every = save_every,
    trunc_psi = trunc_psi,
    fp16 = fp16,
    cl_reg = cl_reg,
    fq_layers = fq_layers,
    fq_dict_size = fq_dict_size,
    attn_layers = attn_layers,
    no_const = no_const,
    aug_prob = aug_prob,
    dataset_aug_prob = dataset_aug_prob,
)

if load_from:
    model.load(load_from)
else:
    model.clear()

model.set_data_src(data)

for _ in tqdm(range(num_train_steps - model.steps), mininterval=10., desc=f'{name}<{data}>'):
    model.train()
    if _ % 50 == 0:
        model.print_log()