<a href="https://colab.research.google.com/github/eyaler/clip_biggan/blob/main/ClipBigGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# BigGAN + CLIP

by https://twitter.com/eyaler

based on colabs by:

https://twitter.com/advadnoun

https://twitter.com/norod78

based on the works:

https://github.com/openai/CLIP

https://tfhub.dev/deepmind/biggan-deep-512

https://github.com/huggingface/pytorch-pretrained-BigGAN



In [None]:
#@title Restart after running this cell!

import subprocess

CUDA_version = [s for s in subprocess.check_output(["nvcc", "--version"]).decode("UTF-8").split(", ") if s.startswith("release")][0].split(" ")[-1]
print("CUDA version:", CUDA_version)

if CUDA_version == "10.0":
    torch_version_suffix = "+cu100"
elif CUDA_version == "10.1":
    torch_version_suffix = "+cu101"
elif CUDA_version == "10.2":
    torch_version_suffix = ""
else:
    torch_version_suffix = "+cu110"

!pip install torch==1.7.1{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html ftfy regex

In [None]:
#@title Setup
!pip install pytorch-pretrained-biggan
from pytorch_pretrained_biggan import BigGAN
model = BigGAN.from_pretrained('biggan-deep-512').cuda()

%cd /content
!git clone --depth 1 https://github.com/openai/CLIP
!pip install ftfy
%cd /content/CLIP
import clip
perceptor, preprocess = clip.load('ViT-B/32')

import nltk
nltk.download('wordnet')


In [None]:
#@title Optimize!
#@markdown note for initial_class you can also enter free text
prompt = 'Green mushroom' #@param {type:'string'}
initial_class = 'mushroom' #@param ['from prompt', 'random class', 'random dirichlet', 'random mix'] {allow-input: true}
optimize_class = True #@param {type:'boolean'}
truncation = 1 #@param {type:'number'}
learning_rate =  0.1#@param {type:'number'}
class_ent_reg =  0.1#@param {type:'number'}
iterations = 500 #@param {type:'integer'}
save_every = 1 #@param {type:'integer'}
fps = 30 #@param {type:'integer'}

!rm -rf /content/output
!mkdir -p /content/output

import torch
import torchvision
import numpy as np
import imageio
from IPython.display import HTML, Image, clear_output
from scipy.stats import truncnorm, dirichlet
from pytorch_pretrained_biggan import convert_to_images, one_hot_from_names
from base64 import b64encode

im_shape = [512, 512, 3]
sideX, sideY, channels = im_shape

def save(out,name):
  with torch.no_grad():
    al = out.cpu().numpy()
  img = convert_to_images(al)[0]
  imageio.imwrite(name, np.asarray(img))

def checkin(total_loss, loss, reg, values, out):
  global sample_num
  name = '/content/output/frame_%05d.jpg'%sample_num
  save(out,name)
  clear_output()
  display(Image(name))
  print('%d: total=%.1f cos=%.1f reg=%.1f components: >=0.5=%d, >=0.3=%d, >=0.1=%d\n'%(sample_num, total_loss, loss, reg,np.sum(values >= 0.5),np.sum(values >= 0.3),np.sum(values >= 0.1)))
  
  sample_num +=1

seed = None
state = None if seed is None else np.random.RandomState(seed)
np.random.seed(seed)
noise_vector = truncnorm.rvs(-2*truncation, 2*truncation, size=(1, 128), random_state=state).astype(np.float32) #see https://github.com/tensorflow/hub/issues/214

if initial_class=='random class':
  class_vector = np.zeros(shape=(1,1000), dtype=np.float32)
  class_vector[0,np.random.randint(1000)] = 1
elif initial_class=='random dirichlet':
  class_vector = dirichlet.rvs([1/1000] * 1000, size=1, random_state=state).astype(np.float32)
elif initial_class=='random mix':
  class_vector = np.random.rand(1,1000).astype(np.float32)
else:
  if initial_class=='from prompt':
    initial_class = prompt
  try:
    class_vector = one_hot_from_names(initial_class, batch_size=1)
    assert class_vector is not None
  except Exception:  
    print('Error: could not find initial_class. Try something else.')
eps=1e-8
class_vector = np.log(class_vector+eps)

# All in tensors
noise_vector = torch.tensor(noise_vector, requires_grad=True, device='cuda')
class_vector = torch.tensor(class_vector, requires_grad=True, device='cuda')

params = [noise_vector]
if optimize_class:
  params += [class_vector]
optimizer = torch.optim.Adam(params, lr=learning_rate)

tx = clip.tokenize(prompt)
with torch.no_grad():
  target_clip = perceptor.encode_text(tx.cuda())

def ascend_txt(i):
  noise_vector_trunc = noise_vector.clamp(-2*truncation,2*truncation)
  class_vector_norm = torch.nn.functional.softmax(class_vector)
  out = model(noise_vector_trunc, class_vector_norm, truncation)
  if i==iterations-1:
    save(out,'/content/%s.jpg'%prompt)
  cutn = 64
  p_s = []
  for ch in range(cutn):
    size = torch.randint(int(.5*sideX), int(.98*sideX), ())
    offsetx = torch.randint(0, sideX - size, ())
    offsety = torch.randint(0, sideX - size, ())
    apper = out[:, :, offsetx:offsetx + size, offsety:offsety + size]
    apper = torch.nn.functional.interpolate(apper, (224,224), mode='bilinear')
    p_s.append(nom(apper))
  into = torch.cat(p_s, 0)

  predict_clip = perceptor.encode_image(into)
  factor = 100
  loss = factor*(1-torch.cosine_similarity(predict_clip, target_clip, dim=-1).mean())
  total_loss = loss
  reg = torch.tensor(0., requires_grad=True)
  if optimize_class and class_ent_reg:
    reg = -factor*class_ent_reg*(class_vector_norm*torch.log(class_vector_norm+eps)).sum()
    total_loss += reg
  if i % save_every == 0:
    checkin(total_loss.item(),loss.item(),reg.item(),class_vector_norm.detach().cpu().numpy(),out)
  return total_loss

nom = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

sample_num = 0
for i in range(iterations):    
  loss = ascend_txt(i)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

out = '"/content/%s.mp4"'%prompt
!ffmpeg -framerate $fps -i /content/output/frame_%05d.jpg -c:v libx264 -pix_fmt yuv420p -profile:v baseline -movflags +faststart $out -y
with open('/content/%s.mp4'%prompt, 'rb') as f:
  data_url = "data:video/mp4;base64," + b64encode(f.read()).decode()
clear_output()
display(HTML("""
  <video controls autoplay loop>
        <source src="%s" type="video/mp4">
  </video>""" % data_url))

from google.colab import files, output
output.eval_js('new Audio("https://freesound.org/data/previews/80/80921_1022651-lq.ogg").play()')
files.download('/content/%s.jpg'%prompt)
files.download('/content/%s.mp4'%prompt)