<a href="https://colab.research.google.com/github/dvschultz/ml-art-colabs/blob/master/Advanced_StyleGAN_Network_bending.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Network Bending
## Manipulate StyleGAN2 models through rotation, translation, etc.

[Paper](https://arxiv.org/abs/2005.12420) | [Video](https://youtu.be/IlSMQ2RRTh8) | [GitHub](https://github.com/terrybroad/network-bending)

Thanks to [Sid Black](https://twitter.com/realmeatyhuman) for a lot of the code used here.

## Install Library

This code uses the PyTorch version of StyleGAN2. Because of that the install process for this may take a couple minutes.

In [1]:
!nvidia-smi

Thu Jun 17 01:39:02 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    23W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
# Install libraries
!git clone -b audio-animate https://github.com/dvschultz/network-bending
!pip uninstall torch torchvision -y
!pip install torch==1.5.0+cu101 torchvision==0.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
!pip install Ninja kmeans-pytorch
!apt-get install -y vim make gdb libopencv-dev
!wget https://download.pytorch.org/libtorch/cu101/libtorch-shared-with-deps-1.5.0%2Bcu101.zip
!unzip /content/libtorch-shared-with-deps-1.5.0+cu101.zip -d /root/
%cd network-bending

#build custom pytorch transformations
!chmod +x /content/network-bending/build_custom_transforms.sh
!/content/network-bending/build_custom_transforms.sh /root/libtorch/

Cloning into 'network-bending'...
remote: Enumerating objects: 369, done.[K
remote: Counting objects: 100% (37/37), done.[K
remote: Compressing objects: 100% (28/28), done.[K
remote: Total 369 (delta 22), reused 21 (delta 9), pack-reused 332[K
Receiving objects: 100% (369/369), 21.44 MiB | 7.39 MiB/s, done.
Resolving deltas: 100% (213/213), done.
Uninstalling torch-1.8.1+cu101:
  Successfully uninstalled torch-1.8.1+cu101
Uninstalling torchvision-0.9.1+cu101:
  Successfully uninstalled torchvision-0.9.1+cu101
Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.5.0+cu101
[?25l  Downloading https://download.pytorch.org/whl/cu101/torch-1.5.0%2Bcu101-cp37-cp37m-linux_x86_64.whl (703.8MB)
[K     |████████████████████████████████| 703.8MB 25kB/s 
[?25hCollecting torchvision==0.6.0+cu101
[?25l  Downloading https://download.pytorch.org/whl/cu101/torchvision-0.6.0%2Bcu101-cp37-cp37m-linux_x86_64.whl (6.6MB)
[K     |████████████████████████████████| 

In [3]:
def show_local_mp4_video(file_name, width=640, height=640):
  import io
  import base64
  from IPython.display import HTML
  video_encoded = base64.b64encode(io.open(file_name, 'rb').read())
  return HTML(data='''<video width="{0}" height="{1}" alt="test" controls>
                        <source src="data:video/mp4;base64,{2}" type="video/mp4" />
                      </video>'''.format(width, height, video_encoded.decode('ascii')))

## Download .pt file

As mentioned above, this library uses the PyTorch version of StyleGAN2. If you have a .pkl file, you’ll need to convert it to a .pt file. I have a notebook to do that [here](https://colab.research.google.com/github/dvschultz/stylegan2-ada-pytorch/blob/main/SG2_ADA_PT_to_Rosinality.ipynb).

In [4]:
!gdown --id 1rL-J63eFfn80IYU2GfVY977GI2qOG6dw -O /content/ladiesblack.pt

Downloading...
From: https://drive.google.com/uc?id=1rL-J63eFfn80IYU2GfVY977GI2qOG6dw
To: /content/ladiesblack.pt
133MB [00:01, 83.3MB/s]


## Generate Image Samples (and Latent Vectors)

This script will generate "normal" images because the transform config is blank. I recommend doing this initially so you know what images you want to work with.

In [5]:
!python generate.py --help

usage: generate.py [-h] [--size SIZE] [--sample SAMPLE] [--pics PICS]
                   [--truncation TRUNCATION]
                   [--truncation_mean TRUNCATION_MEAN] [--ckpt CKPT]
                   [--channel_multiplier CHANNEL_MULTIPLIER] [--config CONFIG]
                   [--load_latent LOAD_LATENT] [--clusters CLUSTERS]
                   [--dir DIR] [--save_latent SAVE_LATENT]

optional arguments:
  -h, --help            show this help message and exit
  --size SIZE
  --sample SAMPLE
  --pics PICS
  --truncation TRUNCATION
  --truncation_mean TRUNCATION_MEAN
  --ckpt CKPT
  --channel_multiplier CHANNEL_MULTIPLIER
  --config CONFIG
  --load_latent LOAD_LATENT
  --clusters CLUSTERS
  --dir DIR             path to output samples
  --save_latent SAVE_LATENT


In [6]:
!python generate.py \
--ckpt /content/ladiesblack.pt \
--pics 20 \
--config /content/network-bending/configs/empty_transform_config.yaml \
--save_latent 1 

  0% 0/20 [00:00<?, ?it/s]torch.Size([1, 512])
  5% 1/20 [00:00<00:11,  1.72it/s]torch.Size([1, 512])
 10% 2/20 [00:01<00:09,  1.83it/s]torch.Size([1, 512])
 15% 3/20 [00:01<00:08,  2.00it/s]torch.Size([1, 512])
 20% 4/20 [00:01<00:08,  1.93it/s]torch.Size([1, 512])
 25% 5/20 [00:02<00:07,  2.02it/s]torch.Size([1, 512])
 30% 6/20 [00:02<00:06,  2.08it/s]torch.Size([1, 512])
 35% 7/20 [00:03<00:06,  2.14it/s]torch.Size([1, 512])
 40% 8/20 [00:03<00:05,  2.13it/s]torch.Size([1, 512])
 45% 9/20 [00:04<00:05,  2.14it/s]torch.Size([1, 512])
 50% 10/20 [00:04<00:04,  2.18it/s]torch.Size([1, 512])
 55% 11/20 [00:05<00:04,  2.21it/s]torch.Size([1, 512])
 60% 12/20 [00:05<00:03,  2.15it/s]torch.Size([1, 512])
 65% 13/20 [00:06<00:03,  2.23it/s]torch.Size([1, 512])
 70% 14/20 [00:06<00:02,  2.22it/s]torch.Size([1, 512])
 75% 15/20 [00:06<00:02,  2.19it/s]torch.Size([1, 512])
 80% 16/20 [00:07<00:01,  2.19it/s]torch.Size([1, 512])
 85% 17/20 [00:07<00:01,  2.27it/s]torch.Size([1, 512])
 90% 18/20

In [7]:
!zip -r samples-ladiesblack-400.zip ./sample

  adding: sample/ (stored 0%)
  adding: sample/000004.png (deflated 0%)
  adding: sample/000006.png (deflated 0%)
  adding: sample/000008.png (deflated 0%)
  adding: sample/000005.png (deflated 0%)
  adding: sample/000017.png (deflated 0%)
  adding: sample/000010.png (deflated 0%)
  adding: sample/000018.png (deflated 0%)
  adding: sample/000009.png (deflated 0%)
  adding: sample/000012.png (deflated 0%)
  adding: sample/000011.png (deflated 0%)
  adding: sample/000014.png (deflated 0%)
  adding: sample/000019.png (deflated 0%)
  adding: sample/config.yaml (deflated 21%)
  adding: sample/latents.yaml (deflated 55%)
  adding: sample/000013.png (deflated 0%)
  adding: sample/000015.png (deflated 0%)
  adding: sample/000016.png (deflated 0%)
  adding: sample/000007.png (deflated 0%)
  adding: sample/000002.png (deflated 0%)
  adding: sample/000001.png (deflated 0%)
  adding: sample/000000.png (deflated 0%)
  adding: sample/000003.png (deflated 0%)


If we want to generate images using transformations we have to create a config file, update its values and then run the `generate.py` script using the same config.

Note: order of transforms does matter!

In [8]:
!cp ./configs/empty_transform_config.yaml ./configs/custom_transform_config.yaml 

In [9]:
%%writefile ./configs/custom_transform_config.yaml 
transforms:
- layer: 1
  transform: "rotate"
  params: [45.0]
  features: "all"
- layer: 10
  transform: "rotate"
  params: [45.0]
  features: "all"
- layer: 3
  transform: "translate"
  params: [-0.5, -0.25] #range is -1 to 1
  features: "all"
- layer: 8
  transform: "scale"
  params: [1.5]
  features: "all"
- layer: 15
  transform: "flip-h"
  params: []
  features: "all"

Overwriting ./configs/custom_transform_config.yaml


In [10]:
!python generate.py --ckpt /content/ladiesblack.pt --pics 10 --config ./configs/custom_transform_config.yaml --dir '/content/custom-samples/'

  0% 0/10 [00:00<?, ?it/s]torch.Size([1, 512])
 10% 1/10 [00:01<00:09,  1.10s/it]torch.Size([1, 512])
 20% 2/10 [00:01<00:08,  1.02s/it]torch.Size([1, 512])
 30% 3/10 [00:02<00:06,  1.05it/s]torch.Size([1, 512])
 40% 4/10 [00:03<00:05,  1.14it/s]torch.Size([1, 512])
 50% 5/10 [00:04<00:04,  1.19it/s]torch.Size([1, 512])
 60% 6/10 [00:04<00:03,  1.22it/s]torch.Size([1, 512])
 70% 7/10 [00:05<00:02,  1.23it/s]torch.Size([1, 512])
 80% 8/10 [00:06<00:01,  1.23it/s]torch.Size([1, 512])
 90% 9/10 [00:07<00:00,  1.21it/s]torch.Size([1, 512])
100% 10/10 [00:08<00:00,  1.22it/s]


You can also generate strips of images where the transform is applied to every single layer insequence. Note this requires a separate transformation config file as well.

In [11]:
!cp ./configs/sample_strip_config.yaml ./configs/custom_strip_config.yaml 

In [12]:
%%writefile ./configs/custom_strip_config.yaml 
transform: "rotate"
params: [45.0]
features: "all"
feature-param: 

Overwriting ./configs/custom_strip_config.yaml


In [13]:
!python generate_sample_strips.py \
--ckpt /content/ladiesblack.pt \
--pics 5 \
--config ./configs/custom_strip_config.yaml   \
--dir '/content/strips/'

100% 5/5 [00:21<00:00,  4.34s/it]


## Animating Vectors: Script version

You can use a script based version if you want to create interpolations with single transformations.

*   `--num_frames`: how many frames to produce (this value/fps in video = length of animation)
*   `--transform`: transform function you want to use
*   `--init_val`, `--end_val`: starting and stoppping points for linear transformation over the total frames
* `--layer_id`: which of the StyleGAN layers to apply the transformation to. Lower IDs will affect more of the structure, higher IDs will affect more of the details.
* `--interpolate_ids`: which of the StyleGAN layers to 

apply the transformation to. Lower IDs will affect more of the structure, higher IDs will affect more of the details.




In [21]:
!python animate.py \
--ckpt /content/ladiesblack.pt \
--load_latent /content/network-bending/sample/latents.yaml \
--interpolate_ids=10,10 \
--latent_id 0 \
--num_frames 240 \
--transform "scale" \
--init_val 0.0 \
--end_val 3.0 \
--layer_id=1 \
--truncation=0.6 \
--noise_interpolation \
--dir="ladiesblack-scale-test"

[0.0]
1
  0% 0/240 [00:00<?, ?it/s]animating frame: 0 , param: 0.0125
240 0/2 0.0
  0% 1/240 [00:00<03:22,  1.18it/s]animating frame: 1 , param: 0.025
240 0/2 0.004166666666666667
  1% 2/240 [00:01<03:13,  1.23it/s]animating frame: 2 , param: 0.037500000000000006
240 0/2 0.008333333333333333
  1% 3/240 [00:02<03:07,  1.27it/s]animating frame: 3 , param: 0.05
240 0/2 0.0125
  2% 4/240 [00:03<03:02,  1.30it/s]animating frame: 4 , param: 0.0625
240 0/2 0.016666666666666666
  2% 5/240 [00:03<02:58,  1.32it/s]animating frame: 5 , param: 0.07500000000000001
240 0/2 0.020833333333333332
  2% 6/240 [00:04<02:55,  1.33it/s]animating frame: 6 , param: 0.08750000000000001
240 0/2 0.025
  3% 7/240 [00:05<02:53,  1.34it/s]animating frame: 7 , param: 0.1
240 0/2 0.029166666666666667
  3% 8/240 [00:05<02:51,  1.36it/s]animating frame: 8 , param: 0.1125
240 0/2 0.03333333333333333
  4% 9/240 [00:06<02:49,  1.36it/s]animating frame: 9 , param: 0.125
240 0/2 0.0375
  4% 10/240 [00:07<02:48,  1.37it/s]a

In [19]:
!rm -r /content/network-bending/ladiesblack-scale-test

In [15]:
show_local_mp4_video('/content/network-bending/ladiescrop28-rotate-layer3.mp4', width=720, height=720)

FileNotFoundError: ignored

### Multiple Transforms

In [None]:
!python animate.py \
--ckpt /content/ladiesblack.pt \
--load_latent /content/network-bending/sample/latents.yaml \
--latent_id 0 \
--num_frames 72 \
--transform "scale,rotate,scale" \
--init_val 0.0,0.0,0.5 \
--end_val 2.0,360.0,1.0 \
--layer_id 3,1,10 \
--truncation=0.5 \
--interpolate_ids=1,4,11,1 \
--dir="ladiesblack-multi-test"

### Interpolating Noise

You may find that the detail textures on your bend animations seem too similar. You can add `--noise_interpolation` 


In [None]:
!python animate.py \
--ckpt /content/ladiesblack.pt \
--load_latent /content/network-bending/sample/latents.yaml \
--latent_id 0 \
--num_frames 72 \
--transform "scale,rotate,scale" \
--init_val 0.0,0.0,0.5 \
--end_val 2.0,360.0,1.0 \
--layer_id 3,1,10 \
--truncation=0.5 \
--interpolate_ids=1,4,11,1 \
--dir="ladiesblack-multi-test-noise2" \
--noise_interpolation

## Generate Clusters

More TK on this section because honestly I’m not sure this is all that’s needed.

In [None]:
!python get_clusters.py --ckpt /content/network-bending/FreaGAN-10k.pt

##Handmade stuff


In [None]:
import os
import copy
import torch
import yaml

from torchvision import utils
from model import Generator
from tqdm import tqdm
from util import *

frames = 120

# https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/4
def slerp(val, low, high):
    low_norm = low/torch.norm(low, dim=1, keepdim=True)
    high_norm = high/torch.norm(high, dim=1, keepdim=True)
    omega = torch.acos((low_norm*high_norm).sum(1))
    so = torch.sin(omega)
    res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
    print(res.shape)
    return res

device = "cuda"
g_ema = Generator(
        1024, 512, 8, 2
    ).to(device)
checkpoint = torch.load('/content/ladiesblack.pt')
g_ema.load_state_dict(checkpoint["g_ema"])

with torch.no_grad():
    mean_latent = g_ema.mean_latent(4096)

yaml_config = {}
with open('/content/network-bending/configs/empty_transform_config.yaml', 'r') as stream:
    try:
        yaml_config = yaml.load(stream)
    except yaml.YAMLError as exc:
        print(exc)

cluster_config = {}
layer_channel_dims = create_layer_channel_dim_dict(2)

# create noise
noise = [getattr(g_ema.noises, f"noise_{i}") for i in range(g_ema.num_layers)]

noise2 = copy.deepcopy(noise)
for i,n in enumerate(noise2):
    if len(n[0][0]) < 256:
        # print('update: ', n.shape)
        noise2[i] = (0.1**0.5)*torch.randn_like(n)

# only slerp for lower layers, keep defaults for higher layers (won't fit in VRAM)
# noise_slerps = []
# for f in range(int(frames/2)):
#     ns = []
#     for i in range(len(noise)):
#         if len(noise[i][0][0]) < 256:
#             # print('update: ', noise[i].shape)
#             ns.append(slerp(f/(frames/2), noise[i], noise2[i]))
#     noise_slerps.append(ns)

noise_slerps = []
for f in range(int(frames/2)):
    ns = []
    for i in range(len(noise)):
        # if len(noise[i][0][0]) < 256:
            # print('update: ', noise[i].shape)
        #ns.append(slerp((f/(frames/2)), noise[i], noise2[i]))
        ns.append( torch.lerp( noise[i], noise2[i], (f/(frames/2)) ) )
    noise_slerps.append(ns)

#print(len(noise_slerps))

with torch.no_grad():
    g_ema.eval()
    t_dict_list = create_transforms_dict_list(yaml_config, cluster_config, layer_channel_dims)
    
    sample_z = torch.randn(1, 512, device=device)
    for i in tqdm(range(len(noise_slerps))):
        
        sample, _ = g_ema([sample_z], truncation=0.7, noise=noise_slerps[i], truncation_latent=mean_latent, transform_dict_list=t_dict_list)
        #sample2, _ = g_ema([sample_z], truncation=0.7, noise=noise, truncation_latent=mean_latent, transform_dict_list=t_dict_list)

        if not os.path.exists('interpolations'):
            os.makedirs('interpolations')

        utils.save_image(
            sample,
            f'interpolations/{str(i).zfill(6)}.png',
            nrow=1,
            normalize=True,
            range=(-1, 1))

In [None]:
!ffmpeg -r 24 -i /content/network-bending/interpolations/%06d.png -vcodec libx264 -pix_fmt yuv420p noise-test.mp4 -y

## Helper functions

In [None]:
rm -rf /content/network-bending/sample-animation/frame*.png

In [None]:
!gdown --id 1rL-J63eFfn80IYU2GfVY977GI2qOG6dw -O /content/ladiesblack.pt