In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.append('./../')

from collections import defaultdict

import torch
import numpy as np

import matplotlib_inline
import matplotlib.pyplot as plt

%matplotlib inline
matplotlib_inline.backend_inline.set_matplotlib_formats('pdf', 'svg')

For the sake of simplicity, most of the additional functions were implemented in the `nb_utils` module.

See `General Results.ipynb` for information about basic operations with a trained model (loading, metrics evaluation, inference).

# Load checkpoints

The semantic modifications we use in our work (`StyleSpace` and `StyleFlow`) produce results only for models with a resolution of $1024 \times 1024$. So, here we will use those models.

In [None]:
import nb_utils

In [None]:
device = torch.device('cuda:0')
base_exp_path = '~/StyleDomain/DissimilarDomains/training-runs'

In [None]:
exp_path = os.path.join(base_exp_path, '00069-Mega-stylegan2-kimg241-resumeffhq1024')
_, G_ema, _, G_ema_base, options, metrics = nb_utils.load_checkpoint(
    exp_path=exp_path, chkpt_idx=40, device=device
)

# StyleSpace Modifications

First, let's define the list of available modifications:

In [None]:
# Source: https://github.com/betterze/StyleSpace/blob/main/StyleSpace_single.ipynb
configs_ffhq = {
    'Black Hair'        : [(12, 479)],
    'Blond Hair'        : [(12, 479), (12, 266)],
    'Grey Hair'         : [(11, 286)],
    'Short Hair'        : [(6, 500), (8, 128), (5, 92), (6, 394), (6, 323)],
    'Wavy Hair'         : [(6, 500), (8, 128), (5, 92), (6, 394), (6, 323)],
    'Bangs'             : [(3, 259), (6, 285), (5, 414), (6, 128), (9, 295), (6, 322), (6, 487), (6, 504)],
    'Receding Hairline' : [(5, 414), (6, 322), (6, 497), (6, 504)],
    'Smile'             : [(6, 501)],
    'Lipstick'          : [(15, 45)],
    'Sideburns'         : [(12, 237)],
    'Goatee'            : [(9, 421)],
    'Earrings'          : [(8, 81)],
    'Glasses'           : [(3, 288), (2, 175), (3, 120), (2, 97)],
    'Wear Suit'         : [(9, 441), (8, 292), (11, 358), (6, 223)],
    'Gender'            : [(9, 6)],
    'Bangs'             : [(3, 169)],
    'Gaze'              : [(9, 409)]
}

`Generator` module was modified in order to apply single-channel modifications during inference. The following interface is used: one can pass several new named arguments into `Generator.synthesis` call:

* `style_space_modifications` — List of modifications. Defaults to None.
* `style_space_modifications_first` — Whether to apply Style Space modification before Style Space offsets. Defaults to False.
* `saved_styles` — Buffer to save intermediate Style Space latents. Defaults to None.

The final one is used to save activation statistics during the forward pass and then rescale modification magnitude. This is necessary to correctly apply multichannel modifications (i.e., **Wavy Hair**).

`style_space_modifications` should be represented as a list of single-channel modifications. Each modification consists of a layer idx, a channel idx, a modification magnitude, and an additional value to suppress StyleDomain directions in the corresponding channel.

Let's define an additional function that will prepare `style_space_modifications` in the required format:

In [None]:
def get_modifier(name, magnitude, offset_factor, styles_stats=None):
    """
        :param str name: Name of modification for the FFHQ model
        :param float magnitude: Modification magnitude
        :param float offset_factor: Scalar multiplier in corresponding channels for StyleDomain directions
        :param Optional[List[torch.Tensor]] styles_stats: Standard deviation of the StyleSpace will 
            be stored for each SyntesisNetwork layer.
            
        :return List[Tuple[Tuple[int, int], float, float]]: List of modifications 
            in the format of ((layer_idx, channel_idx), magnitude, offset_factor)
    """
    modifier = []
    for layer, channel in configs_ffhq[name]:
        channel_std = styles_stats[layer][channel] if styles_stats is not None else 1.0
        modifier.append(
            ((layer, channel), magnitude * channel_std, offset_factor)
        )
    
    return modifier

