In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.append('./../')

from collections import defaultdict

import torch
import tabulate
import numpy as np

import matplotlib_inline
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec

%matplotlib inline
matplotlib_inline.backend_inline.set_matplotlib_formats('pdf', 'svg')

For the sake of simplicity, most of the additional functions were implemented in the `nb_utils` module.

# Load checkpoints

One can load an existing model using `nb_utils.load_checkpoint`. 

This function will return the Generator module (both trainee and ema-smoothed) as well as the initial Generator that corresponds to `--resume` CLI training option. Also, this function will parse and return the `training_options.json` and `metric-*.jsonl` files.

In [None]:
import nb_utils

In [None]:
device = torch.device('cuda:0')
base_exp_path = '~/StyleDomain/DissimilarDomains/training-runs'

In [None]:
exp_path = os.path.join(base_exp_path, '00066-afhqdog-stylegan2-kimg241-resumeffhq512')
(
    G, G_ema, # Trained Generator
    G_base, G_ema_base, # Resume Generator
    options, # Exact arguments that were used to train the model. See `training_options.json`
    metrics # Metrics that were computed during training and after it using `calc_metrics.py`
) = nb_utils.load_checkpoint(
    exp_path=exp_path, # Path to experiment folder
    chkpt_idx=241, # Checkpoint idx
    device=device # GPU or CPU device to store the module
)

In [None]:
print(*metrics.items(), sep='\n')

# Compute metrics

We can compute other metrics using `metric_main.calc_metric`. For example, here we compute $\text{FID5k}$ (aka $\text{FID}$) for the model finetuned on the **Dog** dataset with **Full** parameterization. It is worth noting that we also report $\text{FID} = 20.3$ in the paper.

In [None]:
from metrics import metric_main

In [None]:
metric = metric_main.calc_metric(
    metric='fid5k', G=G_ema, device=device,
    dataset_kwargs=options['training_set_kwargs']
)
print('{0} = {1:.3f}'.format(metric['metric'], metric['results'][metric['metric']]))

# Generate sample images

We can generate images from a pretrained model using the `nb_utils.generate_images` wrapper around Generator.

In [None]:
grid_size = np.array([4, 4])
images, _ = nb_utils.generate_images(G_ema, grid_size=grid_size, device=device)

In [None]:
fig, axes = plt.subplots(*grid_size, figsize=grid_size[::-1] * 2)

nb_utils.prepare_axes(axes)
for ax, image in zip(axes.reshape(-1), images):
    ax.imshow(image)

fig.suptitle('Uncurated Dog samples for Full parameterization')
    
fig.subplots_adjust(wspace=0.01, hspace=0.01)
fig.tight_layout()
plt.show()

# Results reproduction

To reproduce all results, all necessary models must be trained. More precisely, it is the cartesian product of all datasets and parameterizations ($56$ models in total):

$$[\text{Metfaces}, \text{Mega}, \text{Ukiyoe}, \text{Dog}, \text{Cat}, \text{Car}, \text{Church}, \text{Flowers}] \times [\text{Full}, \text{SyntConv}, \text{Affine}+, \text{Affine}, \text{Mapping}, \text{AffineLight}+, \text{StyleSpace}]$$

Let's define a list of all those experiments:

