
# Training Projected GAN
This is a self-contained notebook for training Projected GAN.

## Setup

Make sure you're running a GPU runtime; if not, select "GPU" as the hardware accelerator in Runtime > Change Runtime Type in the menu. 

Now, get the repo and install missing dependencies.

In [3]:
#%cd /content/drive/MyDrive/Places/Projected-GAN

/content/drive/MyDrive/Places/Projected-GAN


In [2]:
%%capture
%%bash
# # clone repo
# git clone https://github.com/autonomousvision/projected_gan
# pip install timm dill

In [1]:
!pip install timm dill

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm
  Downloading timm-0.6.13-py3-none-any.whl (549 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m549.1/549.1 KB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting dill
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 KB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting huggingface-hub
  Downloading huggingface_hub-0.13.4-py3-none-any.whl (200 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m200.1/200.1 KB[0m [31m25.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: dill, huggingface-hub, timm
Successfully installed dill-0.3.6 huggingface-hub-0.13.4 timm-0.6.13


In [2]:
!pip install timm==0.5.4
!pip install ftfy
!pip install Ninja
!pip install setuptools==59.5.0
!pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm==0.5.4
  Downloading timm-0.5.4-py3-none-any.whl (431 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/431.5 KB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━[0m [32m235.5/431.5 KB[0m [31m7.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m431.5/431.5 KB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: timm
  Attempting uninstall: timm
    Found existing installation: timm 0.6.13
    Uninstalling timm-0.6.13:
      Successfully uninstalled timm-0.6.13
Successfully installed timm-0.5.4
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ftfy
  Downloading ftfy-6.1.1-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.1/5

## Data Preparation
We need to download and prepare the data. In this example, we use the few-shot datasets of the [FastGAN repo](https://github.com/odegeasslbc/FastGAN-pytorch).

In [4]:
#!gdown https://drive.google.com/u/0/uc?id=1aAJCZbXNHyraJ6Mi13dSbe7pTyfPXha0&export=download

Downloading...
From: https://drive.google.com/u/0/uc?id=1aAJCZbXNHyraJ6Mi13dSbe7pTyfPXha0
To: /content/drive/MyDrive/Places/Projected-GAN/few-shot-image-datasets.zip
100% 913M/913M [00:07<00:00, 124MB/s]


In [12]:
# %%capture
# #!unzip few-shot-image-datasets.zip
# !mv few-shot-images data

In [8]:
%cd /content/drive/MyDrive/Places/Projected-GAN/projected_gan

/content/drive/MyDrive/Places/Projected-GAN/projected_gan


In [None]:
# %%bash
# python dataset_tool.py --source=/content/drive/MyDrive/Places/Dataset --dest=/content/drive/MyDrive/Places/Projected-GAN/alter-Image --resolution=256x256

## Training

Now that the data is prepared, we can start training!  The training loop tracks FID, but the computations seems to lead to problems in colab. Hence, it is disable by default (```metrics=[]```). The loop also generates fixed noise samples after a defined amount of ticks, eg. below ```--snap=1```.

In [10]:
import os
import json
import re
import dnnlib

from training import training_loop
from torch_utils import training_stats
from train import init_dataset_kwargs
from metrics import metric_main

In [11]:
def launch_training(c, desc, outdir, rank=0):
    # Pick output directory.
    prev_run_dirs = []
    if os.path.isdir(outdir):
        prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))]

    matching_dirs = [re.fullmatch(r'\d{5}' + f'-{desc}', x) for x in prev_run_dirs if re.fullmatch(r'\d{5}' + f'-{desc}', x) is not None]
    if c.restart_every > 0 and len(matching_dirs) > 0:  # expect unique desc, continue in this directory
        assert len(matching_dirs) == 1, f'Multiple directories found for resuming: {matching_dirs}'
        c.run_dir = os.path.join(outdir, matching_dirs[0].group())
    else:                     # fallback to standard
        prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
        prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
        cur_run_id = max(prev_run_ids, default=-1) + 1
        c.run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{desc}')
        assert not os.path.exists(c.run_dir)


    # Print options.
    print()
    print('Training options:')
    print(json.dumps(c, indent=2))
    print()
    print(f'Output directory:    {c.run_dir}')
    print(f'Number of GPUs:      {c.num_gpus}')
    print(f'Batch size:          {c.batch_size} images')
    print(f'Training duration:   {c.total_kimg} kimg')
    print(f'Dataset path:        {c.training_set_kwargs.path}')
    print(f'Dataset size:        {c.training_set_kwargs.max_size} images')
    print(f'Dataset resolution:  {c.training_set_kwargs.resolution}')
    print(f'Dataset labels:      {c.training_set_kwargs.use_labels}')
    print(f'Dataset x-flips:     {c.training_set_kwargs.xflip}')
    print()

    # Create output directory.
    print('Creating output directory...')
    os.makedirs(c.run_dir, exist_ok=c.restart_every > 0)
    with open(os.path.join(c.run_dir, 'training_options.json'), 'wt+') as f:
        json.dump(c, f, indent=2)

    # Start training
    dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=False)
    sync_device = torch.device('cuda', rank) if c.num_gpus > 1 else None
    training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
    training_loop.training_loop(rank=rank, **c)

In [12]:
def train(**kwargs):
    # Initialize config.
    opts = dnnlib.EasyDict(kwargs) # Command line arguments.
    c = dnnlib.EasyDict() # Main config dict.
    c.G_kwargs = dnnlib.EasyDict(class_name=None, z_dim=64, w_dim=128, mapping_kwargs=dnnlib.EasyDict())
    c.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8)
    c.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8)
    c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, prefetch_factor=2)

    # Training set.
    c.training_set_kwargs, dataset_name = init_dataset_kwargs(data=opts.data)
    if opts.cond and not c.training_set_kwargs.use_labels:
        raise ValueError('--cond=True requires labels specified in dataset.json')
    c.training_set_kwargs.use_labels = opts.cond
    c.training_set_kwargs.xflip = opts.mirror

    # Hyperparameters & settings.
    c.num_gpus = opts.gpus
    c.batch_size = opts.batch
    c.batch_gpu = opts.batch_gpu or opts.batch // opts.gpus
    c.G_kwargs.channel_base = opts.cbase
    c.G_kwargs.channel_max = opts.cmax
    c.G_kwargs.mapping_kwargs.num_layers = 2
    c.G_opt_kwargs.lr = (0.002 if opts.cfg == 'stylegan2' else 0.0025) if opts.glr is None else opts.glr
    c.D_opt_kwargs.lr = opts.dlr
    c.metrics = opts.metrics
    c.total_kimg = opts.kimg
    c.kimg_per_tick = opts.tick
    c.image_snapshot_ticks = c.network_snapshot_ticks = opts.snap
    c.random_seed = c.training_set_kwargs.random_seed = opts.seed
    c.data_loader_kwargs.num_workers = opts.workers

    # Sanity checks.
    if c.batch_size % c.num_gpus != 0:
        raise ValueError('--batch must be a multiple of --gpus')
    if c.batch_size % (c.num_gpus * c.batch_gpu) != 0:
        raise ValueError('--batch must be a multiple of --gpus times --batch-gpu')
    if any(not metric_main.is_valid_metric(metric) for metric in c.metrics):
        raise ValueError('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))

    # Base configuration.
    c.ema_kimg = c.batch_size * 10 / 32
    if opts.cfg == 'stylegan2':
        c.G_kwargs.class_name = 'pg_modules.networks_stylegan2.Generator'
        c.G_kwargs.fused_modconv_default = 'inference_only' # Speed up training by using regular convolutions instead of grouped convolutions.
        use_separable_discs = True

    elif opts.cfg == 'fastgan':
        c.G_kwargs = dnnlib.EasyDict(class_name='pg_modules.networks_fastgan.Generator', cond=opts.cond)
        c.G_opt_kwargs.lr = c.D_opt_kwargs.lr = 0.0002
        use_separable_discs = False

    # Restart.
    c.restart_every = opts.restart_every

    # Description string.
    desc = f'{opts.cfg:s}-{dataset_name:s}-gpus{c.num_gpus:d}-batch{c.batch_size:d}'
    if opts.desc is not None:
        desc += f'-{opts.desc}'

    # Projected and Multi-Scale Discriminators
    c.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.ProjectedGANLoss')
    c.D_kwargs = dnnlib.EasyDict(
        class_name='pg_modules.discriminator.ProjectedDiscriminator',
        diffaug=True,
        interp224=(c.training_set_kwargs.resolution < 224),
        backbone_kwargs=dnnlib.EasyDict(),
    )

    c.D_kwargs.backbone_kwargs.cout = 64
    c.D_kwargs.backbone_kwargs.expand = True
    c.D_kwargs.backbone_kwargs.proj_type = 2
    c.D_kwargs.backbone_kwargs.num_discs = 4
    c.D_kwargs.backbone_kwargs.separable = use_separable_discs
    c.D_kwargs.backbone_kwargs.cond = opts.cond

    # Launch.
    launch_training(c=c, desc=desc, outdir=opts.outdir)

In [None]:
# start training!

train(
    outdir='training-runs', 
    cfg='fastgan',
    data='/content/drive/MyDrive/Places/LUSN_dataset_bedroom/lsun_train_output_dir', 
    gpus=1, 
    batch=64, 
    cond=False, 
    mirror=1, 
    batch_gpu=8, 
    cbase=32768, 
    cmax=512, 
    glr=None, 
    dlr=0.002, 
    desc='', 
    metrics=[],
    kimg=10000, 
    tick=4, 
    snap=1, 
    seed=0, 
    workers=1,
    restart_every=999999,
)


Training options:
{
  "G_kwargs": {
    "class_name": "pg_modules.networks_fastgan.Generator",
    "cond": false
  },
  "G_opt_kwargs": {
    "class_name": "torch.optim.Adam",
    "betas": [
      0,
      0.99
    ],
    "eps": 1e-08,
    "lr": 0.0002
  },
  "D_opt_kwargs": {
    "class_name": "torch.optim.Adam",
    "betas": [
      0,
      0.99
    ],
    "eps": 1e-08,
    "lr": 0.0002
  },
  "data_loader_kwargs": {
    "pin_memory": true,
    "prefetch_factor": 2,
    "num_workers": 1
  },
  "training_set_kwargs": {
    "class_name": "training.dataset.ImageFolderDataset",
    "path": "/content/drive/MyDrive/Places/LUSN_dataset_bedroom/lsun_train_output_dir",
    "use_labels": false,
    "max_size": 8098,
    "xflip": 1,
    "resolution": 256,
    "random_seed": 0
  },
  "num_gpus": 1,
  "batch_size": 64,
  "batch_gpu": 8,
  "metrics": [],
  "total_kimg": 10000,
  "kimg_per_tick": 4,
  "image_snapshot_ticks": 1,
  "network_snapshot_ticks": 1,
  "random_seed": 0,
  "ema_kimg": 20.0

To inspect the samples, click on the folder symbol on the left and navigate to 

```projected_gan/training-runs/YOUR_RUN```

The files ```fakesXXXXXX.png``` are the samples for a fixed noise vector at point.

In [15]:
!python calc_metrics.py --metrics=fid50k_full --data=/content/drive/MyDrive/Places/Projected-GAN/projected_gan/training-runs/00001-fastgan-lsun_train_output_dir-gpus1-batch64-  \
--network=/content/drive/MyDrive/Places/Projected-GAN/projected_gan/training-runs/00001-fastgan-lsun_train_output_dir-gpus1-batch64-/network-snapshot.pkl

Loading network from "/content/drive/MyDrive/Places/Projected-GAN/projected_gan/training-runs/00001-fastgan-lsun_train_output_dir-gpus1-batch64-/network-snapshot.pkl"...
Dataset options:
{
  "class_name": "training.dataset.ImageFolderDataset",
  "path": "/content/drive/MyDrive/Places/Projected-GAN/projected_gan/training-runs/00001-fastgan-lsun_train_output_dir-gpus1-batch64-",
  "resolution": 256,
  "use_labels": false
}
Launching processes...

Generator              Parameters  Buffers  Output shape        Datatype
---                    ---         ---      ---                 ---     
mapping                -           -        [1, 1, 256]         float32 
synthesis.init.init    16785408    16385    [1, 2048, 4, 4]     float32 
synthesis.feat_8.0     -           -        [1, 2048, 8, 8]     float32 
synthesis.feat_8.1     37748736    20480    [1, 2048, 8, 8]     float32 
synthesis.feat_8.2     1           -        [1, 2048, 8, 8]     float32 
synthesis.feat_8.3     4096        4097 

In [14]:
!python calc_metrics.py --help

Usage: calc_metrics.py [OPTIONS]

  Calculate quality metrics for previous training
  run or pretrained network pickle.

  Examples:

  # Previous training run: look up options automatically, save result to JSONL file.
  python calc_metrics.py --metrics=eqt50k_int,eqr50k \
      --network=~/training-runs/00000-stylegan3-r-mydataset/network-snapshot-000000.pkl

  # Pre-trained network pickle: specify dataset explicitly, print result to stdout.
  python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq-1024x1024.zip --mirror=1 \
      --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl

  Recommended metrics:
    fid50k_full  Frechet inception distance against the full dataset.
    kid50k_full  Kernel inception distance against the full dataset.
    pr50k3_full  Precision and recall againt the full dataset.
    ppl2_wend    Perceptual path length in W, endpoints, full image.
    eqt50k_int   Equivariance w.r.

In [16]:
!python calc_metrics.py --metrics=fid50k_full --network=/content/drive/MyDrive/Places/Projected-GAN/projected_gan/training-runs/00001-fastgan-lsun_train_output_dir-gpus1-batch64-/network-snapshot.pkl

Loading network from "/content/drive/MyDrive/Places/Projected-GAN/projected_gan/training-runs/00001-fastgan-lsun_train_output_dir-gpus1-batch64-/network-snapshot.pkl"...
Dataset options:
{
  "class_name": "training.dataset.ImageFolderDataset",
  "path": "/content/drive/MyDrive/Places/LUSN_dataset_bedroom/lsun_train_output_dir",
  "use_labels": false,
  "max_size": 8098,
  "xflip": 1,
  "resolution": 256,
  "random_seed": 0
}
Launching processes...

Generator              Parameters  Buffers  Output shape        Datatype
---                    ---         ---      ---                 ---     
mapping                -           -        [1, 1, 256]         float32 
synthesis.init.init    16785408    16385    [1, 2048, 4, 4]     float32 
synthesis.feat_8.0     -           -        [1, 2048, 8, 8]     float32 
synthesis.feat_8.1     37748736    20480    [1, 2048, 8, 8]     float32 
synthesis.feat_8.2     1           -        [1, 2048, 8, 8]     float32 
synthesis.feat_8.3     4096        4