# StyleGAN2-ADA-PyTorch

**Notes**
This is based on Derrick Schultz's [SG2-ADA-PyTorch notebook](https://colab.research.google.com/github/dvschultz/stylegan2-ada-pytorch/blob/main/SG2_ADA_PyTorch.ipynb).

## Setup

Let’s start by checking to see what GPU we’ve been assigned.

In [None]:
!nvidia-smi -L

Next let’s connect our Google Drive account.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Install repo

The next cell will install the StyleGAN repository in Google Drive. If you have already installed it it will just move into that folder. If you don’t have Google Drive connected it will just install the necessary code in Colab.

In [None]:
import os
!pip install gdown --upgrade

if os.path.isdir("/content/drive/MyDrive/colab-sg2-ada-pytorch"):
    %cd "/content/drive/MyDrive/colab-sg2-ada-pytorch/stylegan2-ada-pytorch"
elif os.path.isdir("/content/drive/"):
    #install script
    %cd "/content/drive/MyDrive/"
    !mkdir colab-sg2-ada-pytorch
    %cd colab-sg2-ada-pytorch
    !git clone https://github.com/dvschultz/stylegan2-ada-pytorch
    %cd stylegan2-ada-pytorch
    !mkdir downloads
    !mkdir datasets
    !mkdir pretrained
    !gdown --id 1-5xZkD8ajXw1DdopTkH_rAoCsD72LhKU -O /content/drive/MyDrive/colab-sg2-ada-pytorch/stylegan2-ada-pytorch/pretrained/wikiart.pkl
else:
    !git clone https://github.com/dvschultz/stylegan2-ada-pytorch
    %cd stylegan2-ada-pytorch
    !mkdir downloads
    !mkdir datasets
    !mkdir pretrained
    %cd pretrained
    !gdown --id 1-5xZkD8ajXw1DdopTkH_rAoCsD72LhKU
    %cd ../

!pip install ninja opensimplex torch==1.7.1 torchvision==0.8.2

## Dataset Preparation

Upload a .zip of square images to the `datasets` folder.

## Train model

Below are a series of variables you need to set to run the training. You probably won’t need to touch most of them.

* `dataset_path`: this is the path to your .zip file
* `resume_from`: we've uploaded the "metfaces" file for network parameters.
* `mirror_x` and `mirror_y`: Allow the dataset to use horizontal or vertical mirroring.

In [None]:
resume_from = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl'
aug_strength = 0.0

In [None]:
!wget https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl
!mv celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl /content/


In [None]:
%cd /content/
!git clone https://github.com/gtamba/pytorch-slim-cnn 

In [None]:
%cd pytorch-slim-cnn/

In [None]:
from slimnet import SlimNet

In [None]:
%cd /content/drive/MyDrive/colab-sg2-ada-pytorch/stylegan2-ada-pytorch

In [None]:
import torch
import pickle
import PIL
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt
import random

In [None]:
device = torch.device('cuda')
model = SlimNet.load_pretrained('/content/pytorch-slim-cnn/models/celeba_20.pth').to(device)

In [None]:
labels = np.array(['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
       'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
       'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
       'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
       'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
       'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline',
       'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair',
       'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick',
       'Wearing_Necklace', 'Wearing_Necktie', 'Young'])

In [None]:
transform = transforms.Compose([
                              transforms.Resize((178,218)),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ])

In [None]:
with open('/content/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl', 'rb') as f:
    G = pickle.load(f)['G_ema'].cuda()  # torch.nn.Module

## Get images with/without glasses

This was executed once to identify multiple images with and without the presence of glasses. Then, we handpicked the good ones (since some of the images were rather badly generated) and saved these images and the points in latent space used to generate them into .npy files, so we don't have to run the code and handpick the images again every time.

