In [None]:
## Use the latest StyleGAN github
!git clone https://github.com/NVlabs/stylegan3.git

In [None]:
#@markdown Install additional dependencies
!pip install click requests pyspng tqdm pyopengl==3.1.5 ninja==1.10.2 imageio-ffmpeg==0.4.3 imgui==1.3.0

In [None]:
#@markdown Move working directory

## move to the stylegan3 folder
%cd /content/stylegan3 

In [None]:
#@markdown Import libraries
import torch
import argparse
import os
import pickle
import re

import numpy as np
import PIL.Image


from typing import List, Optional

import legacy
import click
import copy

import matplotlib.pyplot as plt

import random
import itertools

from sklearn.decomposition import FastICA, PCA, IncrementalPCA, MiniBatchSparsePCA, SparsePCA, KernelPCA

import dnnlib

print(torch.version.cuda)
print(torch.__version__)
print(torch.cuda.get_device_name(device=None))


In [None]:
#@markdown Set up random seed for reproducibility

random.seed(30)
sample_size = 20
seed_list = random.sample(range(0, 100000), sample_size)

In [None]:
#@markdown Utility functions
def loading_network(network_pkl):
  """
  Load pretrained StyleGAN network.
  network_pkl: the path to the pretrained network .pkl file.
  e.g., "network-snapshot-soap8k-day9-403-pytorch.pkl"
  """
  print('Loading networks from "%s"...' % network_pkl)
  device = torch.device('cuda')
  with dnnlib.util.open_url(network_pkl) as f:
      G = legacy.load_network_pkl(f)['G_ema'].to(device) 

  print("Loading finished")
  
  return G


def edit_layer(w_codes, direction, step, layer_indices, start_distance, end_distance):  
  """
  Edit a subset of latent code.
  w_codes: a list of W or W+ latent codes.
  direction: the vector direction to move along in the latent space.
  step: the number of intermediate steps.
  layer_indices: a list of layer index. e.g., [6,7,8]
  start_distance: distance from the starting latent code
  end_distance: distance from the end latent code

  """
  x = w_codes[:, np.newaxis]

  results = np.tile(x, [step if axis == 1 else 1 for axis in range(x.ndim)])
  
  is_manipulatable = np.zeros(results.shape, dtype=bool)

  distance = np.linspace(start_distance, end_distance, step)
  l = distance.reshape(
      [step if axis == 1 else 1 for axis in range(x.ndim)])
  
  is_manipulatable[:, :, layer_indices] = True
  results = np.where(is_manipulatable,  x + l * direction, results)

  return results

def generate_from_w(G, w_codes, noise_mode = "const"):
  """
  Generate a list of images from latent codes (either W or W+)
  G: the loaded StyleGAN generator
  w_codes: a list of latent codes (either W or W+)
  
  """
  device = torch.device('cuda')
  w_codes = torch.tensor(w_codes, device=device) # pylint: disable=not-callable
  assert w_codes.shape[1:] == (G.num_ws, G.w_dim)

  generated = []
  for idx, w in enumerate(w_codes):
      img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode)
      img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
      img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
      generated.append(img)
  return generated 

def plot_subplot(img_list, n_row, n_col, width, height, show_label=True):
    """
    Create (x_row, n_col) subplot of a list of images.
    img_list: a list of PIL images
    n_row: number of rows
    n_col: number of columns
    width: width of subplot
    height: height of subplot
    show_label: if True, the column and row indices are shown

    """
    px = 1/plt.rcParams['figure.dpi']  # pixel in inches
    fig, axes = plt.subplots(nrows=n_row, ncols=n_col, figsize=(width, height))
    plt.subplots_adjust(hspace = .001)

    count = -1 

    if n_row==1:
      for j in range(n_col):
        count = count+1
        axes[j].imshow(img_list[count])
        axes[j].grid(False)
        axes[j].set_xticks([])
        axes[j].set_yticks([])

    else:
      for i in range(n_row):
          for j in range(n_col):
              count = count+1
              axes[i,j].imshow(img_list[count]) ## show image
              axes[i,j].grid(False)
              axes[i,j].set_xticks([])
              axes[i,j].set_yticks([])
            
    rows = ['Row label {}'.format(row) for row in range(n_row)]    ## labels on the Rows
    cols = ['Column label {}'.format(col) for col in range(n_col)] ## labels on the Columns

    if show_label:
        rows = ['Row label {}'.format(row) for row in range(n_row)]    ## labels on the Rows
        cols = ['Column label {}'.format(col) for col in range(n_col)] ## labels on the Columns
        for ax, col in zip(axes[0], cols):
            ax.set_title(col, size='small')

        for ax, row in zip(axes[:,0], rows):
            ax.set_ylabel(row, rotation=90, size='small')

    fig.tight_layout()
    plt.show() 

