# Download required Libraries

In [None]:
!pip uninstall jax jaxlib -y
!pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install torch==1.8.1 torchvision==0.9.1
!git clone https://github.com/NVlabs/stylegan2-ada-pytorch.git
!git clone https://github.com/NVlabs/stylegan3.git
!pip install ninja

## Dataset

In [3]:
!ls gan/images

celeba_hq_256


## Convert Your Images

You must convert your images into a data set form that PyTorch can directly utilize. The following command converts your images and writes the resulting data set to another directory.

In [4]:
CMD = "python stylegan3/dataset_tool.py "\
  "--source gan/images/celeba_hq_256 "\
  "--dest gan/dataset/celeba_hq_256"

!{CMD}


  img = img.resize((ww, hh), PIL.Image.LANCZOS)

  1%|▏         | 14/1000 [00:00<00:07, 135.66it/s]
  3%|▎         | 28/1000 [00:00<00:07, 137.57it/s]
  4%|▍         | 44/1000 [00:00<00:06, 144.79it/s]
  6%|▌         | 59/1000 [00:00<00:06, 144.42it/s]
  7%|▋         | 74/1000 [00:00<00:06, 144.52it/s]
  9%|▉         | 89/1000 [00:00<00:06, 140.42it/s]
 10%|█         | 104/1000 [00:00<00:06, 142.65it/s]
 12%|█▏        | 119/1000 [00:00<00:06, 141.35it/s]
 13%|█▎        | 134/1000 [00:00<00:06, 142.99it/s]
 15%|█▌        | 150/1000 [00:01<00:05, 145.42it/s]
 17%|█▋        | 166/1000 [00:01<00:05, 147.41it/s]
 18%|█▊        | 182/1000 [00:01<00:05, 149.82it/s]
 20%|█▉        | 197/1000 [00:01<00:05, 148.78it/s]
 21%|██        | 212/1000 [00:01<00:05, 146.67it/s]
 23%|██▎       | 228/1000 [00:01<00:05, 148.58it/s]
 24%|██▍       | 243/1000 [00:01<00:05, 142.71it/s]
 26%|██▌       | 258/1000 [00:01<00:05, 143.87it/s]
 27%|██▋       | 274/1000 [00:01<00:04, 146.83it/s]
 29%|██▉       | 289

## Process Image Data

All images must have the same dimensions and color depth.  This code can identify images that have issues.

In [None]:
from os import listdir
from os.path import isfile, join
import os
from PIL import Image
from tqdm.notebook import tqdm

IMAGE_PATH = 'gan/images/celeba_hq_256'
files = [f for f in listdir(IMAGE_PATH) if isfile(join(IMAGE_PATH, f))]

base_size = None
for file in tqdm(files):
  file2 = os.path.join(IMAGE_PATH,file)
  img = Image.open(file2)
  sz = img.size
  if base_size and sz!=base_size:
    print(f"Inconsistant size: {file2}")
  elif img.mode!='RGB':
    print(f"Inconsistant color format: {file2}")
  else:
    base_size = sz

## Training

In [7]:
import os

# Modify these to suit your needs
EXPERIMENTS = "gan/experiments"
DATA = "gan/dataset/celeba_hq_256"
SNAP = 10
cfg = "stylegan2"
gpus = 1
# Build the command and run it
cmd = f"python stylegan3/train.py --cfg {cfg} --gpus {gpus} --batch=16  --gamma=10 --snap {SNAP} --outdir {EXPERIMENTS} --data {DATA}"
!{cmd}

^C


In [None]:
python stylegan3/train.py --cfg "stylegan2" --gpus 1 --batch 16  --gamma 10 --snap 10 --outdir "gan/experiments" --data "gan/dataset/celeba_hq_256"

In [None]:
cmd = f"python stylegan3/train.py --help"
!{cmd}

## Resume Training

In [None]:
import os

# Modify these to suit your needs
EXPERIMENTS = "gan/experiments"
NETWORK = "network-snapshot-000100.pkl"
RESUME = os.path.join(EXPERIMENTS, \
                "00008-circuit-auto1-resumecustom", NETWORK)
DATA = "gan/dataset/circuit"
SNAP = 10

# Build the command and run it
cmd = f"/usr/bin/python3 /content/stylegan3/train.py "\
  f"--snap {SNAP} --resume {RESUME} --outdir {EXPERIMENTS} --data {DATA}"
!{cmd}

In [None]:
python stylegan3/train.py --cfg "stylegan2" --gpus 1 --batch 16  --gamma 10 --snap 10 --outdir "gan/experiments" --data "gan/dataset/celeba_hq_256" --resume "gan/experiments/00000-stylegan2-celeba_hq_256-gpus1-batch8-gamma10/network-snapshot-000400.pkl" 