In [None]:
metfaces_512 = {
    'Full':         '00114-metfaces_512-stylegan2-kimg241-resumeffhq512',
    'Mapping':      '00114-metfaces_512-stylegan2-kimg241-resumeffhq512-Gparts-mapping',
    'SyntConv':     '00314-metfaces_512-stylegan2-kimg241-resumeffhq512-Gparts-synt_conv,tRGB_conv',
    'Affine':       '00115-metfaces_512-stylegan2-kimg241-resumeffhq512-Gparts-synt_affine,tRGB_affine',
    'Affine+':      '00304-metfaces_512-stylegan2-glrate0.02-kimg241-resumeffhq512-Gparts-synt_affine,tRGB_affine,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive',
    'StyleSpace':   '00149-metfaces_512-stylegan2-glrate0.02-kimg241-resumeffhq512-Gparts-synt_offset,tRGB_offset-use-dom-mod-additive',
    'AffineLight+': '00253-metfaces_512-stylegan2-glrate0.008-kimg241-resumeffhq512-Gparts-synt_affine_weights_offset,tRGB_affine_weights_offset,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive,affine_out_in_5_1_additive',
}
Mega_512 = {
    'Full':         '00087-Mega_512-stylegan2-kimg241-resumeffhq512',
    'Mapping':      '00088-Mega_512-stylegan2-kimg241-resumeffhq512-Gparts-mapping',
    'SyntConv':     '00315-Mega_512-stylegan2-kimg241-resumeffhq512-Gparts-synt_conv,tRGB_conv',
    'Affine':       '00088-Mega_512-stylegan2-kimg241-resumeffhq512-Gparts-synt_affine,tRGB_affine',
    'Affine+':      '00312-Mega_512-stylegan2-glrate0.02-kimg241-resumeffhq512-Gparts-synt_affine,tRGB_affine,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive',
    'StyleSpace':   '00149-Mega_512-stylegan2-glrate0.02-kimg241-resumeffhq512-Gparts-synt_offset,tRGB_offset-use-dom-mod-additive',
    'AffineLight+': '00318-Mega_512-stylegan2-glrate0.008-kimg241-resumeffhq512-Gparts-synt_affine_weights_offset,tRGB_affine_weights_offset,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive,affine_out_in_5_1_additive',
}
ukiyoe_512 = {
    'Full':         '00117-ukiyoe_512-stylegan2-kimg241-resumeffhq512',
    'Mapping':      '00117-ukiyoe_512-stylegan2-kimg241-resumeffhq512-Gparts-mapping',
    'SyntConv':     '00315-ukiyoe_512-stylegan2-kimg241-resumeffhq512-Gparts-synt_conv,tRGB_conv',
    'Affine':       '00116-ukiyoe_512-stylegan2-kimg241-resumeffhq512-Gparts-synt_affine,tRGB_affine',
    'Affine+':      '00312-ukiyoe_512-stylegan2-glrate0.02-kimg241-resumeffhq512-Gparts-synt_affine,tRGB_affine,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive',
    'StyleSpace':   '00149-ukiyoe_512-stylegan2-glrate0.02-kimg241-resumeffhq512-Gparts-synt_offset,tRGB_offset-use-dom-mod-additive',
    'AffineLight+': '00319-ukiyoe_512-stylegan2-glrate0.008-kimg241-resumeffhq512-Gparts-synt_affine_weights_offset,tRGB_affine_weights_offset,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive,affine_out_in_5_1_additive',
}
afhqdog = {
    'Full':         '00066-afhqdog-stylegan2-kimg241-resumeffhq512',
    'Mapping':      '00066-afhqdog-stylegan2-kimg241-resumeffhq512-Gparts-mapping',
    'SyntConv':     '00315-afhqdog-stylegan2-kimg241-resumeffhq512-Gparts-synt_conv,tRGB_conv',
    'Affine':       '00068-afhqdog-stylegan2-kimg241-resumeffhq512-Gparts-synt_affine,tRGB_affine',
    'Affine+':      '00313-afhqdog-stylegan2-glrate0.02-kimg241-resumeffhq512-Gparts-synt_affine,tRGB_affine,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive',
    'StyleSpace':   '00147-afhqdog-stylegan2-glrate0.02-kimg241-resumeffhq512-Gparts-synt_offset,tRGB_offset-use-dom-mod-additive',
    'AffineLight+': '00253-afhqdog-stylegan2-glrate0.008-kimg241-resumeffhq512-Gparts-synt_affine_weights_offset,tRGB_affine_weights_offset,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive,affine_out_in_5_1_additive',
}
afhqcat = {   
    'Full':         '00065-afhqcat-stylegan2-kimg241-resumeffhq512',
    'Mapping':      '00065-afhqcat-stylegan2-kimg241-resumeffhq512-Gparts-mapping',
    'SyntConv':     '00315-afhqcat-stylegan2-kimg241-resumeffhq512-Gparts-synt_conv,tRGB_conv',
    'Affine':       '00067-afhqcat-stylegan2-kimg241-resumeffhq512-Gparts-synt_affine,tRGB_affine',
    'Affine+':      '00299-afhqcat-stylegan2-glrate0.02-kimg241-resumeffhq512-Gparts-synt_affine,tRGB_affine,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive',
    'StyleSpace':   '00147-afhqcat-stylegan2-glrate0.02-kimg241-resumeffhq512-Gparts-synt_offset,tRGB_offset-use-dom-mod-additive',
    'AffineLight+': '00253-afhqcat-stylegan2-glrate0.008-kimg241-resumeffhq512-Gparts-synt_affine_weights_offset,tRGB_affine_weights_offset,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive,affine_out_in_5_1_additive',
}
car = {
    'Full':         '00220-lsun_cars_512_10k-stylegan2-kimg241-resumeffhq512',
    'Mapping':      '00220-lsun_cars_512_10k-stylegan2-kimg241-resumeffhq512-Gparts-mapping',
    'SyntConv':     '00315-lsun_cars_512_10k-stylegan2-kimg241-resumeffhq512-Gparts-synt_conv,tRGB_conv',
    'Affine':       '00262-lsun_cars_512_10k-stylegan2-kimg241-resumeffhq512-Gparts-synt_affine,tRGB_affine',
    'Affine+':      '00300-lsun_cars_512_10k-stylegan2-glrate0.02-kimg241-resumeffhq512-Gparts-synt_affine,tRGB_affine,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive',
    'StyleSpace':   '00223-lsun_cars_512_10k-stylegan2-glrate0.02-kimg241-resumeffhq512-Gparts-synt_offset,tRGB_offset-use-dom-mod-additive',
    'AffineLight+': '00318-lsun_cars_512_10k-stylegan2-glrate0.008-kimg241-resumeffhq512-Gparts-synt_affine_weights_offset,tRGB_affine_weights_offset,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive,affine_out_in_5_1_additive',
}
church = {
    'Full':         '00221-lsun_church_outdoor_train_512_10k-stylegan2-kimg241-resumeffhq512',
    'Mapping':      '00221-lsun_church_outdoor_train_512_10k-stylegan2-kimg241-resumeffhq512-Gparts-mapping',
    'SyntConv':     '00315-lsun_church_outdoor_train_512_10k-stylegan2-kimg241-resumeffhq512-Gparts-synt_conv,tRGB_conv',
    'Affine':       '00262-lsun_church_outdoor_train_512_10k-stylegan2-kimg241-resumeffhq512-Gparts-synt_affine,tRGB_affine',
    'Affine+':      '00313-lsun_church_outdoor_train_512_10k-stylegan2-glrate0.02-kimg241-resumeffhq512-Gparts-synt_affine,tRGB_affine,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive',
    'StyleSpace':   '00224-lsun_church_outdoor_train_512_10k-stylegan2-glrate0.02-kimg241-resumeffhq512-Gparts-synt_offset,tRGB_offset-use-dom-mod-additive',
    'AffineLight+': '00253-lsun_church_outdoor_train_512_10k-stylegan2-glrate0.008-kimg241-resumeffhq512-Gparts-synt_affine_weights_offset,tRGB_affine_weights_offset,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive,affine_out_in_5_1_additive',
}
flowers = {
    'Full':         '00142-flowers_102_train_random_01-stylegan2-kimg482-resumeffhq512',
    'Mapping':      '00205-flowers_102_train_random_01-stylegan2-kimg241-resumeffhq512-Gparts-mapping',
    'SyntConv':     '00315-flowers_102_train_random_01-stylegan2-kimg241-resumeffhq512-Gparts-synt_conv,tRGB_conv',
    'Affine':       '00262-flowers_102_train_random_01-stylegan2-kimg241-resumeffhq512-Gparts-synt_affine,tRGB_affine',
    'Affine+':      '00300-flowers_102_train_random_01-stylegan2-glrate0.02-kimg241-resumeffhq512-Gparts-synt_affine,tRGB_affine,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive',
    'StyleSpace':   '00208-flowers_102_train_random_01-stylegan2-glrate0.02-kimg241-resumeffhq512-Gparts-synt_offset,tRGB_offset-use-dom-mod-additive',
    'AffineLight+': '00318-flowers_102_train_random_01-stylegan2-glrate0.008-kimg241-resumeffhq512-Gparts-synt_affine_weights_offset,tRGB_affine_weights_offset,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive,affine_out_in_5_1_additive',
}