def convert_rgb(images, scale_contrast):
  """Convert the image to the pixel range for visualization"""  
  images = np.transpose(images, [0, 2, 3, 1])
  images = images - np.min(images)
  images = images / np.max(images)
  images = scale_contrast*128*(images - np.mean(images)) + 128
  images[np.where(images<0)] = 0
  images[np.where(images>255)] = 255
  images = images.astype('uint8')
  return images
  
def convert_images_to_uint8(images, drange=[-1,1]):
    """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
    Can be used as an output transformation for Network.run().
    """
    images = images.permute(0, 2, 3, 1) 
    scale = 255 / (drange[1] - drange[0])
    images = images * scale + (0.5 - drange[0] * scale)
    return images.clamp(0, 255).to(torch.uint8)


def get_intermediate_trgb(G, w_codes):
  """Collect the trgb layers' outputs"""
  G_1 = copy.deepcopy(G)
  w_codes = torch.tensor(w_codes, device="cuda:0")
  
  activation_layer = {}
  def get_activation_layer(name):
      def hook(model, input, output):
          activation_layer[name] = output.detach()
      return hook

  for name, module in G_1.synthesis.named_modules():
    if "torgb" in name and "affine" not in name:
      print(name)
      module.register_forward_hook(get_activation_layer(name))
  img = G_1.synthesis(w_codes,  noise_mode="const")
  activation_layer["b1024.torgb"] = img
  return activation_layer

def plot_generative_stages(out_levels):
  """Plot intemediate generative step at each major resolution"""
  intermediate_trgb = []
  list_layers = list(out_levels.keys())[1:]
  for layer in list_layers:
    layer = out_levels[layer].to(torch.float32)
    trgb_img = convert_rgb(layer.cpu().numpy(), 2)
    img = PIL.Image.fromarray(trgb_img[1], 'RGB')
    # trgb_img = convert_images_to_uint8(layer)
    # img = PIL.Image.fromarray(trgb_img[0].cpu().numpy(), 'RGB')

    intermediate_trgb.append(img)

  levels = ["8 x 8", "16 x 16", "32 x 32", "64 x 64", "128 x 128", "256 x 256", "512 x 512", "1024 x 1024"]
  fig, axs = plt.subplots(1, len(intermediate_trgb), figsize=(20, 10))
  for i, ax, image in zip(levels, axs, intermediate_trgb):
    ax.imshow(image)
    ax.set_title(str(i))
    ax.grid(False)
    ax.axis('off')
  plt.show()

  return intermediate_trgb

In [None]:
## Please change the path of the model (.pkl file) accordingly
G_soap_path = "/content/drive/MyDrive/StyleGAN2-multidomain/soap/training-resume-soap8k-day9-403pkl/00000-stylegan2-soap_8k_1024res-gpus1-batch4-gamma10/network-snapshot-soap8k-day9-403-pytorch.pkl"
G_soap = loading_network(G_soap_path)

In [None]:
#@markdown Visualization of Latent space

