<a href="https://colab.research.google.com/github/dvschultz/ml-art-colabs/blob/master/StyleGAN2_activations_and_pca_projection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## StyleGAN2 Activations and PCA Projection

by [duskvirkus](https://github.com/duskvirkus)

Ever wondered what's happening in your SG2 model? This notebook generates activations of a StyleGAN2 rosinality model. Optionally can use PCA at the end to see a representation of lower level network layers.

Convert pkl model: https://github.com/dvschultz/stylegan2-ada-pytorch/blob/main/SG2_ADA_PT_to_Rosinality.ipynb

Thanks to [Derrick Schultz](https://github.com/dvschultz) for the notebook this is based on. Can be found at: [https://github.com/dvschultz/stylegan2-ada-pytorch/blob/eps/Advanced_StyleGAN_Network_bending.ipynb](https://github.com/dvschultz/stylegan2-ada-pytorch/blob/eps/Advanced_StyleGAN_Network_bending.ipynb)

## Prep

In [None]:
!nvidia-smi -L

In [None]:
# Install libraries
!git clone -b audio-animate https://github.com/dvschultz/network-bending
!pip uninstall torch torchvision -y
!pip install torch==1.5.0+cu101 torchvision==0.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
!pip install Ninja kmeans-pytorch
!apt-get install -y vim make gdb libopencv-dev
!wget https://download.pytorch.org/libtorch/cu101/libtorch-shared-with-deps-1.5.0%2Bcu101.zip
!unzip /content/libtorch-shared-with-deps-1.5.0+cu101.zip -d /root/
%cd network-bending

#build custom pytorch transformations
!chmod +x /content/network-bending/build_custom_transforms.sh
!/content/network-bending/build_custom_transforms.sh /root/libtorch/

Cloning into 'network-bending'...
remote: Enumerating objects: 369, done.[K
remote: Counting objects: 100% (37/37), done.[K
remote: Compressing objects: 100% (28/28), done.[K
remote: Total 369 (delta 22), reused 21 (delta 9), pack-reused 332[K
Receiving objects: 100% (369/369), 21.44 MiB | 69.69 MiB/s, done.
Resolving deltas: 100% (213/213), done.
Found existing installation: torch 1.9.0+cu102
Uninstalling torch-1.9.0+cu102:
[31mERROR: Operation cancelled by user[0m
Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.5.0+cu101
  Downloading https://download.pytorch.org/whl/cu101/torch-1.5.0%2Bcu101-cp37-cp37m-linux_x86_64.whl (703.8 MB)
[K     |██▉                             | 62.4 MB 1.3 MB/s eta 0:08:17
[31mERROR: Operation cancelled by user[0m
[?25hTraceback (most recent call last):
  File "<frozen importlib._bootstrap>", line 677, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 728, in exec_module
  File "<froz

In [None]:
!gdown --id 1rL-J63eFfn80IYU2GfVY977GI2qOG6dw -O /content/ladiesblack.pt # Model Credit: Derrick Schultz

Downloading...
From: https://drive.google.com/uc?id=1rL-J63eFfn80IYU2GfVY977GI2qOG6dw
To: /content/ladiesblack.pt
133MB [00:01, 124MB/s]


## Fix existing script

In [None]:
%%writefile generate_activations.py

import argparse
import torch
import yaml
import os
import copy

from torchvision import utils
from model import Generator
from tqdm import tqdm
from util import *

def generate(args, g_ema, device, mean_latent, t_dict_list):
    with torch.no_grad():
        g_ema.eval()
        for i in tqdm(range(args.pics)):
            extra_t_dict_list =  copy.deepcopy(t_dict_list)
            extra_t_dict_list.append({'layerID': -1, 'index': i})
            sample_z = torch.randn(args.sample, args.latent, device=device)
            sample, _ = g_ema([sample_z], 
                                truncation=args.truncation, 
                                truncation_latent=mean_latent, 
                                transform_dict_list=extra_t_dict_list)
            if not os.path.exists('sample'):
                    os.makedirs('sample')
            utils.save_image(
                sample,
                f'sample/{str(i).zfill(6)}.png',
                nrow=1,
                normalize=True,
                range=(-1, 1))


if __name__ == '__main__':
    device = 'cuda'

    parser = argparse.ArgumentParser()

    parser.add_argument('--size', type=int, default=1024)
    parser.add_argument('--sample', type=int, default=1)
    parser.add_argument('--pics', type=int, default=20)
    parser.add_argument('--truncation', type=float, default=0.5)
    parser.add_argument('--truncation_mean', type=int, default=4096)
    parser.add_argument('--ckpt', type=str, default="stylegan2-ffhq-config-f.pt")
    parser.add_argument('--channel_multiplier', type=int, default=2)
    parser.add_argument('--config', type=str, default="configs/empty_transform_config.yaml")

    args = parser.parse_args()

    args.latent = 512
    args.n_mlp = 8

    yaml_config = {}
    with open(args.config, 'r') as stream:
        try:
            yaml_config = yaml.load(stream)
        except yaml.YAMLError as exc:
            print(exc)
    
    g_ema = Generator(
        args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
    ).to(device)
    new_state_dict = g_ema.state_dict()
    checkpoint = torch.load(args.ckpt)
    
    ext_state_dict  = torch.load(args.ckpt)['g_ema']
    g_ema.load_state_dict(checkpoint['g_ema'])
    new_state_dict.update(ext_state_dict)
    g_ema.load_state_dict(new_state_dict)
    g_ema.eval()
    g_ema.to(device)

    if args.truncation < 1:
        with torch.no_grad():
            mean_latent = g_ema.mean_latent(args.truncation_mean)
    else:
        mean_latent = None
    
    layer_channel_dims = create_layer_channel_dim_dict(args.channel_multiplier, 16)
    transform_dict_list = create_transforms_dict_list(yaml_config, None, layer_channel_dims)
    generate(args, g_ema, device, mean_latent, transform_dict_list)
    
    config_out = {}
    config_out['transforms'] = yaml_config['transforms']
    with open(r'sample/config.yaml', 'w') as file:
        documents = yaml.dump(config_out, file)

Writing generate_activations.py


In [None]:
%%writefile configs/nothing.yaml



In [None]:
%%writefile util.py

import random

def create_layer_channel_dim_dict(channel_multiplier, n_layers=16):
    layer_channel_dict = {
        0: 512,
        1: 512,
        2: 512,
        3: 512,
        4: 512,
        5: 512,
        6: 512,
        7: 256*channel_multiplier,
        8: 256*channel_multiplier,
        9: 128*channel_multiplier,
        10: 128*channel_multiplier,
        11: 64*channel_multiplier,
        12: 64*channel_multiplier,
        13: 32*channel_multiplier,
        14: 32*channel_multiplier,
        15: 16*channel_multiplier,
        16: 16*channel_multiplier
    }
    return  {k: v for k, v in layer_channel_dict.items() if int(k) <= n_layers}

def create_random_transform_dict(layer, layer_channel_dict, transform, params, percentage):
    layer_dim = layer_channel_dict[layer]
    num_samples = int( layer_dim * percentage )
    rand_indicies = random.sample(range(0, layer_dim), num_samples)
    transform_dict ={
        "layerID": layer,
        "transformID": transform,
        "indicies": rand_indicies,
        "params": params
    }
    return transform_dict

def create_layer_wide_transform_dict(layer, layer_channel_dict, transform, params):
    layer_dim = layer_channel_dict[layer]
    transform_dict ={
        "layerID": layer,
        "transformID": transform,
        "indicies": range(0, layer_dim),
        "params": params
    }
    return transform_dict

def create_multiple_transforms_dict(layer, layer_channel_dict, transform, params):
    
    transform_dict_list = []
    for t in range(len(transform)):
        layer_dim = layer_channel_dict[layer[t]]

        transform_dict_list.append({
            "layerID": layer[t],
            "transformID": transform[t],
            "indicies": range(0, layer_dim),
            "params": params[t]
        })
    return transform_dict_list

def create_cluster_transform_dict(layer, layer_channel_dict, cluster_config, transform, params, cluster_ID):
    layer_dim = layer_channel_dict[layer]
    indicies = []
    for i, c_dict in enumerate(cluster_config[layer]):
        if c_dict['cluster_index'] == int(cluster_ID):
            indicies.append(c_dict['feature_index'])
    print(indicies)
    if len(indicies) == 0:
        print("No indicies found for clusterID: " +str(cluster_ID) + " on layer: " +str(layer))
    transform_dict ={
        "layerID": layer,
        "transformID": transform,
        "indicies": indicies,
        "params": params
    }
    return transform_dict

def create_transforms_dict_list(yaml_config, cluster_config, layer_channel_dict):
    transform_dict_list = []
    
    if yaml_config and 'transforms' in yaml_config:
      for transform in yaml_config['transforms']:
          if transform['features'] == 'all':
              transform_dict_list.append(
                  create_layer_wide_transform_dict(transform['layer'],
                      layer_channel_dict, 
                      transform['transform'], 
                      transform['params']))
          elif transform['features'] == 'random':
              transform_dict_list.append(
                  create_random_transform_dict(transform['layer'],
                      layer_channel_dict, 
                      transform['transform'], 
                      transform['params'],
                      transform['feature-param']))
          elif transform['features'] == 'cluster' and cluster_config != {}:
              transform_dict_list.append(
                  create_cluster_transform_dict(transform['layer'],
                      layer_channel_dict, 
                      cluster_config,
                      transform['transform'], 
                      transform['params'],
                      transform['feature-param']))
          else:
              print('transform type: ' + str(transform) + ' not recognised')
      
    return transform_dict_list
        

## Run the script

Activations can be found `/content/network-bending/activations` or download the activations.zip made by the last cell.

In [None]:
%cd network-bending

/content/network-bending


In [None]:
!python generate_activations.py --size 1024 --ckpt /content/tree-flowers.pt --pics 5 --config /content/network-bending/configs/nothing.yaml --truncation 0.8 --channel_multiplier 2

100% 5/5 [02:59<00:00, 35.81s/it]
Traceback (most recent call last):
  File "generate_activations.py", line 84, in <module>
    config_out['transforms'] = yaml_config['transforms']
TypeError: 'NoneType' object is not subscriptable


## Principal Component Analysis (PCA) Script

Output colors from PCA have no meaning just there so it can be 3D data instead of 1D.

In [None]:
import cv2
import numpy as np
import os

layer_paths = [
  '/content/network-bending/activations/1/0',
  '/content/network-bending/activations/2/0',
  '/content/network-bending/activations/3/0',
  '/content/network-bending/activations/4/0',
  '/content/network-bending/activations/5/0',
  '/content/network-bending/activations/6/0',
  '/content/network-bending/activations/7/0',
  '/content/network-bending/activations/8/0',
  '/content/network-bending/activations/9/0',
  '/content/network-bending/activations/10/0',
  '/content/network-bending/activations/11/0',
  '/content/network-bending/activations/12/0',
  '/content/network-bending/activations/13/0',
  '/content/network-bending/activations/14/0',
  '/content/network-bending/activations/15/0',
  '/content/network-bending/activations/16/0',
]

layer_imgs = []

os.makedirs('/content/out-test-2', exist_ok=True)

for j in range(len(layer_paths)):
  layer_path = layer_paths[j]
  activations = []
  for root, subdirs, files in os.walk(layer_path):

      for filename in files:

        activations.append(cv2.imread(os.path.join(root, filename)))

  layer_imgs.append(activations)

  saved_shape = layer_imgs[j][0].shape
  for i in range(len(layer_imgs[j])):
    layer_imgs[j][i] = cv2.cvtColor(layer_imgs[j][i], cv2.COLOR_BGR2GRAY)
    layer_imgs[j][i] = layer_imgs[j][i].flatten()
    layer_imgs[j][i] = np.expand_dims(layer_imgs[j][i], axis=0)

  all = np.concatenate(layer_imgs[j])

  unprojected = all.T

  pca = decomposition.PCA(n_components=3)
  pca.fit(unprojected)
  projected = pca.transform(unprojected)


  output = np.reshape(projected, saved_shape)

  img_float32 = np.float32(output)
  final = cv2.cvtColor(img_float32, cv2.COLOR_BGR2RGB)

  cv2.imwrite('/content/out-test-2/' + str(j).zfill(2) + '.png', final)

In [None]:
%cd /content/
!zip -r out-test-2.zip out-test-2

/content
  adding: out-test-2/ (stored 0%)
  adding: out-test-2/12.png (deflated 0%)
  adding: out-test-2/03.png (stored 0%)
  adding: out-test-2/10.png (deflated 0%)
  adding: out-test-2/13.png (deflated 0%)
  adding: out-test-2/06.png (stored 0%)
  adding: out-test-2/14.png (deflated 0%)
  adding: out-test-2/01.png (stored 0%)
  adding: out-test-2/11.png (deflated 0%)
  adding: out-test-2/04.png (stored 0%)
  adding: out-test-2/05.png (stored 0%)
  adding: out-test-2/09.png (stored 0%)
  adding: out-test-2/07.png (stored 0%)
  adding: out-test-2/00.png (stored 0%)
  adding: out-test-2/02.png (stored 0%)
  adding: out-test-2/08.png (stored 0%)
  adding: out-test-2/15.png (deflated 2%)
