In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.append('./../')

import glob

from collections import defaultdict

import torch
import numpy as np

import matplotlib_inline
import matplotlib.pyplot as plt

%matplotlib inline
matplotlib_inline.backend_inline.set_matplotlib_formats('pdf', 'svg')

For the sake of simplicity, most of the additional functions were implemented in the `nb_utils` module.

See `General Results.ipynb` for information about basic operations with a trained model (loading, metrics evaluation, inference).

# $Z$ space projection

We use the same approach to inverse images into latent space of `Generator` as the authors of the [StyleAlign](https://openreview.net/pdf?id=Qg2vi4ZbHM9) paper. We reimplemented [their approach](https://github.com/betterze/StyleAlign/blob/main/projector_z.py) on top of the original $W$ space inversion from the StyleGAN-ADA. 

In [None]:
import nb_utils

In [None]:
device = torch.device('cuda:0')
base_exp_path = '~/StyleDomain/DissimilarDomains/training-runs'

Inversion can be performed using the `projector.py` script by specifying the path to a single image (or to the images folder). An inverse of a sample image created with a model fine-tuned to the **Cat** domain in **Full** parameterization is shown below. 

In [None]:
image_path = './../samples/41_0.jpg'
outdir_path = './inverted_samples/41_0/'
network_path = os.path.join(base_exp_path, '00065-afhqcat-stylegan2-kimg241-resumeffhq512', 'network-snapshot-000241.pkl')

In [None]:
! python ./../projector.py --target "$image_path" --outdir "$outdir_path" --network "$network_path" \
    --space 'z' --truncation-psi 0.7  --num-steps 1000 \
    --gpu 0 --save-image --save-video --save-all-steps

As a result of this, the method generates the following files:

* `x_target.png` — image that was projected to the latent space
* `x_proj.png` — image obtained from the resulting latent 
* `x_proj.mp4` — amination of projection procedure


* `x_projected_z.npz` — stores dictionary `{'z': torch.Tensor, 'z_steps': torch.Tensor}` — final and all intermediate latents during optimization for the image
* `projected_z.npz` — stores dictionary `{'z': torch.Tensor, 'z_steps': torch.Tensor}` — final and all intermediate latents during optimization for all images

In [None]:
! ls "$outdir_path"

In [None]:
result = np.load(os.path.join(outdir_path, '0_projected_z.npz'), allow_pickle=True)
result['z'].shape, result['z_steps'].shape

| Original image                | Image after $Z$ space inversion                        |
|-------------------------------|--------------------------------------------------------|
| ![img](./../samples/41_0.jpg) | ![img](./inverted_samples/41_0/0_proj.png) |

# Unconditional I2I

Then we can use models from other domains to perform inference from the obtained latent vector. For example, here we use the model for the **Dog** domain in **Full** parameterization:

In [None]:
exp_path = os.path.join(base_exp_path, '00066-afhqdog-stylegan2-kimg241-resumeffhq512')
_, G_ema, _, _, _, _ = nb_utils.load_checkpoint(exp_path=exp_path, chkpt_idx=241, device=device)

In [None]:
images, _ = nb_utils.generate_images(
    G_ema, grid_size=1, device=device, truncation_psi=0.8,
    target_zs=torch.tensor(result['z']).to(device)
)

In [None]:
grid_size = np.array([1, 2])
fig, axes = plt.subplots(*grid_size, figsize=grid_size[::-1] * 4)

nb_utils.prepare_axes(axes)
axes[0].imshow(plt.imread(image_path))
axes[1].imshow(images[0])

axes[0].set_title('Input', fontdict=dict(fontsize=12, weight='bold'))
axes[1].set_title('Full', fontdict=dict(fontsize=12, weight='bold'))
    
fig.subplots_adjust(wspace=0.01, hspace=0.01)
plt.show()

# Reference-based I2I

For reference-based I2I, we need to get the latent vector of the reference image. Here we will perform I2I using the same source **Cat** image and a **Dog** reference image.

In [None]:
ref_image_path = './../samples/pixabay_dog_003552.jpg'
ref_outdir_path = './inverted_samples/pixabay_dog_003552/'
ref_network_path = os.path.join(base_exp_path, '00066-afhqdog-stylegan2-kimg241-resumeffhq512', 'network-snapshot-000241.pkl')

In [None]:
! python ./../projector.py --target "$ref_image_path" --outdir "$ref_outdir_path" --network "$ref_network_path" \
    --space 'z' --truncation-psi 0.7  --num-steps 1000 \
    --gpu 0 --save-image --save-video --save-all-steps

In [None]:
ref_result = np.load(os.path.join(ref_outdir_path, '0_projected_z.npz'), allow_pickle=True)

Then, we should combine latent vectors for source and reference images. We combine the first $6$ style codes from the source image with the latest codes from the reference image. 

Note that although transformation is defined in Style Space we can define it equivalently in $W+$ space. For example, the first $6$ style codes have one-to-one correspondence to the first $4$ latent vectors in $W+$ space (for StyleGAN2 in $512\times 512$ resolution).

In [None]:
reference_slice = 4
G_ema.to(device).eval()

c = torch.empty([result['z'].shape[0], 0])
# Generate a W+ latent for the source image
ws = G_ema.mapping(
    torch.tensor(result['z']).to(device), c.to(device), truncation_psi=0.8
)
# Generate a W+ latent for the reference image
ref_ws = G_ema.mapping(
    torch.tensor(ref_result['z']).to(device), c.to(device), truncation_psi=0.8
)
# Combine latent vectors (which is equivalent to combining style codes)
ws[:, reference_slice:] = ref_ws[:, reference_slice:]

# Generate an image from the latent vector
images, _ = nb_utils.generate_images(
    G_ema, grid_size=1, device=device,
    target_ws=ws
)

In [None]:
grid_size = np.array([1, 3])
fig, axes = plt.subplots(*grid_size, figsize=grid_size[::-1] * 4)

nb_utils.prepare_axes(axes)
axes[0].imshow(plt.imread(image_path))
axes[1].imshow(plt.imread(ref_image_path))
axes[2].imshow(images[0])

axes[0].set_title('Source', fontdict=dict(fontsize=12, weight='bold'))
axes[1].set_title('Reference', fontdict=dict(fontsize=12, weight='bold'))
axes[2].set_title('Full', fontdict=dict(fontsize=12, weight='bold'))
    
fig.subplots_adjust(wspace=0.01, hspace=0.01)
plt.show()

# Results reproduction

To reproduce all results, all necessary models must be trained (in $512\times 512$ resolution). More specifically, it is the following configurations ($9$ models in total):

$$[\text{Dog}, \text{Cat}, \text{Wild}] \times [\text{Full}, \text{Affine}+, \text{AffineLight}+]$$

Let's define a list of all those experiments:

In [None]:
afhqdog = {
    'Full':         '00066-afhqdog-stylegan2-kimg241-resumeffhq512',
    'Affine+':      '00313-afhqdog-stylegan2-glrate0.02-kimg241-resumeffhq512-Gparts-synt_affine,tRGB_affine,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive',
    'AffineLight+': '00253-afhqdog-stylegan2-glrate0.008-kimg241-resumeffhq512-Gparts-synt_affine_weights_offset,tRGB_affine_weights_offset,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive,affine_out_in_5_1_additive',
}
afhqcat = {   
    'Full':         '00065-afhqcat-stylegan2-kimg241-resumeffhq512',
    'Affine+':      '00299-afhqcat-stylegan2-glrate0.02-kimg241-resumeffhq512-Gparts-synt_affine,tRGB_affine,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive',
    'AffineLight+': '00253-afhqcat-stylegan2-glrate0.008-kimg241-resumeffhq512-Gparts-synt_affine_weights_offset,tRGB_affine_weights_offset,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive,affine_out_in_5_1_additive',
}

afhqwild = {   
    'Full':         '00318-afhqwild-stylegan2-kimg241-resumeffhq512',
    'Affine+':      '00319-afhqwild-stylegan2-glrate0.02-kimg241-resumeffhq512-Gparts-synt_affine,tRGB_affine,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive',
    'AffineLight+': '00323-afhqwild-stylegan2-glrate0.008-kimg241-resumeffhq512-Gparts-synt_affine_weights_offset,tRGB_affine_weights_offset,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive,affine_out_in_5_1_additive',
}

i2i_experiments = {
    'Dog': afhqdog,
    'Cat': afhqcat,
    'Wild': afhqwild,
}

chkpt_idxs = {
    'Dog': 241,
    'Cat': 241,
    'Wild': 241,
}

Then, you need to perform $Z$ space inversion for images from the validation part of the AFHQ dataset (for all three domains). Suppose you have this dataset on the following path:

In [None]:
afhq_path = './afhq'

For each parameterization and for each domain, perform image inversion: 

In [None]:
for dataset_name, dataset_exps in i2i_experiments.items():
    chkpt_idx = chkpt_idxs[dataset_name]
    
    for pname, exp_suffix in dataset_exps.items():
        for image_path in glob.glob(os.path.join(afhq_path, f'val/{dataset_name.lower()}/*.jpg')):
            outdir_path = os.path.join(
                f'./inverted_samples/afhq/val/{dataset_name.lower()}/{pname}', 
                os.path.splitext(os.path.basename(image_path))[0]
            )
            if os.path.exists(os.path.join(outdir_path, 'projected_z.npz')):
                continue
            
            network_path = os.path.join(
                base_exp_path, exp_suffix, 'network-snapshot-{0:06d}.pkl'.format(chkpt_idx)
            )

            ! python ./../projector.py --target "$image_path" --outdir "$outdir_path" --network "$network_path" \
                --space 'z' --truncation-psi 0.7  --num-steps 1000 \
                --gpu 0 --save-image --save-video --save-all-steps

Load all checkpoints:

In [None]:
all_models = defaultdict(defaultdict)

for dataset_name, dataset_exps in i2i_experiments.items():
    chkpt_idx = chkpt_idxs[dataset_name]
    
    for parameterization_name, exp_suffix in dataset_exps.items():
        exp_path = os.path.join(base_exp_path, exp_suffix)
        _, G_ema, _, _, _, _ = nb_utils.load_checkpoint(
            exp_path=exp_path, chkpt_idx=chkpt_idx, device=torch.device('cpu')
        )

        all_models[parameterization_name][dataset_name] = G_ema

### Figure 35

In [None]:
def make_i2i_inference(source_dataset_name, target_dataset_name, images_paths, all_models, device):
    grid_size = np.array([len(images_paths), len([_ for _ in all_models.values() if target_dataset_name in _]) + 1])
    fig, axes = plt.subplots(*grid_size, figsize=grid_size[::-1] * 3)
    axes = axes.reshape(grid_size)

    nb_utils.prepare_axes(axes)
    for idx, image_path in enumerate(images_paths):
        axes[idx, 0].imshow(plt.imread(image_path))

        for jdx, pname in enumerate(['Full', 'Affine+', 'AffineLight+']):
            result_path = os.path.join(
                    f'./inverted_samples/afhq/val/{source_dataset_name.lower()}/{pname}', 
                    os.path.splitext(os.path.basename(image_path))[0]
                )
            result = np.load(os.path.join(result_path, '0_projected_z.npz'), allow_pickle=True)

            images, _ = nb_utils.generate_images(
                all_models[pname][target_dataset_name], grid_size=1, device=device, truncation_psi=0.8,
                target_zs=torch.tensor(result['z']).to(device)
            )
            axes[idx, jdx + 1].imshow(images[0])

            if idx == 0:
                axes[idx, jdx + 1].set_title(pname, fontdict=dict(fontsize=16, weight='bold'))

    axes[0, 0].set_title('Input', fontdict=dict(fontsize=16, weight='bold'))

    fig.subplots_adjust(wspace=0.01, hspace=0.01)
    plt.show()
    
    return fig

#### Cat2Dog

In [None]:
source_dataset_name, target_dataset_name, images_paths = 'Cat', 'Dog', [
    os.path.join(afhq_path, 'val/cat/flickr_cat_000816.jpg'),
    os.path.join(afhq_path, 'val/cat/flickr_cat_000320.jpg'),
    os.path.join(afhq_path, 'val/cat/pixabay_cat_000081.jpg')
]
fig = make_i2i_inference(source_dataset_name, target_dataset_name, images_paths, all_models, device)
fig.savefig('./../images/I2I_3_Cat2Dog.pdf', bbox_inches='tight', pad_inches=0)

#### Cat2Wild

In [None]:
source_dataset_name, target_dataset_name, images_paths = 'Cat', 'Wild', [
    os.path.join(afhq_path, 'val/cat/pixabay_cat_000615.jpg'),
    os.path.join(afhq_path, 'val/cat/flickr_cat_000265.jpg'),
    os.path.join(afhq_path, 'val/cat/pixabay_cat_000668.jpg')
]
fig = make_i2i_inference(source_dataset_name, target_dataset_name, images_paths, all_models, device)
fig.savefig('./../images/I2I_3_Cat2Wild.pdf', bbox_inches='tight', pad_inches=0)

#### Dog2Wild

In [None]:
source_dataset_name, target_dataset_name, images_paths = 'Dog', 'Wild', [
    os.path.join(afhq_path, 'val/dog/pixabay_dog_000307.jpg'),
    os.path.join(afhq_path, 'val/dog/pixabay_dog_000818.jpg'),
    os.path.join(afhq_path, 'val/dog/flickr_dog_000619.jpg')
]
fig = make_i2i_inference(source_dataset_name, target_dataset_name, images_paths, all_models, device)
fig.savefig('./../images/I2I_3_Dog2Wild.pdf', bbox_inches='tight', pad_inches=0)

#### Dog2Cat

In [None]:
source_dataset_name, target_dataset_name, images_paths = 'Dog', 'Cat', [
    os.path.join(afhq_path, 'val/dog/flickr_dog_000176.jpg'),
    os.path.join(afhq_path, 'val/dog/flickr_dog_000569.jpg'),
    os.path.join(afhq_path, 'val/dog/pixabay_dog_000494.jpg')
]
fig = make_i2i_inference(source_dataset_name, target_dataset_name, images_paths, all_models, device)
fig.savefig('./../images/I2I_3_Dog2Cat.pdf', bbox_inches='tight', pad_inches=0)

#### Wild2Cat

In [None]:
source_dataset_name, target_dataset_name, images_paths = 'Wild', 'Cat', [
    os.path.join(afhq_path, 'val/wild/flickr_wild_001627.jpg'),
    os.path.join(afhq_path, 'val/wild/flickr_wild_003586.jpg'),
    os.path.join(afhq_path, 'val/wild/pixabay_wild_000267.jpg')
]
fig = make_i2i_inference(source_dataset_name, target_dataset_name, images_paths, all_models, device)
fig.savefig('./../images/I2I_3_Wild2Cat.pdf', bbox_inches='tight', pad_inches=0)

#### Wild2Dog

In [None]:
source_dataset_name, target_dataset_name, images_paths = 'Wild', 'Dog', [
    os.path.join(afhq_path, 'val/wild/flickr_wild_003060.jpg'),
    os.path.join(afhq_path, 'val/wild/flickr_wild_002036.jpg'),
    os.path.join(afhq_path, 'val/wild/flickr_wild_000063.jpg')
]
fig = make_i2i_inference(source_dataset_name, target_dataset_name, images_paths, all_models, device)
fig.savefig('./../images/I2I_3_Wild2Dog.pdf', bbox_inches='tight', pad_inches=0)

### Figure 36

In [None]:
def make_ref_i2i_inference(
    source_dataset_name, target_dataset_name, images_paths, all_models, device, *,
    truncation_psi=0.8, reference_slice=4
):
    grid_size = np.array([len(images_paths), len([_ for _ in all_models.values() if target_dataset_name in _]) + 2])
    fig, axes = plt.subplots(*grid_size, figsize=grid_size[::-1] * 2.5)
    axes = axes.reshape(grid_size)

    nb_utils.prepare_axes(axes)
    for idx, (image_path, ref_image_path) in enumerate(images_paths):
        axes[idx, 0].imshow(plt.imread(image_path))
        axes[idx, 1].imshow(plt.imread(ref_image_path))

        for jdx, pname in enumerate(['Full', 'Affine+', 'AffineLight+']):
            result_path = os.path.join(
                f'./inverted_samples/afhq/val/{source_dataset_name.lower()}/{pname}', 
                os.path.splitext(os.path.basename(image_path))[0]
            )
            result = np.load(os.path.join(result_path, '0_projected_z.npz'), allow_pickle=True)

            ref_result_path = os.path.join(
                f'./inverted_samples/afhq/val/{target_dataset_name.lower()}/{pname}', 
                os.path.splitext(os.path.basename(ref_image_path))[0]
            )
            ref_result = np.load(os.path.join(ref_result_path, '0_projected_z.npz'), allow_pickle=True)

            c = torch.empty([result['z'].shape[0], 0])
            all_models[pname][target_dataset_name].to(device).eval()
            ws = all_models[pname][target_dataset_name].mapping(
                torch.tensor(result['z']).to(device), c.to(device), truncation_psi=0.8
            )
            ref_ws = all_models[pname][target_dataset_name].mapping(
                torch.tensor(ref_result['z']).to(device), c.to(device), truncation_psi=0.8
            )
            ws[:, reference_slice:] = ref_ws[:, reference_slice:]

            images, _ = nb_utils.generate_images(
                all_models[pname][target_dataset_name], grid_size=1, device=device,
                target_ws=ws
            )
            axes[idx, jdx + 2].imshow(images[0])

            if idx == 0:
                axes[idx, jdx + 2].set_title(pname, fontdict=dict(fontsize=14, weight='bold'))

    axes[0, 0].set_title('Source', fontdict=dict(fontsize=14, weight='bold'))
    axes[0, 1].set_title('Reference', fontdict=dict(fontsize=14, weight='bold'))

    fig.subplots_adjust(wspace=0.01, hspace=0.01)
    plt.show()
    
    return fig

#### Cat2Dog

In [None]:
source_dataset_name, target_dataset_name, images_paths = 'Cat', 'Dog', [
    (
        os.path.join(afhq_path, 'val/cat/flickr_cat_000123.jpg'),
        os.path.join(afhq_path, 'val/dog/pixabay_dog_002544.jpg'),
    ),
    (
        os.path.join(afhq_path, 'val/cat/pixabay_cat_002582.jpg'),
        os.path.join(afhq_path, 'val/dog/flickr_dog_000452.jpg'),
    ),
    (
        os.path.join(afhq_path, 'val/cat/pixabay_cat_003562.jpg'),
        os.path.join(afhq_path, 'val/dog/pixabay_dog_000607.jpg'),
    )
]
fig = make_ref_i2i_inference(source_dataset_name, target_dataset_name, images_paths, all_models, device)

#### Cat2Wild

In [None]:
source_dataset_name, target_dataset_name, images_paths = 'Cat', 'Wild', [
    (
        os.path.join(afhq_path, 'val/cat/pixabay_cat_000343.jpg'),
        os.path.join(afhq_path, 'val/wild/pixabay_wild_001082.jpg'),
    ),
    (
        os.path.join(afhq_path, 'val/cat/pixabay_cat_000147.jpg'),
        os.path.join(afhq_path, 'val/wild/flickr_wild_001137.jpg'),
    ),
    (
        os.path.join(afhq_path, 'val/cat/pixabay_cat_004765.jpg'),
        os.path.join(afhq_path, 'val/wild/flickr_wild_003947.jpg'),
    )
]
fig = make_ref_i2i_inference(source_dataset_name, target_dataset_name, images_paths, all_models, device)

#### Dog2Wild

In [None]:
source_dataset_name, target_dataset_name, images_paths = 'Dog', 'Wild', [
    (
       os.path.join(afhq_path, 'val/dog/pixabay_dog_003449.jpg'),
       os.path.join(afhq_path, 'val/wild/flickr_wild_001625.jpg'),
    ),
    (
        os.path.join(afhq_path, 'val/dog/pixabay_dog_001651.jpg'),
        os.path.join(afhq_path, 'val/wild/pixabay_wild_000265.jpg'),
    ),
    (
        os.path.join(afhq_path, 'val/dog/flickr_dog_001100.jpg'),
        os.path.join(afhq_path, 'val/wild/pixabay_wild_000536.jpg'),
    )
]
fig = make_ref_i2i_inference(source_dataset_name, target_dataset_name, images_paths, all_models, device)

#### Dog2Cat

In [None]:
source_dataset_name, target_dataset_name, images_paths = 'Dog', 'Cat', [
    (
        os.path.join(afhq_path, 'val/dog/pixabay_dog_000504.jpg'),
        os.path.join(afhq_path, 'val/cat/flickr_cat_000446.jpg'),
    ),
    (
        os.path.join(afhq_path, 'val/dog/pixabay_dog_002307.jpg'),
        os.path.join(afhq_path, 'val/cat/pixabay_cat_004217.jpg'),
    ),
    (
        os.path.join(afhq_path, 'val/dog/pixabay_dog_002423.jpg'),
        os.path.join(afhq_path, 'val/cat/flickr_cat_000585.jpg'),
    )
]
fig = make_ref_i2i_inference(source_dataset_name, target_dataset_name, images_paths, all_models, device)

#### Wild2Cat

In [None]:
source_dataset_name, target_dataset_name, images_paths = 'Wild', 'Cat', [
    (
        os.path.join(afhq_path, 'val/wild/flickr_wild_001230.jpg'),
        os.path.join(afhq_path, 'val/cat/pixabay_cat_001632.jpg'),
    ),
    (
        os.path.join(afhq_path, 'val/wild/flickr_wild_002336.jpg'),
        os.path.join(afhq_path, 'val/cat/pixabay_cat_001029.jpg'),
    ),
    (
        os.path.join(afhq_path, 'val/wild/flickr_wild_000418.jpg'),
        os.path.join(afhq_path, 'val/cat/pixabay_cat_002559.jpg'),
    )
]
fig = make_ref_i2i_inference(source_dataset_name, target_dataset_name, images_paths, all_models, device)

#### Wild2Dog

In [None]:
source_dataset_name, target_dataset_name, images_paths = 'Wild', 'Dog', [
    (
        os.path.join(afhq_path, 'val/wild/flickr_wild_003854.jpg'),
        os.path.join(afhq_path, 'val/dog/pixabay_dog_002700.jpg'),
    ),
    (
        os.path.join(afhq_path, 'val/wild/flickr_wild_003169.jpg'),
        os.path.join(afhq_path, 'val/dog/pixabay_dog_002680.jpg'),
    ),
    (
        os.path.join(afhq_path, 'val/wild/flickr_wild_002867.jpg'),
        os.path.join(afhq_path, 'val/dog/pixabay_dog_002597.jpg'),
    )
]
fig = make_ref_i2i_inference(source_dataset_name, target_dataset_name, images_paths, all_models, device)