class visualize_latent_space():
  def __init__(self, 
               G,
               ):
    self.G = copy.deepcopy(G) 

    
  def sample_W_codes(
    self,
    seeds: Optional[List[int]],
    truncation_psi: float,
    noise_mode: str,
    class_idx: Optional[int],
    ):
    
    """
    Sample random latent code from W latent space
    G: the loaded StyleGAN generator.
    seeds: a list of seeds.
    save_img: if False, only the W latent codes are returned as a list.
              if True, the corresponding generated images are also returned as a list
    truncation_psi: the diversity of the generated images.
    noise_mode: use "const" for constant noise.

    """

    device = torch.device('cuda')

    if seeds is None:
        print('--seeds option is required when not using --projected-w')

    label = torch.zeros([1, self.G.c_dim], device=device)
    if self.G.c_dim != 0:
      if class_idx is None:
        print('Must specify class label with --class when using a conditional network')
      label[:, class_idx] = 1
    else:
      if class_idx is not None:
        print('--seeds option is required when not using --projected-w')

    w_samples = []

    # Sample W latent codes.
    for seed_idx, seed in enumerate(seeds):
      print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
      z = torch.from_numpy(np.random.RandomState(seed).randn(1, self.G.z_dim)).to(device)
      w_sample = self.G.mapping(z, None)[0]
      w_samples.append(w_sample)

    return torch.stack(w_samples).cpu().numpy()

  def morph_layers(
    self,
    w_codes,
    steps
    ):
    """
    Morph the specified layers in the latent space for 
    a pair of source and target latent codes 
    w_codes: a list of latent codes (either W or W+). If None, "W" space is used
    steps: number of intermediate steps between source and target
    """

    if w_codes is None:
      print("sample 2 points from W space:")
      sample_size = 2
      seed_list = random.sample(range(0, 100000), sample_size)    
      w_codes = self.sample_W_codes(seeds = seed_list, truncation_psi = 0.7, noise_mode="const", class_idx = None)
    
    # print(w_codes.shape)

    w1 = w_codes[0] # source latent code
    w2 = w_codes[1] # target latent code
    direction = w2 - w1 ## the direction from source to target latent code

    # Note:
    # early-layers manipulation: list(range(6)): 
    # middle-layers manipulation: [6,7,8]
    # later-layers manipulation: list(range(9,18))

    for layers in [list(range(18)), list(range(6)), [6,7,8], list(range(9,18))]:
      print("Changed layers index:", layers)
      res = edit_layer(np.expand_dims(w1, axis=0), direction = direction, step=steps, layer_indices = layers, start_distance=0, end_distance=1)
      imgs = generate_from_w(self.G, res[0])
      plot_subplot(imgs, n_row=1, n_col=steps, width=30, height=15, show_label=False) 




    

In [None]:
vis_soap = visualize_latent_space(G_soap)

In [None]:
#@markdown Visualize the effect of layer-manipulations in W space:
vis_soap.morph_layers(w_codes=None, steps=5)

# Encode real images into W+ latent space

In [None]:
## Move to the directory of the pixel2style2pixel encoder
%cd /content/pixel2style2pixel

In [None]:
# !python scripts/inference.py \
# --exp_dir=/content/drive/MyDrive/Real_Fake_soaps_pavlovia/milk_glycerin_soap1000/milky_soap_encoded \
# --checkpoint_path=/content/drive/MyDrive/pixel2style2pixel/soap_encoded_models/soap_encoder_8k_1024res_pkl09_403in_Wplus/checkpoints/best_model.pt \
# --data_path=/content/drive/MyDrive/Real_Fake_soaps_pavlovia/milk_glycerin_soap1000/milky_soap \
# --test_batch_size=5 \
# --test_workers=4 \
# --save_latent_type=Wplus_all \
# --couple_outputs

In [None]:
#@markdown Load W+ latent codes of milky and glycerin soaps
# Please change the path to your save .npy of the encoded images.
opaque_samples = np.load("/content/drive/MyDrive/Real_Fake_soaps_pavlovia/milk_glycerin_soap1000/milky_soap_encoded/Wplus_all_layers.npy")
trans_samples = np.load("/content/drive/MyDrive/Real_Fake_soaps_pavlovia/milk_glycerin_soap1000/glycerin_soap_encoded/Wplus_all_layers.npy")
w_plus = np.vstack((opaque_samples,trans_samples))

In [None]:
#@markdown Visualize the effect of layer-manipulations in W+ space:
opaque_index = random.randrange(500)
trans_index = random.randrange(500,1000)
print(opaque_index, trans_index)
w_codes_selected = w_plus[[opaque_index,trans_index]] ## select 1 opaque and 1 glycerin soap
vis_soap.morph_layers(w_codes=w_codes_selected, steps=5)