<a href="https://colab.research.google.com/github/duskvirkus/ml-art-colab/blob/main/StyleGAN2_generate_feature_grids.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Generate Feature Vector Grids



## Dependencies

In [None]:
!git clone https://github.com/NVlabs/stylegan2-ada-pytorch
%cd stylegan2-ada-pytorch

!pip install ninja

## Load Model

load your model or connect google drive

In [None]:
!wget http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-cat-config-f.pkl

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Setup

In [None]:
#@title imports (hidden code)
import os
import io

import numpy as np
import PIL.Image
import cv2
import IPython.display
import dnnlib
import torch

import legacy

In [None]:
network_pkl = '/content/drive/MyDrive/art/stylegan/suicide-girls/pytorch-v1-cp1/suicide-girls-pytorch-v1-cp1.pkl'

In [None]:
device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device)

## Generate Feature Vectors

In [None]:
features_save_path = '/content/drive/MyDrive/art/stylegan/suicide-girls/pytorch-v1-cp1/feature-vectors/feature-vectors.pt'

modulate = {
    k[0]: k[1]
    for k in G.named_parameters()
    if "affine" in k[0] and "torgb" not in k[0] and "weight" in k[0] or ("torgb" in k[0] and "b4" in k[0] and "weight" in k[0] and "affine" in k[0])
}

weight_mat = []
for k, v in modulate.items():
    weight_mat.append(v)

W = torch.cat(weight_mat, 0)
eigvec = torch.linalg.svd(W).V.to("cpu")

directory = os.path.dirname(features_save_path)
if not os.path.exists(directory):
    os.makedirs(directory)

torch.save({"ckpt": network_pkl, "eigvec": eigvec}, features_save_path)

## Generating Grids

In [None]:
#@title functions (hidden code)
# define functions

def lerp(zs, steps):
  out = []
  for i in range(len(zs)-1):
      for index in range(steps):
          t = index/float(steps)
          out.append(zs[i+1]*t + zs[i]*(1-t))
  return out

def concat_tile(im_list_2d):
    return cv2.vconcat([cv2.hconcat(im_list_h) for im_list_h in im_list_2d])

def save_grid(images, save_path, save_name):
  # resize image
  resized_images = []
  for i in range(len(images)):
    temp = []
    for j in range(len(images[i])):
      im = cv2.resize(np.float32(images[i][j]), (256, 256))
      im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
      temp.append(im)
    resized_images.append(temp)
    

  # save grid
  img_grid = concat_tile(resized_images)

  if not os.path.exists(save_path):
    os.makedirs(save_path)

  path = os.path.join(save_path, save_name)
  cv2.imwrite(path, img_grid)

def gen_feature_grid(rows, cols, feature_index, strength, truncation, save_location):

  label = torch.zeros([1, G.c_dim], device=device)
  truncation_psi = truncation
  noise_mode = 'const'

  zgrid = [[None for i in range(cols)] for j in range(rows)] # [row][col]

  for row in range(rows):
    z = np.random.RandomState(row).randn(1, G.z_dim)
    z = torch.from_numpy(z)
    z = z.cpu()

    current_eigvec = eigvec[feature_index]
    direction = strength * current_eigvec

    col_zvals = lerp([z - direction, z + direction], cols)
    for col in range(len(col_zvals)):
      zgrid[row][col] = col_zvals[col]

  all_images = []
  for row in range(len(zgrid)):
    zs = torch.cat(zgrid[row]).to(device)
    generated = G(zs, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
    generated = (generated.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    all_images.append(generated.cpu())

  save_name = "feature{feature_index}_seeds{seeds_start}-{seeds_end}.png".format(feature_index=feature_index, seeds_start = 0, seeds_end = rows - 1)
  save_grid(all_images, save_location, save_name)

In [None]:
# use it
num_seeds = 20 # starts at seed 0
num_rows = 7
feature_index = 1
strength = 20.0
truncation = 0.8
save_dir = '/content/output'

gen_feature_grid(num_seeds, num_rows, feature_index, strength, truncation, save_dir)

In [None]:
# in a loop

feature_start = 0
feature_end = 511

num_seeds = 12 # starts at seed 0
num_rows = 7
strength = 20.0
truncation = 0.8
save_dir = '/content/drive/MyDrive/art/stylegan/suicide-girls/pytorch-v1-cp1/feature-vectors/grids/'

for i in range(feature_start, feature_end):
  gen_feature_grid(num_seeds, num_rows, i, strength, truncation, save_dir)