Take a modifier:

In [None]:
modifier = get_modifier('Smile', -8.0, 0.0)
modifier

Generate images from fine-tuned model that was loaded previously (**Full** parameterization for **Mega** domain):

In [None]:
images, ws = nb_utils.generate_images(
    G_ema, grid_size=6, device=device, truncation_psi=0.9, seed=1
)

Use the same latents in order to modify images:

In [None]:
modified_images, _ = nb_utils.generate_images(
    G_ema, grid_size=ws.shape[0], device=device, truncation_psi=0.9,
    target_ws=torch.from_numpy(ws).to(device), style_space_modifications=modifier
)

In [None]:
grid_size = np.array([2, ws.shape[0]])
fig, axes = plt.subplots(*grid_size, figsize=grid_size[::-1] * 2)
axes = axes.reshape(grid_size)
nb_utils.prepare_axes(axes)

for ax, image in zip(axes[0].reshape(-1), images):
    ax.imshow(image)
for ax, image in zip(axes[1].reshape(-1), modified_images):
    ax.imshow(image)

axes[1, 0].set_ylabel('+ Smile', fontdict=dict(fontsize=12, weight='bold'))
fig.suptitle('StyleSpace modifications for Mega')

fig.subplots_adjust(wspace=0.01, hspace=0.01)
plt.show()

## Multichannel modifications

To apply modifications that change multiple channels at once, we have to compute the scale of each channel. See `training.networks.w_to_s` for an exact description of the `saved_styles` argument.

In [None]:
def compute_styles_stats(G, n_batches=128, batch_size=8, **kwargs):
    """
        :param torch.Module G: Generator
        :param int n_batches: Number of batches to estimate variance
        :param int batch_size: Batch size for `nb_utils.generate_images`
        :param dict kwargs: Additional named parameters that will be passed to `nb_utils.generate_images`
            I.e. device, truncation_psi, etc.
        :return Dict[torch.Tensor]: Standard deviation for each layer and channel
    """
    styles_stats = dict()
    saved_styles = defaultdict(list)
    
    for idx in range(n_batches):
        batch_styles = dict(initial=dict(), final=dict())
        _ = nb_utils.generate_images(
            G, grid_size=batch_size, batch_size=batch_size,
            seed=idx, saved_styles=batch_styles, **kwargs
        )
        
        for layer, styles in batch_styles['initial'].items():
            saved_styles[layer].append(styles.numpy())
    
    for layer, styles in saved_styles.items():
        styles_stats[layer] = np.std(np.vstack(styles), axis=0)
        
    return styles_stats

Compute statistics for the model using $128 \times 8 = 1024$ images:

In [None]:
styles_stats = compute_styles_stats(G_ema, n_batches=128, batch_size=8, device=device)

Apply a multichannel modification:

In [None]:
modifier = get_modifier('Wavy Hair', 10.0, 0.0, styles_stats=styles_stats)
modifier

Use the same model as before to generate modified images:

In [None]:
images, ws = nb_utils.generate_images(
    G_ema, grid_size=6, device=device, truncation_psi=0.9, seed=19
)

modified_images, _ = nb_utils.generate_images(
    G_ema, grid_size=ws.shape[0], device=device, truncation_psi=0.9,
    target_ws=torch.from_numpy(ws).to(device), style_space_modifications=modifier
)

grid_size = np.array([2, ws.shape[0]])
fig, axes = plt.subplots(*grid_size, figsize=grid_size[::-1] * 2)
axes = axes.reshape(grid_size)
nb_utils.prepare_axes(axes)

for ax, image in zip(axes[0].reshape(-1), images):
    ax.imshow(image)
for ax, image in zip(axes[1].reshape(-1), modified_images):
    ax.imshow(image)

axes[1, 0].set_ylabel('+ Wavy Hair', fontdict=dict(fontsize=12, weight='bold'))
fig.suptitle('StyleSpace modifications for Mega')

fig.subplots_adjust(wspace=0.01, hspace=0.01)
plt.show()