general_experiments = {
    'Metfaces': metfaces_512,
    'Mega': Mega_512,
    'Ukiyoe': ukiyoe_512,
    'Dog': afhqdog,
    'Cat': afhqcat,
    'Car': car,
    'Church': church,
    'Flowers': flowers,
}

chkpt_idxs = {
    'Metfaces': 241,
    'Mega': 40,
    'Ukiyoe': 100,
    'Dog': 241,
    'Cat': 241,
    'Car': 241,
    'Church': 241,
    'Flowers': 241,
}

Load all checkpoints:

In [None]:
all_models = defaultdict(defaultdict)

for dataset_name, dataset_exps in general_experiments.items():
    chkpt_idx = chkpt_idxs[dataset_name]
    
    for parameterization_name, exp_suffix in dataset_exps.items():
        exp_path = os.path.join(base_exp_path, exp_suffix)
        _, G_ema, _, _, _, metrics = nb_utils.load_checkpoint(
            exp_path=exp_path, chkpt_idx=chkpt_idx, device=torch.device('cpu')
        )
        
        fid5k = dict(metrics['fid5k'])[chkpt_idx]
        kid5k = dict(metrics['kid5k'])[chkpt_idx]
        fid50k = dict(metrics['fid50k'])[chkpt_idx]
        kid50k = dict(metrics['kid50k'])[chkpt_idx]

        all_models[parameterization_name][dataset_name] = (G_ema, (fid5k, kid5k, fid50k, kid50k))

