# Structured Dreaming - Styledream notebook
The Styledream notebook is a notebook for finetuning Stylegan2 models with CLIP.

Disclaimer: The underlying repository StructuredDreaming https://github.com/ekgren/StructuredDreaming will continually undergo changes that might break copies of this notebook.  

Author: Ariel Ekgren  
https://github.com/ekgren  
https://twitter.com/ArYoMo  

Resources:  
CLIP https://github.com/openai/CLIP  
Stylegan2 ADA https://github.com/NVlabs/stylegan2-ada-pytorch

In [None]:
!nvidia-smi

In [None]:
!pip install ftfy regex tqdm pyspng ninja imageio-ffmpeg==0.4.3
!git clone https://github.com/ekgren/StructuredDreaming.git
!pip install -e ./StructuredDreaming

!git clone https://github.com/NVlabs/stylegan2-ada-pytorch.git
import sys
# insert at 1, 0 is the script path (or '' in REPL)
sys.path.insert(1, 'stylegan2-ada-pytorch')

In [None]:
!git -C ./StructuredDreaming/ pull

In [None]:
# Imports
import random
import torch
import torchvision
import PIL
from matplotlib import pyplot as pl
from IPython.display import clear_output

# StructuredDreaming imports
from StructuredDreaming import structure
from StructuredDreaming.structure import clip
from StructuredDreaming.structure import sample
from StructuredDreaming.structure import optim

# Stylegan imports
import dnnlib
import legacy

In [None]:
# Load models
perceptor, normalize_image = structure.clip.load('ViT-B/16', jit=False)

In [None]:
# Utils
def display_img(input: torch.Tensor, size: float = 1.):
    """ Assumes tensor values in the range [0, 1] """
    with torch.no_grad():
        batch_size, num_channels, height, width = input.shape
        img = torch.nn.functional.interpolate(input, (int(size*height), int(size*width)), mode='area')
        img_show = img.cpu()[0].transpose(0, 1).transpose(1, 2)
        img_out = (img_show * 255).clamp(0, 255).to(torch.uint8)
        display(PIL.Image.fromarray(img_out.cpu().numpy(), 'RGB'))
        pl.show()

def stylegan_to_rgb(input: torch.Tensor) -> torch.Tensor:
    return (input * 127.5 + 128) / 255

display_img(torch.rand(1, 3, 10, 10, requires_grad=False), 4)

In [None]:
#@title # Prompt and training parameters{ run: "auto" }
#@markdown Write your image prompt in the txt field below.

#@markdown Prompt suggestions:
#@markdown * "portrait painting of android from dystopic future by James Gurney"
#@markdown * "portrait of anime character in the style of studio ghibli | cute anime character"

txt = "eternal alien #film #eternity | trending on artstation | art" #@param {type:"string"}

# Training parameters
iterations = 400
grad_acc_steps = 1
batch_size = 1
lr = 2e-4
loss_scale = 100.
steps_show = 8
truncation_psi = 0.6
clamp_val = 1e-30
drop = 0.8

# Sampler
sample_size = 224
kernel_min = 1
kernel_max = 16
grid_size_min = 224
grid_size_max = 3*224
noise = 1.
noise_std = 0.3
cutout = 1.
cutout_size = 0.25

network_pkl = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl'

In [None]:
#@title Train loop {vertical-output: true}
#@markdown Loading and fine-tuning the model.
#@markdown The image shown during training is displayed at half size.

device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
for p in G.parameters():
    p.requires_grad = True
c = None

# Training
txt_tok = structure.clip.tokenize(txt)
text_latent = perceptor.encode_text(txt_tok.to(device)).detach()
sampler = torch.jit.script(
              structure.sample.ImgSampleStylegan(kernel_min=kernel_min,
                                                 kernel_max=kernel_max,
                                                 grid_size_min=grid_size_min,
                                                 grid_size_max=grid_size_max,
                                                 noise=noise,
                                                 noise_std=noise_std,
                                                 cutout=cutout,
                                                 cutout_size=cutout_size).to(device)
          )
optimizer = structure.optim.ClampSGD(G.parameters(),
                                     lr=lr, 
                                     clamp=clamp_val,
                                     drop=drop)

print('Generating image.')
for i in range(iterations):

    if (i + 1) % steps_show == 0:
        with torch.no_grad():
            clear_output(True)
            z = torch.randn([1, G.z_dim], device=device)          
            img = G(z, c, truncation_psi)
            img = stylegan_to_rgb(img)
            display_img(img, 0.5)
            print(i, 
                  loss.item(),
                  img.min().item(), 
                  img.max().item(),) 
    
    for j in range(grad_acc_steps):
        optimizer.zero_grad()
        z = torch.randn([1, G.z_dim], device=device)
        img = G(z, c, truncation_psi)
        img = stylegan_to_rgb(img)
        img = sampler(img, size=sample_size, bs=batch_size)
        img = normalize_image(img)
        img_latents = perceptor.encode_image(img)
        loss = torch.cosine_similarity(text_latent, img_latents, dim=-1).mean().neg() * loss_scale
        
        loss.backward()

    optimizer.step()

In [None]:
#@title Generate images from the fine-tuned model

with torch.no_grad():
    clear_output(True)
    z = torch.randn([1, G.z_dim], device=device)
    img = G(z, c, truncation_psi)
    img = stylegan_to_rgb(img)
    display_img(img, 1.)