# StyleFlow

We use the original StyleFlow implementation from [https://github.com/RameenAbdal/StyleFlow](https://github.com/RameenAbdal/StyleFlow). We created a simple wrapper to apply modifications directly in a Jupyter notebook. 

Note that this method requires initial values for attributes that are going to be modified. Those values could be obtained via the [Microsoft Face API](https://azure.microsoft.com/en-us/products/cognitive-services/face/) and [DPR model](https://github.com/zhhoper/DPR). Instead, we chose to use images that were labeled by StyleFlow authors, as they provide $1000$ labeled latent vectors in $W+$ space that came from the same model that we use as an initial checkpoint to finetune our models. 

Since **Affine+**, **AffineLight+**, and **StyleSpace** parameterizations do not update the `Mapping Network` during training, those vectors are valid $W+$ space latents for finetuned models as well. However, in **Full** parameterization the `Mapping Network` changes, so we should perform image inversion into $Z$ space to obtain the correct $W+$ latents. Instead, we chose to ignore changes in the `Mapping Network` and use the $W+$ vectors as they are. We observe that this does not influence modifications' quality due to model alignment (see [StyleAlign](https://openreview.net/pdf?id=Qg2vi4ZbHM9)).

In [None]:
# styleflow weights

!wget https://nxt.2a2i.org/index.php/s/yxdCXxSWJgAKXkP/download/styleflow_data.zip -O ../editing/styleflow/styleflow_data.zip
!unzip ../editing/styleflow/styleflow_data.zip -d ../editing/styleflow
!rm -rf ../editing/styleflow/styleflow_data.zip

In [None]:
styleflow_data_path = './../editing/styleflow/styleflow_data/data'
styleflow_model_path = './../editing/styleflow/styleflow_data/flow_weight/modellarge10k.pt'

In [None]:
from editing.styleflow.editor import StyleFlowEditor

Create a wrapper for StyleFlow semantic modifications:

In [None]:
editor = StyleFlowEditor(styleflow_data_path, styleflow_model_path, device=device)

Initialize the model to modify one of the labeled images:

In [None]:
editor._allocate_entity(idx=0)

We can change any of those attributes:
```python
attr_order = ['Gender', 'Glasses', 'Yaw', 'Pitch', 'Baldness', 'Beard', 'Age', 'Expression']
```

Change `Gender` using StyleFlow. `edit_power` changes between $0.0$ and $1.0$, where $0.0$ corresponds to the minimum attribute value and $1.0$ corresponds to the maximum value.

In [None]:
ws, ws_modified = editor.real_time_editing(attr_index=0, edit_power=0.0)

Generate images from the latents using **Full** parameterization model for the **Mega** domain and show them:

In [None]:
images, _ = nb_utils.generate_images(
    G_ema, grid_size=ws.shape[0], device=device, truncation_psi=0.9,
    target_ws=ws
)
modified_images, _ = nb_utils.generate_images(
    G_ema, grid_size=ws_modified.shape[0], device=device, truncation_psi=0.9,
    target_ws=ws_modified
)

grid_size = np.array([2, ws.shape[0]])
fig, axes = plt.subplots(*grid_size, figsize=grid_size[::-1] * 2)
axes = axes.reshape(grid_size)
nb_utils.prepare_axes(axes)

for ax, image in zip(axes[0].reshape(-1), images):
    ax.imshow(image)
for ax, image in zip(axes[1].reshape(-1), modified_images):
    ax.imshow(image)

axes[1, 0].set_ylabel('Gender', fontdict=dict(fontsize=12, weight='bold'))
fig.suptitle('StyleFlow modifications for Mega')

fig.subplots_adjust(wspace=0.01, hspace=0.01)
plt.show()

Apply `Yaw` modification to multiple images:

In [None]:
ws = []
ws_modified = []
for image_idx in range(6):
    editor._allocate_entity(image_idx)
    w, w_modified = editor.real_time_editing(attr_index=2, edit_power=0.9)
    ws.append(w)
    ws_modified.append(w_modified)
ws = torch.vstack(ws)
ws_modified = torch.vstack(ws_modified)

In [None]:
images, _ = nb_utils.generate_images(
    G_ema, grid_size=1, device=device, truncation_psi=0.9,
    target_ws=ws
)
modified_images, _ = nb_utils.generate_images(
    G_ema, grid_size=1, device=device, truncation_psi=0.9,
    target_ws=ws_modified
)

grid_size = np.array([2, ws.shape[0]])
fig, axes = plt.subplots(*grid_size, figsize=grid_size[::-1] * 2)
axes = axes.reshape(grid_size)
nb_utils.prepare_axes(axes)

for ax, image in zip(axes[0].reshape(-1), images):
    ax.imshow(image)
for ax, image in zip(axes[1].reshape(-1), modified_images):
    ax.imshow(image)

axes[1, 0].set_ylabel('Yaw', fontdict=dict(fontsize=12, weight='bold'))
fig.suptitle('StyleFlow modifications for Mega')

fig.subplots_adjust(wspace=0.01, hspace=0.01)
plt.show()

# Results reproduction

To reproduce all results, all necessary models must be trained (in $1024\times 1024$ resolution). More precisely, it is the cartesian product of datasets and parameterizations ($20$ models in total):

$$[\text{Metfaces}, \text{Mega}, \text{Ukiyoe}, \text{Dog}, \text{Cat}] \times [\text{Full}, \text{Affine}+, \text{AffineLight}+, \text{StyleSpace}]$$

Let's define a list of all those experiments:

In [None]:
metfaces = {
    'Full':         '00103-metfaces-stylegan2-kimg241-resumeffhq1024',
    'Affine+':      '00302-metfaces-stylegan2-glrate0.02-kimg241-resumeffhq1024-Gparts-synt_affine,tRGB_affine,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive',
    'StyleSpace':   '00227-metfaces-stylegan2-glrate0.02-kimg241-resumeffhq1024-Gparts-synt_offset,tRGB_offset-use-dom-mod-additive',
    'AffineLight+': '00325-metfaces-stylegan2-glrate0.008-kimg241-resumeffhq1024-Gparts-synt_affine_weights_offset,tRGB_affine_weights_offset,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive,affine_out_in_5_1_additive',
}
Mega = {
    'Full':         '00069-Mega-stylegan2-kimg241-resumeffhq1024',
    'Affine+':      '00301-Mega-stylegan2-glrate0.02-kimg241-resumeffhq1024-Gparts-synt_affine,tRGB_affine,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive',
    'StyleSpace':   '00228-Mega-stylegan2-glrate0.02-kimg241-resumeffhq1024-Gparts-synt_offset,tRGB_offset-use-dom-mod-additive',
    'AffineLight+': '00325-Mega-stylegan2-glrate0.008-kimg241-resumeffhq1024-Gparts-synt_affine_weights_offset,tRGB_affine_weights_offset,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive,affine_out_in_5_1_additive',
}
ukiyoe = {
    'Full':         '00106-ukiyoe-stylegan2-kimg241-resumeffhq1024',
    'Affine+':      '00302-ukiyoe-stylegan2-glrate0.02-kimg241-resumeffhq1024-Gparts-synt_affine,tRGB_affine,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive',
    'StyleSpace':   '00227-ukiyoe-stylegan2-glrate0.02-kimg241-resumeffhq1024-Gparts-synt_offset,tRGB_offset-use-dom-mod-additive',
    'AffineLight+': '00319-ukiyoe-stylegan2-glrate0.008-kimg241-resumeffhq1024-Gparts-synt_affine_weights_offset,tRGB_affine_weights_offset,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive,affine_out_in_5_1_additive',
}
afhqdog_1024 = {
    'Full':         '00219-afhqdog_1024-stylegan2-kimg241-resumeffhq1024',
    'Affine+':      '00301-afhqdog_1024-stylegan2-glrate0.02-kimg241-resumeffhq1024-Gparts-synt_affine,tRGB_affine,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive',
    'StyleSpace':   '00219-afhqdog_1024-stylegan2-glrate0.02-kimg241-resumeffhq1024-Gparts-synt_offset,tRGB_offset-use-dom-mod-additive',
    'AffineLight+': '00325-afhqdog_1024-stylegan2-glrate0.008-kimg241-resumeffhq1024-Gparts-synt_affine_weights_offset,tRGB_affine_weights_offset,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive,affine_out_in_5_1_additive',
}
afhqcat_1024 = {
    'Full':         '00218-afhqcat_1024-stylegan2-kimg241-resumeffhq1024',
    'Affine+':      '00307-afhqcat_1024-stylegan2-glrate0.02-kimg241-resumeffhq1024-Gparts-synt_affine,tRGB_affine,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive',
    'StyleSpace':   '00218-afhqcat_1024-stylegan2-glrate0.02-kimg241-resumeffhq1024-Gparts-synt_offset,tRGB_offset-use-dom-mod-additive',
    'AffineLight+': '00326-afhqcat_1024-stylegan2-glrate0.008-kimg241-resumeffhq1024-Gparts-synt_affine_weights_offset,tRGB_affine_weights_offset,synt_weights_offset.b64,tRGB_weights_offset.b64-use-dom-mod-out_in_additive,affine_out_in_5_1_additive',
}

semantic_edditing_experiments = {
    'Metfaces': metfaces,
    'Mega': Mega,
    'Ukiyoe': ukiyoe,
    'Dog': afhqdog_1024,
    'Cat': afhqcat_1024,
}

chkpt_idxs = {
    'Metfaces': 180,
    'Mega': 40,
    'Ukiyoe': 100,
    'Dog': 180,
    'Cat': 100,
}

Load all checkpoints:

In [None]:
all_models = defaultdict(defaultdict)

for dataset_name, dataset_exps in semantic_edditing_experiments.items():
    chkpt_idx = chkpt_idxs[dataset_name]
    
    for parameterization_name, exp_suffix in dataset_exps.items():
        exp_path = os.path.join(base_exp_path, exp_suffix)
        _, G_ema, _, G_ema_base, _, metrics = nb_utils.load_checkpoint(
            exp_path=exp_path, chkpt_idx=chkpt_idx, device=torch.device('cpu')
        )
        
        all_models[parameterization_name][dataset_name] = G_ema
    all_models['Original'][dataset_name] = G_ema_base

Let's define a function that will apply modifications for each model to a given image:

In [None]:
def make_modifications_figure(models, modifications, dataset_name, editor, image_idx, device):
    grid_size = np.array([len(models), len(modifications)])
    fig, axes = plt.subplots(*grid_size, figsize=grid_size[::-1] * 2)
    axes = axes.reshape(grid_size)
    nb_utils.prepare_axes(axes)

    styles_stats = compute_styles_stats(models['Full'][dataset_name], n_batches=128, batch_size=8, device=device)
    editor._allocate_entity(image_idx)
    w = editor.initial_w

    for idx, pname in enumerate(['Original', 'Full', 'Affine+', 'AffineLight+', 'StyleSpace']):
        for jdx, (attribute, power) in enumerate(modifications):
            if attribute is None:
                # Do not apply modifications
                title = ''

                w_modified = w
                modifier = None
            elif isinstance(attribute, str):
                # Apply StyleSpace modifications
                title = attribute

                w_modified = w
                modifier = get_modifier(attribute, power, 0.0, styles_stats=styles_stats)
            else:
                # Apply StyleFlow modifications
                title = editor.attr_order[attribute]
                if attribute == 6:
                    title = 'Rejuvenation' if power < 0.5 else 'Aging'

                _, w_modified = editor.real_time_editing(attr_index=attribute, edit_power=power)
                modifier = None

            modified_images, _ = nb_utils.generate_images(
                models[pname][dataset_name], grid_size=w_modified.shape[0], device=device,
                truncation_psi=0.9, target_ws=w_modified.to(device), style_space_modifications=modifier
            )

            axes[idx, jdx].imshow(modified_images[0])
            if idx == 0:
                axes[idx, jdx].set_title(title, fontdict=dict(fontsize=12, weight='bold'))
            if jdx == 0:
                axes[idx, jdx].set_ylabel(pname, fontdict=dict(fontsize=12, weight='bold'))

    fig.subplots_adjust(wspace=0.01, hspace=0.01)
    plt.show()
    
    return fig

### Figure 22

In [None]:
image_idx, is_smile, is_yaw_left, is_man, is_gaze_left = 32, False, True, False, True

mega_modifications = [
    (None, None), 
    ('Blond Hair', -4.0), ('Black Hair', 4.0), 
    ('Smile', 3.0 if is_smile else -1.0), 
    ('Lipstick', -3.0),
    (2, 0.9 if is_yaw_left else 0.1), 
    (6, 0.1), (6, 0.9), 
    (0, 0.3 if is_man else 0.7),
    ('Short Hair', -4.0), ('Wavy Hair', 4.0),             
    ('Gaze', 6.0 if is_gaze_left else -6.0),
]

fig = make_modifications_figure(
    models=all_models, 
    modifications=mega_modifications, 
    dataset_name='Mega', editor=editor, image_idx=image_idx, device=device
)

### Figure 23

In [None]:
image_idx, is_smile, is_yaw_left, is_man, is_gaze_left = 23, False, False, True, True

metfaces_modifications = [
    (None, None), 
    ('Blond Hair', -4.0), ('Black Hair', 4.0), 
    ('Smile', 3.0 if is_smile else -1.0), 
    ('Lipstick', -3.0),
    (2, 0.9 if is_yaw_left else 0.1), 
    (6, 0.1), (6, 0.9), 
    (0, 0.3 if is_man else 0.7),
    ('Short Hair', -4.0), ('Wavy Hair', 4.0),             
    ('Gaze', 6.0 if is_gaze_left else -6.0),
]

fig = make_modifications_figure(
    models=all_models, 
    modifications=metfaces_modifications, 
    dataset_name='Metfaces', editor=editor, image_idx=image_idx, device=device
)

### Figure 24

In [None]:
image_idx, is_smile, is_yaw_left, is_man, is_gaze_left = 22, False, True, False, False

ukiyoe_modifications = [
    (None, None), 
    ('Blond Hair', -8.0), ('Black Hair', 8.0), 
    ('Smile', 4.0 if is_smile else -4.0), 
    ('Lipstick', -6.0),
    (2, 0.6 if is_yaw_left else 0.4), 
    (6, 0.2), (6, 1.0), 
    (0, 0.0 if is_man else 1.0),
    ('Short Hair', -5.0), ('Wavy Hair', 3.0),             
    ('Gaze', 4.0 if is_gaze_left else -4.0),
]

fig = make_modifications_figure(
    models=all_models, 
    modifications=ukiyoe_modifications, 
    dataset_name='Ukiyoe', editor=editor, image_idx=image_idx, device=device
)

### Figure 25.1

In [None]:
image_idx, is_yaw_left, is_gaze_left = 10, False, False

afhqcat_modifications = [
    (None, None), 
    ('Blond Hair', -5.0), ('Black Hair', 5.0), 
    (2, 1.0 if is_yaw_left else 0.0), 
    ('Gaze', 6.0 if is_gaze_left else -6.0),
    ('Short Hair', -6.0), ('Wavy Hair', 6.0), 
]

fig = make_modifications_figure(
    models=all_models, 
    modifications=afhqcat_modifications, 
    dataset_name='Cat', editor=editor, image_idx=image_idx, device=device
)

### Figure 25.2

In [None]:
image_idx, is_yaw_left, is_man = 38, True, False

afhqdog_modifications = [
    (None, None), 
    ('Blond Hair', -5.0), ('Black Hair', 4.0), 
    (2, 0.9 if is_yaw_left else 0.1), 
    (0, 0.0 if is_man else 1.0),
    ('Short Hair', -4.0), ('Wavy Hair', 4.0), 
]

fig = make_modifications_figure(
    models=all_models, 
    modifications=afhqdog_modifications, 
    dataset_name='Dog', editor=editor, image_idx=image_idx, device=device
)