In [None]:
def generate_img(G,w):
  # G is a Generator and w is the point in the mapping space (as a tensor)
  img = G.synthesis(w, noise_mode='const', force_fp32=True)
  img = (img.squeeze(0).permute(1, 2, 0) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
  img = img.cpu().numpy()
  return img

def img_to_tensor(img):
  return transform(PIL.Image.fromarray(img, 'RGB')).unsqueeze(0).to(device)

In [None]:
# n_imgs = 1
# z = torch.randn([n_imgs, G.z_dim]).cuda()    # latent codes
# c = None                                # class labels (not used in this example)]
# w = G.mapping(z, c, truncation_psi=0.3, truncation_cutoff=8)
# img = G.synthesis(w, noise_mode='const', force_fp32=True)

In [None]:
# img = (img.squeeze(0).permute(1, 2, 0) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
# x = transform(PIL.Image.fromarray(img.cpu().numpy(), 'RGB')).unsqueeze(0).to(device)

In [None]:
# PIL.Image.fromarray(img.cpu().numpy(), 'RGB')

In [None]:
w_female = []
z_female = []
for cont in range(200):
  #target = set([''])
  target = set(['Male'])
  while 'Male' in target:
    n_imgs = 1
    z = torch.randn([n_imgs, G.z_dim]).cuda()    # latent codes
    c = None                                # class labels (not used in this example)]
    w = G.mapping(z, c, truncation_psi=1, truncation_cutoff=8)
    img = G.synthesis(w, noise_mode='const', force_fp32=True)

    img = (img.squeeze(0).permute(1, 2, 0) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    x = transform(PIL.Image.fromarray(img.cpu().numpy(), 'RGB')).unsqueeze(0).to(device)
    with torch.no_grad():
      model.eval()
      logits = model(x)
      sigmoid_logits = torch.sigmoid(logits)
      predictions = (sigmoid_logits > 0.5).squeeze().cpu().numpy()
    target = set(labels[predictions.astype(bool)])
  w_female.append(w)
  z_female.append(z)

In [None]:
len(z_male)

In [None]:
z_no_glasses_list = np.array([z.cpu().numpy() for z in z_no_glasses])
np.save('/content/drive/MyDrive/IMA206-Project/z_no_glasses_list.npy', np.array(z_no_glasses_list))

In [None]:
drive_path = "/content/drive/MyDrive/IMA206-Project/Female/"
import imageio
for i in range(200):
  z = z_female[i]
  w = w_female[i]
  img = generate_img(G, w)
  imageio.imwrite(drive_path + f'female_{i:03d}.jpg', img)

In [None]:
w_noglasses = []
for cont in range(400):
  target = set(['Eyeglasses'])
  while 'Eyeglasses' in target:
    n_imgs = 1
    z = torch.randn([n_imgs, G.z_dim]).cuda()    # latent codes
    c = None                                # class labels (not used in this example)]
    w = G.mapping(z, c, truncation_psi=1, truncation_cutoff=8)
    img = G.synthesis(w, noise_mode='const', force_fp32=True)

    img = (img.squeeze(0).permute(1, 2, 0) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    x = transform(PIL.Image.fromarray(img.cpu().numpy(), 'RGB')).unsqueeze(0).to(device)
    with torch.no_grad():
      model.eval()
      logits = model(x)
      sigmoid_logits = torch.sigmoid(logits)
      predictions = (sigmoid_logits > 0.5).squeeze().cpu().numpy()
    target = set(labels[predictions.astype(bool)])
  w_noglasses.append(w)

In [None]:
for i in range(400):
  img = G.synthesis(w_noglasses[i], noise_mode='const', force_fp32=True)
  img = (img.squeeze(0).permute(1, 2, 0) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
  img_print = np.array(PIL .Image.fromarray(img.cpu().numpy(), 'RGB'))
  plt.ion()
  plt.figure()
  plt.imshow(img_print)
  plt.title(f'{i}')
  plt.show()
  _ = input('')
  plt.close()

In [None]:
male_index = [0, 1, 3, 4, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 23, 27, 30, 31, 34, 38, 39, 40, 42, 44, 50, 51, 52, 53, 54, 55, 56, 57, 59, 61, 63, 65, 67, 68, 69, 74, 76, 78, 79, 82, 83, 84, 85, 86, 88, 93, 94, 96, 103, 108, 109, 113]
female_index = [3, 6, 7, 8, 11, 12, 13, 16, 18, 20, 25, 26, 27, 28, 29, 31, 32, 33, 34, 35, 38, 39, 41, 42, 43, 44, 45, 49, 50, 51, 58, 68, 71, 74, 76, 83, 84, 91, 92, 93, 95, 97, 98, 101, 104, 106, 107, 108, 109, 110, 112, 116, 119, 121, 122, 123, 124, 126, 127, 128]

young_index = [4, 9, 12, 14, 16, 17, 21, 22, 25, 37, 40, 50, 52, 56, 57, 61, 63, 71, 73, 75, 76, 77, 78, 80, 81, 82, 86, 88, 92, 94, 101, 113, 115, 118, 120, 121, 129, 131, 140, 144, 150, 152, 153, 160, 161, 166, 167, 168, 170, 171, 172, 173, 174, 178, 180, 181, 189, 193, 194, 195]
old_index = [4, 7, 9, 10, 11, 18, 19, 21, 22, 27, 31, 32, 36, 38, 40, 42, 45, 47, 49, 50, 52, 54, 56, 57, 59, 62, 64, 65, 68, 69, 74, 75, 79, 87, 95, 101, 106, 109, 111, 114, 117, 119, 120, 121, 125, 130, 131, 133, 134, 138, 139, 141, 150, 151, 153, 155, 164, 167, 176, 178]

no_glasses_index = [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 19, 23, 24, 25, 27, 28, 29, 31, 32, 33, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 47, 50, 51, 52, 53, 54, 55, 57, 61, 64, 65, 66, 70, 73, 74, 77, 81, 82, 87, 90, 92, 95, 100, 103]
glasses_index = [1, 6, 7, 8, 14, 16, 17, 18, 20, 26, 29, 30, 33, 34, 37, 38, 40, 43, 47, 49, 50, 58, 59, 60, 64, 58, 71, 74, 75, 76, 77, 80, 85, 86, 90, 93, 95, 97, 102, 105, 107, 113, 114, 115, 120, 125, 128, 129, 130, 131, 147, 153, 154, 155, 156, 157, 159, 165, 171, 176]


In [None]:
z_no_glasses_list = np.array([z.cpu().numpy() for z in z_no_glasses])
w_no_glasses_list = np.array([w.cpu().numpy() for w in w_no_glasses])

z_no_glasses_selected = z_no_glasses_list[no_glasses_index]
w_no_glasses_selected = w_no_glasses_list[no_glasses_index]

np.save('/content/drive/MyDrive/IMA206-Project/w_no_glasses_selected', np.array(w_no_glasses_selected))
np.save('/content/drive/MyDrive/IMA206-Project/z_no_glasses_selected', np.array(z_no_glasses_selected))

In [None]:
w_male_list = np.array([x.cpu().numpy() for x in w_male])
w_male_selected = w_male_list[male_index]

np.save('/content/w_male_list.npy', np.array(w_male_list))
np.save('/content/w_male_selected.npy', np.array(w_male_selected))

In [None]:
w_glasses_list = np.array([x.cpu().numpy() for x in w_glasses])
w_glasses_list

In [None]:
w_glasses_selected = w_glasses_list[glasses_index]

In [None]:
w_noglasses_list = np.array([x.cpu().numpy() for x in w_noglasses])
w_noglasses_list

In [None]:
w_noglasses_selected = w_noglasses_list[noglasses_index]

In [None]:
np.save('/content/w_glasses_list.npy', np.array(w_glasses_list))
np.save('/content/w_noglasses_list.npy', np.array(w_noglasses_list))
np.save('/content/w_glasses_selected.npy', np.array(w_glasses_selected))
np.save('/content/w_noglasses_selected.npy', np.array(w_noglasses_selected))

In [None]:
logits = model(x)
sigmoid_logits = torch.sigmoid(logits)
predictions = (sigmoid_logits > 0.5).squeeze().numpy()

print(labels[predictions.astype(bool)])

In [None]:
img_testing = G.synthesis(w[35].unsqueeze(0), noise_mode='const', force_fp32=True)
img_testing = (img_testing.squeeze(0).permute(1, 2, 0) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
img_testing = np.array(PIL.Image.fromarray(img_testing.cpu().numpy(), 'RGB'))
plt.figure()
imshow(img_testing)

In [None]:
sunglasses.append(w[11])
glasses.append(w[34])
hat.append(w[36])

## Load saved data (see above)

In [None]:
!pwd

In [None]:
!gdown --id 1FFC0tS5YtktEnEA0hTC8Qbqvd1TVVBpR -O /content/w_noglasses_selected.npy
!gdown --id 1FAZQGqlTPZHcxCJ3X4u2E3mZVCPlZ90_ -O /content/w_noglasses_list.npy
!gdown --id 1OoggWJ0OBXtL0WzskSbU-ufNNqYIAjjY -O /content/w_glasses_selected.npy
!gdown --id 1n0_gixUalPr784s7UBS0CgU78R4Iz7BM -O /content/w_glasses_list.npy

In [None]:
w_noglasses_selected = torch.from_numpy(np.load("/content/w_noglasses_selected.npy"))
w_noglasses_list = torch.from_numpy(np.load("/content/w_noglasses_list.npy"))
w_glasses_selected = torch.from_numpy(np.load("/content/w_glasses_selected.npy"))
w_glasses_list = torch.from_numpy(np.load("/content/w_glasses_list.npy"))

w_noglasses_selected = w_noglasses_selected.to(device)
w_noglasses_list = w_noglasses_list.to(device)
w_glasses_selected = w_glasses_selected.to(device)
w_glasses_list = w_glasses_list.to(device)

In [None]:
img_glasses = [generate_img(G,w) for w in w_glasses_selected]
img_noglasses = [generate_img(G,w) for w in w_noglasses_selected]

# Transform tensors to numpy
w_noglasses_arr = np.squeeze(w_noglasses_selected.cpu().numpy())
w_glasses_arr = np.squeeze(w_glasses_selected.cpu().numpy())

# Space shape
space_shape = w_noglasses_arr[0].shape

In [None]:
# Create X (points in mapping space) and y (classes)
X = np.concatenate((w_noglasses_arr, w_glasses_arr))
y = np.concatenate((np.zeros(w_noglasses_arr.shape[0]), np.ones(w_glasses_arr.shape[0])))

# Sort the samples
idx = random.sample(list(range(len(y))), len(y))
X = X[idx].reshape(X.shape[0], -1)
y = y[idx]

## Getting separating hyperplane using SVM

In [None]:
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

In [None]:
clf_pipe = Pipeline([('scaler', StandardScaler()),
                          ('clf', SVC(gamma='auto', kernel='linear'))])
clf_pipe.fit(X, y)


In [None]:
hyperplane = clf_pipe['clf'].coef_[0]
hyperplane = hyperplane.reshape(space_shape)
hyperplane = torch.from_numpy(hyperplane).to(device)

In [None]:
drive_path = "/content/drive/MyDrive/colab-sg2-ada-pytorch/stylegan2-ada-pytorch/saved_images/"
import imageio
images = []
img_n = 35
for i in range(500):
  w = w_noglasses_selected[img_n] + hyperplane*i
  img = generate_img(G, w)
  images.append(img)
imageio.mimsave(drive_path + f'moving_{img_n}.gif', images, duration=0.02)

## Testing/Inference

Also known as "Inference", "Evaluation" or "Testing" the model. This is the process of usinng your trained model to generate new material, usually images or videos.

### Generate Single Images

`--network`: Make sure the `--network` argument points to the .pkl file.

`--seeds`: This allows you to choose random seeds from the model. Remember that our input to StyleGAN is a 512-dimensional array. These seeds will generate those 512 values. Each seed will generate a different, random array. The same seed value will also always generate the same random array, so we can later use it for other purposes like interpolation.

`--truncation`: Truncates the latent space. This can have a subtle or dramatic affect on your images depending on the value you use. The smaller the number the more realistic your images should appear, but this will also affect diversity. Most people choose between 0.5 and 1.0, but technically it's infinite. 


In [None]:
!python generate.py --outdir=/content/out/images/ --trunc=1 --seeds=85,265,297,84 --network=$resume_from

In [None]:
!ls /content/out/images/

In [None]:
import skimage
import glob
import matplotlib.pyplot as plt
from skimage.io import imshow

img_names = glob.glob("/content/out/images/*")

for fn in img_names:
  plt.figure()
  imshow(fn)

In [None]:
!rm /content/out/images/seed*.png