### Table 2

In [None]:
table_02 = []
for pname in ['Full', 'SyntConv', 'Affine', 'Mapping', 'Affine+', 'AffineLight+', 'StyleSpace']:
    table_02.append([pname])
    for dname in general_experiments:
        _, (fid5k, _, _, _) = all_models[pname][dname]
        table_02[-1].append(fid5k)
        
print(tabulate.tabulate(table_02, headers=general_experiments, floatfmt='.1f'))

### Table 7

In [None]:
table_07 = []
for pname in ['Full', 'SyntConv', 'Affine+', 'Affine', 'Mapping']:
    table_07.append([pname])
    for dname in general_experiments:
        _, (_, _, fid50k, _) = all_models[pname][dname]
        table_07[-1].append(fid50k)
        
print(tabulate.tabulate(table_07, headers=general_experiments, floatfmt='.1f'))

### Table 8

In [None]:
table_08 = []
for pname in ['Full', 'SyntConv', 'Affine+', 'Affine', 'Mapping']:
    table_08.append([pname])
    for dname in general_experiments:
        _, (_, _, _, kid50k) = all_models[pname][dname]
        table_08[-1].append(kid50k * 1000)
        
print(tabulate.tabulate(table_08, headers=general_experiments, floatfmt='.1f'))

### Table 12

In [None]:
table_12 = []
for pname in ['Full', 'Affine+', 'AffineLight+', 'StyleSpace']:
    table_12.append([pname])
    for dname in general_experiments:
        _, (_, _, fid50k, _) = all_models[pname][dname]
        table_12[-1].append(fid50k)
        
print(tabulate.tabulate(table_12, headers=general_experiments, floatfmt='.1f'))

### Table 13

In [None]:
table_13 = []
for pname in ['Full', 'Affine+', 'AffineLight+', 'StyleSpace']:
    table_13.append([pname])
    for dname in general_experiments:
        _, (_, _, _, kid50k) = all_models[pname][dname]
        table_13[-1].append(kid50k * 1000)
        
print(tabulate.tabulate(table_13, headers=general_experiments, floatfmt='.1f'))

### Figure 4

In [None]:
grid_size = np.array([2, 7])
fig, axes = plt.subplots(*grid_size, figsize=grid_size[::-1] * np.array([3, 3.2]))
nb_utils.prepare_axes(axes)

