<a href="https://colab.research.google.com/github/eyaler/clip_biggan/blob/main/WanderCLIP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# BigGAN + CLIP + CMA-ES

[j.mp/bigclip](https://j.mp/bigclip)

By Eyal Gruss [@eyaler](https://twitter.com/eyaler) [eyalgruss.com](https://eyalgruss.com)

Based on SIREN+CLIP Colabs by: [@advadnoun](https://twitter.com/advadnoun), [@norod78](https://twitter.com/norod78)

Other CLIP notebooks: [OpenAI tutorial](https://colab.research.google.com/github/openai/clip/blob/master/Interacting_with_CLIP.ipynb), [SIREN by @advadnoun](https://colab.research.google.com/drive/1FoHdqoqKntliaQKnMoNs3yn5EALqWtvP), [SIREN by @norod78](https://colab.research.google.com/drive/1K1vfpTEvAmxW2rnhAaALRVyis8EiLOnD), [BigGAN by @advadnoun](https://colab.research.google.com/drive/1NCceX2mbiKOSlAd_o7IU7nA9UskKN5WR), [BigGAN by @eyaler](j.mp/bigclip), [BigGAN by @tg_bomze](https://colab.research.google.com/github/tg-bomze/collection-of-notebooks/blob/master/Text2Image_v2.ipynb), [BigGAN using big-sleep library by @lucidrains](https://colab.research.google.com/drive/1MEWKbm-driRNF8PrU7ogS5o3se-ePyPb), [BigGAN story hallucinator by @bonkerfield](https://colab.research.google.com/drive/1jF8pyZ7uaNYbk9ZiVdxTOajkp8kbmkLK), [StyleGAN2-ADA Anime by @nagolinc](https://colab.research.google.com/github/nagolinc/notebooks/blob/main/TADNE_and_CLIP.ipynb)

Using the works:

https://github.com/openai/CLIP

https://tfhub.dev/deepmind/biggan-deep-512

https://github.com/huggingface/pytorch-pretrained-BigGAN

http://www.aiartonline.com/design-2019/eyal-gruss (WanderGAN)

For a curated list of more online generative tools see: [j.mp/generativetools](https://j.mp/generativetools)


In [None]:
#@title Restart after running this cell!

!nvidia-smi -L

import subprocess

CUDA_version = [s for s in subprocess.check_output(["nvcc", "--version"]).decode("UTF-8").split(", ") if s.startswith("release")][0].split(" ")[-1]
print("CUDA version:", CUDA_version)

if CUDA_version == "10.0":
    torch_version_suffix = "+cu100"
elif CUDA_version == "10.1":
    torch_version_suffix = "+cu101"
elif CUDA_version == "10.2":
    torch_version_suffix = ""
else:
    torch_version_suffix = "+cu110"

!pip install torch==1.7.1{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html ftfy regex

In [None]:
#@title Setup
!pip install pytorch-pretrained-biggan
from pytorch_pretrained_biggan import BigGAN
gan_model = BigGAN.from_pretrained('biggan-deep-512').cuda().eval()

%cd /content
!git clone --depth 1 https://github.com/openai/CLIP
!pip install ftfy
%cd /content/CLIP
import clip
models = clip.available_models()
perceptor={}
preprocess={}
for model in models:
  perceptor[model], preprocess[model] = clip.load(model)

import nltk
nltk.download('wordnet')

!pip install cma

In [None]:
#@title Generate!
#@markdown 1. For **prompt** OpenAI suggest to use the template "A photo of a X." or "A photo of a X, a type of Y." [[paper]](https://cdn.openai.com/papers/Learning_Transferable_Visual_Models_From_Natural_Language_Supervision.pdf)
#@markdown 2. For **initial_class** you can either use free text or select a special option from the drop-down list.
#@markdown 3. Free text and 'From prompt' might fail to find an appropriate ImageNet class.
#@markdown 4. **seed**=0 means no seed.
prompt = 'A photo of a rainbow unicorn.' #@param {type:'string'}
initial_class = 'Random mix' #@param ['From prompt', 'Random class', 'Random Dirichlet', 'Random mix'] {allow-input: true}
optimize_class = True #@param {type:'boolean'}
class_smoothing = 0.1 #@param {type:'number'}
truncation = 1 #@param {type:'number'}
stochastic_truncation = False #@param {type:'boolean'}
optimizer = 'CMA-ES' #@param ['SGD','Adam','CMA-ES','CMA-ES+SGD','CMA-ES+Adam']
pop_size = 50 #@param {type:'integer'}
model = 'ViT-B/32' #@param ['ViT-B/32','RN50']
augmentations =  64#@param {type:'integer'}
learning_rate = 0.1 #@param {type:'number'}
class_ent_reg = 0.0001 #@param {type:'number'}
iterations = 100 #@param {type:'integer'}
save_every = 1 #@param {type:'integer'}
fps = 1 #@param {type:'number'}
freeze_secs = 0 #@param {type:'number'}
seed = 0 #@param {type:'number'}
if seed == 0:
  seed = None
if 'CMA' not in optimizer:
  pop_size = 1

!rm -rf /content/output
!mkdir -p /content/output

import numpy as np
state = None if not seed else np.random.RandomState(seed)
np.random.seed(seed)
import torch
import torchvision
import sys
torch.manual_seed(np.random.randint(sys.maxsize))
import imageio
from IPython.display import HTML, Image, clear_output
from scipy.stats import truncnorm, dirichlet
from pytorch_pretrained_biggan import convert_to_images, one_hot_from_names
from base64 import b64encode
from time import time
import cma
from cma.sigma_adaptation import CMAAdaptSigmaCSA, CMAAdaptSigmaTPA

im_shape = [512, 512, 3]
sideX, sideY, channels = im_shape

def save(out,name):
  with torch.no_grad():
    out = out.cpu().numpy()
  img = convert_to_images(out)[0]
  imageio.imwrite(name, np.asarray(img))

def checkin(i, best_ind, total_losses, losses, regs, values, out):
  global sample_num
  name = '/content/output/frame_%05d.jpg'%sample_num
  save(out,name)
  clear_output()
  display(Image(name))  
  best = values[best_ind]
  inds = np.argsort(best)[::-1]
  values = np.array(values)
  print('sample=%d iter=%d best: total=%.2f cos=%.2f reg=%.3f avg: total=%.2f cos=%.2f reg=%.3f std: total=%.2f cos=%.2f reg=%.3f 1st_class=%s (%.2f) 2nd_class=%s (%.2f) 3rd_class=%s (%.2f) components: >=0.5=%.0f, >=0.3=%.0f, >=0.1=%.0f'%(sample_num+1, i+1, total_losses[best_ind], losses[best_ind], regs[best_ind], np.mean(total_losses), np.mean(losses), np.mean(regs), np.std(total_losses), np.std(losses), np.std(regs), inds[0], best[inds[0]], inds[1], best[inds[1]], inds[2], best[inds[2]], np.sum(values >= 0.5)/pop_size,np.sum(values >= 0.3)/pop_size,np.sum(values >= 0.1)/pop_size))
  sample_num += 1

noise_vector = truncnorm.rvs(-2*truncation, 2*truncation, size=(pop_size, 128), random_state=state).astype(np.float32) #see https://github.com/tensorflow/hub/issues/214

if initial_class.lower()=='random class':
  class_vector = np.ones(shape=(pop_size,1000), dtype=np.float32)*class_smoothing/999
  class_vector[0,np.random.randint(1000)] = 1-class_smoothing
elif initial_class.lower()=='random dirichlet':
  class_vector = dirichlet.rvs([pop_size/1000] * 1000, size=1, random_state=state).astype(np.float32)
elif initial_class.lower()=='random mix':
  class_vector = np.random.rand(pop_size,1000).astype(np.float32)
else:
  if initial_class.lower()=='from prompt':
    initial_class = prompt
  try:
    class_vector = None
    class_vector = one_hot_from_names(initial_class, batch_size=pop_size)
    assert class_vector is not None
    class_vector = class_vector*(1-class_smoothing*1000/999)+class_smoothing/999
  except Exception as e:  
    print('Error: could not find initial_class. Try something else.')
    raise e

eps = 1e-8
class_vector = np.log(class_vector+eps)
noise_vector = torch.tensor(noise_vector, requires_grad='SGD' in optimizer or 'Adam' in optimizer, device='cuda')
class_vector = torch.tensor(class_vector, requires_grad='SGD' in optimizer or 'Adam' in optimizer, device='cuda')

if 'SGD' in optimizer or 'Adam' in optimizer:
  params = [noise_vector]
  if optimize_class:
    params = params + [class_vector]
  if 'SGD' in optimizer:
    optim = torch.optim.SGD(params, lr=learning_rate)  
  else:
    optim = torch.optim.Adam(params, lr=learning_rate)

tx = clip.tokenize(prompt)
with torch.no_grad():
  target_clip = perceptor[model].encode_text(tx.cuda())

def get_output(noise_vector, class_vector):
  if stochastic_truncation:
    with torch.no_grad():
      trunc_indices = torch.abs(noise_vector) > 2*truncation
      size = torch.count_nonzero(trunc_indices).cpu().numpy()
      trunc = truncnorm.rvs(-2*truncation, 2*truncation, size=(1,size)).astype(np.float32)
      noise_vector.data[trunc_indices] = torch.tensor(trunc, requires_grad='SGD' in optimizer or 'Adam' in optimizer, device='cuda')
  else:
    noise_vector = noise_vector.clamp(-2*truncation, 2*truncation)
  class_vector_norm = class_vector.softmax(dim=-1)
  return gan_model(noise_vector, class_vector_norm, truncation), class_vector_norm

res = perceptor[model].input_resolution.item()
smoothed_ent = -torch.tensor(class_smoothing*np.log(class_smoothing/999+eps)+(1-class_smoothing)*np.log(1-class_smoothing+eps), dtype=torch.float32).cuda()
def ascend_txt(i, grad_step=False, show_save=False):
  prev_class_vector_norms = []
  regs = []
  losses = []
  total_losses = []
  best_loss = np.inf
  for j in range(pop_size):
    out, class_vector_norm = get_output(noise_vector[j:j+1], class_vector[j:j+1])
    with torch.no_grad():
      prev_class_vector_norms.append(class_vector_norm.cpu().numpy()[0])
    p_s = []
    fixed_out = (out+1)/2
    for ch in range(augmentations):
      size = torch.randint(int(.5*sideX), int(.98*sideX), ())
      #size = int(sideX*torch.zeros(1,).normal_(mean=.8, std=.3).clip(.5, .95))
      offsetx = torch.randint(0, sideX - size, ())
      offsety = torch.randint(0, sideX - size, ())
      apper = fixed_out[:, :, offsetx:offsetx + size, offsety:offsety + size]
      apper = torch.nn.functional.interpolate(apper, res, mode='bicubic')
      apper = apper.clamp(0,1)
      p_s.append(apper)
    into = nom(torch.cat(p_s, 0))
    predict_clip = perceptor[model].encode_image(into)
    factor = 100
    loss = factor*(1-torch.cosine_similarity(predict_clip, target_clip).mean())
    total_loss = loss
    if optimize_class and class_ent_reg:
      reg = factor*class_ent_reg*((-class_vector_norm*torch.log(class_vector_norm+eps)).sum()-smoothed_ent).abs()
      total_loss = total_loss + reg
      with torch.no_grad():
        regs.append(reg.item())
    with torch.no_grad():
      losses.append(loss.item())
      total_losses.append(total_loss.item())
    if total_losses[-1]<best_loss:
      best_loss = total_losses[-1]
      best_ind = j
      best_out = out
    if grad_step:    
      optim.zero_grad()
      total_loss.backward()
      optim.step()
      
  if show_save and (i == iterations-1 or i % save_every == 0):
    if i==iterations-1:
      save(best_out,'/content/%s.jpg'%prompt)  
    if i % save_every == 0:
      checkin(i, best_ind, total_losses, losses, regs, prev_class_vector_norms, best_out)  
  return total_losses

nom = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
if 'CMA' in optimizer:
  cma_opts = {'popsize': pop_size, 'seed': np.nan, 'AdaptSigma': True, 'CMA_diagonal': True}
  cmaes = cma.CMAEvolutionStrategy([0]*(1128 if optimize_class else 128), 1, inopts=cma_opts)

sample_num = 0
machine = !nvidia-smi -L
start = time()
for i in range(iterations):    
  if 'CMA' in optimizer:
    with torch.no_grad():
      cma_results = torch.tensor(cmaes.ask(), dtype=torch.float32).cuda()
      if optimize_class:
        noise_vector.data, class_vector.data = torch.split_with_sizes(cma_results, (128,1000), dim=-1)
      else:
        noise_vector.data = cma_results
      losses = ascend_txt(i, show_save='SGD' not in optimizer and 'Adam' not in optimizer)
  if 'SGD' in optimizer or 'Adam' in optimizer:
    losses = ascend_txt(i, grad_step=True, show_save=True)
    assert noise_vector.requires_grad and noise_vector.is_leaf and (class_vector.requires_grad and class_vector.is_leaf or not optimize_class), (noise_vector.requires_grad, noise_vector.is_leaf, class_vector.requires_grad, class_vector.is_leaf)
  if 'CMA' in optimizer:
    with torch.no_grad():
      if optimize_class:
        vectors = torch.cat([noise_vector,class_vector], dim=1)
      else:
        vectors = noise_vector
      cmaes.tell(vectors.cpu().numpy(), losses)
  print('took: %d secs (%.2f sec/iter) on %s'%(time()-start,(time()-start)/(i+1), machine[0]))

from google.colab import files, output
files.download('/content/%s.jpg'%prompt)

out = '"/content/%s.mp4"'%prompt
with open('/content/list.txt','w') as f:
  for i in range(iterations):
    f.write('file /content/output/frame_%05d.jpg\n'%i)
  for j in range(int(freeze_secs*fps)):
    f.write('file /content/output/frame_%05d.jpg\n'%i)
!ffmpeg -r $fps -f concat -safe 0 -i /content/list.txt -c:v libx264 -pix_fmt yuv420p -profile:v baseline -movflags +faststart -r $fps $out -y
with open('/content/%s.mp4'%prompt, 'rb') as f:
  data_url = "data:video/mp4;base64," + b64encode(f.read()).decode()
display(HTML("""
  <video controls autoplay loop>
        <source src="%s" type="video/mp4">
  </video>""" % data_url))

from google.colab import files, output
output.eval_js('new Audio("https://freesound.org/data/previews/80/80921_1022651-lq.ogg").play()')
files.download('/content/%s.mp4'%prompt)