In [None]:
import sys
sys.path.append("..")

import seaborn as sns
import matplotlib.pyplot as plt
sns.set_context("poster")

In [None]:
import json

# change json file to try different models
JSON_PATH = "configs/stylegan2_ffhq.json"
with open(JSON_PATH, "r") as f:
    CONFIG_JSON = json.load(f)
    
MODEL_DIR = CONFIG_JSON["MODEL_DIR"]
DIRECTIONS = CONFIG_JSON["DIRECTIONS"]
LAYER_MAPS = CONFIG_JSON["LAYER_MAPS"]

## Loading

In [None]:
import yaml
from omegaconf import OmegaConf

conf_path = os.path.join(MODEL_DIR, "config.yaml")
model_path = os.path.join(MODEL_DIR, "model.pt")

# load and print config
cfg = yaml.load(open(conf_path), Loader=yaml.FullLoader)
cfg = OmegaConf.create(cfg)
print(cfg)

In [None]:
import torch
from hydra.utils import instantiate, to_absolute_path

device = cfg.device

# init models
model: torch.nn.Module = instantiate(cfg.model, k=cfg.k).to(device);
generator: torch.nn.Module = instantiate(cfg.generator).to(device);
projector: torch.nn.Module = instantiate(cfg.projector).to(device);

# preload models
checkpoint_path = to_absolute_path(model_path);
checkpoint = torch.load(checkpoint_path, map_location=device);
model.load_state_dict(checkpoint["model"]);
projector.load_state_dict(checkpoint["projector"]);

# set to eval
model.eval();
generator.eval();
projector.eval();

## Visualization

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

import numpy as np
import torchvision
import torchvision.transforms as T
from PIL import Image, ImageDraw, ImageFont

import math
sign = lambda x: math.copysign(1, x)

#  Helper function to edit latent codes
def _edit(z, alpha, ks):
    """
        z: latent code to edit
        alpha: magnitude of the edit
        ks: directions to apply
    """
    
    
    #  check if only one latent code is given
    assert z.shape[0] == 1 or z.shape[0] == len(
        ks
    ), "Only able to apply all directions to single latent code or apply each direction to single code"
    model.alpha = alpha

    # Apply Directions
    zs = []
    for i, k in enumerate(ks):
        _i = i if z.shape[0] > 1 else 0
        zs.append(model.forward_single(z[i : i + 1, ...], k=k))
    zs = torch.cat(zs, dim=0)
    return zs

# Helper function to generate images
def _generate(zs, z=None, feed_layers=None):
    """
        zs: z codes to feed into generator
        z: original z code
        feed_layers: targeted edit layers
    """
    
    # Manipulate only asked layers
    if feed_layers is not None and z is not None:
        n_latent = generator.n_latent()

        zs_layers = []
        for i in range(n_latent):
            if i in feed_layers:
                zs_layers.append(zs)
            else:
                zs_layers.append(z.expand(zs.shape[0], -1))
        zs = zs_layers

    return generator(zs).detach().cpu()

# Visualizes images
def visualize(
    dir_ids,
    feed_layers,
    alphas=[-9,-6,-3,0,3,6,9],
    feat_name=None,
    seeds=[0],
    iterative=False,
    scale=5,
):
    # process alphas
    alphas = sorted(alphas)
    i = 0
    while alphas[i] < 0:
        i += 1
    neg_alphas = alphas[:i]

    if alphas[i] == 0:
        i += 1
    pos_alphas = alphas[i:]
    
    
    for seed in seeds:
        # set seed
        np.random.seed(seed)
        
        # generate latent code
        z = generator.sample_latent(1)
        z = z.to(device)
    

        with torch.no_grad():
            # get original image
            orj_img = _generate(z)

            # images container
            images = []

            #  start with z and alpha = 0
            z_orig = z
            prev_alpha = 0
            for alpha in reversed(neg_alphas):
                #  if iterative use last z and d(alpha)
                _z = z if iterative else z_orig
                _alpha = alpha - prev_alpha if iterative else alpha

                z = _edit(_z, _alpha, ks=dir_ids)
                images.append(_generate(z, z_orig, feed_layers=feed_layers))
                prev_alpha = alpha

            # reverse images
            images = list(reversed(images))

            # reset z and alpha
            z = z_orig
            prev_alpha = 0
            for alpha in pos_alphas:
                #  if iterative use last z and d(alpha)
                _z = z if iterative else z_orig
                _alpha = alpha - prev_alpha if iterative else alpha

                z = _edit(_z, _alpha, ks=dir_ids)
                images.append(_generate(z, z_orig,feed_layers=feed_layers))
                prev_alpha = alpha

            #  prepare final image
            images = torch.stack(images, dim=0)
            images = images.transpose(1, 0)
            col_orj_img = orj_img.repeat((images.size(0), 1, 1, 1))

            titles = []
            before_sign = -1
            imgs = []
            for ind, alpha in enumerate(neg_alphas + pos_alphas):
                # append orijinal image
                if sign(alpha) != before_sign:
                    imgs.append(col_orj_img)
                    titles.append("α=0")
                    before_sign = sign(alpha)

                titles.append(f"α= {alpha:.3f}")
                imgs.append(images[:, ind, ...])
            images = torch.stack(imgs).transpose(1, 0)
            
            images = images.transpose(2,3)
            images = images.transpose(3,4)

            fig, axs = plt.subplots(
                nrows=images.shape[0],
                ncols=images.shape[1],
                figsize=(images.shape[1] * scale, images.shape[0] * scale))
            axs = axs.reshape(images.shape[0], images.shape[1])

            fig.suptitle(feat_name)
            for i in range(images.shape[0]):
                axs[i][0].set_ylabel(f"k= {dir_ids[i]}")
                for j in range(images.shape[1]):
                    axs[i][j].set_xlabel(titles[j])
                    axs[i][j].set_xticks([])
                    axs[i][j].set_yticks([])
                    axs[i][j].imshow(images[i][j]);

In [None]:
feat_list = list(DIRECTIONS.keys())

print("Annotated features:")
print(feat_list)

In [None]:
import random
feat_name = feat_list[random.randint(0, len(feat_list)-1)]

# read feature
dir_id = DIRECTIONS[feat_name][0]
feed_layers = LAYER_MAPS[str(DIRECTIONS[feat_name][1])]
alphas = DIRECTIONS[feat_name][2] if len(DIRECTIONS[feat_name]) > 2 else [-7,-5,-1,0,1,5,7]

visualize(
    dir_ids=[dir_id],
    feed_layers=feed_layers,
    feat_name=feat_name,
    alphas=alphas,
    seeds=[0,1]
);