for idx, (dname, seed) in enumerate([('Cat', 964), ('Dog', 912),]):
    for jdx, pname in enumerate(['Full', 'SyntConv', 'Affine', 'Mapping', 'Affine+', 'AffineLight+', 'StyleSpace']):
        G_ema, _ = all_models[pname][dname]
        images, _ = nb_utils.generate_images(G_ema, grid_size=1, device=device, seed=seed, truncation_psi=0.9)
        
        axes[idx, jdx].imshow(images[0])
        if idx == 0:
            axes[idx, jdx].set_title(pname, fontdict=dict(fontsize=20, weight='bold'))
        if jdx == 0:
            axes[idx, jdx].set_ylabel(dname, fontdict=dict(fontsize=20, weight='bold'))
        
fig.subplots_adjust(wspace=0.05, hspace=0.05)
fig.tight_layout()
plt.show()

### Figure 13

In [None]:
grid_size = np.array([7, 8])
fig, axes = plt.subplots(*grid_size, figsize=grid_size[::-1] * 1.25)
nb_utils.prepare_axes(axes)

for idx, pname in enumerate(['Full', 'SyntConv', 'Affine+', 'Affine', 'AffineLight+', 'StyleSpace', 'Mapping']):
    for jdx, (dname, seed) in enumerate([
        ('Metfaces', 518), ('Mega', 312), ('Ukiyoe', 520),('Dog', 527), 
        ('Cat', 514), ('Car', 530), ('Church',  506), ('Flowers', 324)
    ]):
        G_ema, _ = all_models[pname][dname]
        images, _ = nb_utils.generate_images(G_ema, grid_size=1, device=device, seed=seed, truncation_psi=0.9)
        
        axes[idx, jdx].imshow(images[0])
        if idx == 0:
            axes[idx, jdx].set_title(dname, fontdict=dict(fontsize=10, weight='bold'))
        if jdx == 0:
            axes[idx, jdx].set_ylabel(pname, fontdict=dict(fontsize=10, weight='bold'))
        
fig.subplots_adjust(wspace=0.01, hspace=0.01)
fig.tight_layout()
plt.show()

### Figure 14

In [None]:
meta_grid_size = np.array([len(general_experiments) // 4, 4])
block_grid_size = np.array([len(metfaces_512), 6])
grid_size = meta_grid_size * block_grid_size
fig = plt.figure(figsize=grid_size[::-1] * 1.0)

gs_outer = GridSpec(*meta_grid_size, hspace=0.05, wspace=0.05) 
gs_blocks = np.array([
    [
        GridSpecFromSubplotSpec(
            *block_grid_size, subplot_spec=gs_outer[idx, jdx], hspace=0.05, wspace=0.05
        ) for jdx in range(meta_grid_size[1])
    ] for idx in range(meta_grid_size[0]) 
]).reshape(-1)

for idx, (dname, seed_jdx) in enumerate([
    ('Metfaces', 400), ('Mega', 401), ('Ukiyoe', 402),('Dog', 403), 
    ('Cat', 404), ('Car', 405), ('Church',  406), ('Flowers', 407)
]):
    axes = np.array([
        [
            fig.add_subplot(
                gs_blocks[idx][bidx, bjdx]
            ) for bjdx in range(block_grid_size[1])
        ] for bidx in range(block_grid_size[0])
    ])
    nb_utils.prepare_axes(axes)

    for jdx, (pname, seed_idx) in enumerate([
        ('Full', 1000), ('SyntConv', 2000), ('Affine+', 3000), ('Affine', 4000), 
        ('AffineLight+', 6000), ('StyleSpace', 7000), ('Mapping', 5000)
    ]):
        seed = seed_idx + seed_jdx

        G_ema, _ = all_models[pname][dname]
        images, _ = nb_utils.generate_images(G_ema, grid_size=len(axes[jdx]), device=device, seed=seed, truncation_psi=0.9)
        for ax, image in zip(axes[jdx], images):
            ax.imshow(image)
            
        if idx % meta_grid_size[1] == 0:
            axes[jdx][0].set_ylabel(pname, fontdict=dict(fontsize=6, weight='bold'))
            
plt.show()