<a href="https://colab.research.google.com/github/bjin2364/mit-deep-learning/blob/main/VQGAN%2BCLIP_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Generate images from text prompts with VQGAN and CLIP

Tutorial by Phillip Isola

[CLIP paper](https://arxiv.org/abs/2103.00020)<br>
[VQGAN paper](https://arxiv.org/abs/2012.09841)

This tutorial is a simplified version of the colab linked here: https://colab.research.google.com/drive/1ZAus_gn2RhTZWzOWUpPERNC0Q8OhZRTZ

Contributors to original colab seem to be Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings), https://twitter.com/advadnoun, Eleiber#8347, Crimeacs#8222 (https://twitter.com/EarthML1), and Abulafia#3734.

<br>
<hr>
MIT License included in the colab from which this was derived:

Copyright (c) 2021 Katherine Crowson

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

In [None]:
# Clone CLIP and VQGAN code
!git clone https://github.com/openai/CLIP
!git clone https://github.com/CompVis/taming-transformers.git

# install some helpful packages
!pip install ftfy regex tqdm omegaconf pytorch-lightning
!pip install kornia
!pip install imageio-ffmpeg   
!pip install einops

# make a directory to save videos
!mkdir steps

In [None]:
# Download VQGAN generator trained on ImageNet
vqgan_model_name = "vqgan_imagenet_f16_16384"
!curl -L -o vqgan_imagenet_f16_16384.yaml -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.yaml' #ImageNet 16384
!curl -L -o vqgan_imagenet_f16_16384.ckpt -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.ckpt' #ImageNet 16384

In [None]:
# @title Load libraries and variables

import argparse
import math
from pathlib import Path
import sys

sys.path.insert(1, '/content/taming-transformers')
from IPython import display
from base64 import b64encode
from omegaconf import OmegaConf
from PIL import Image
from taming.models import cond_transformer, vqgan
import taming.modules 
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.auto import tqdm

from CLIP import clip
import kornia.augmentation as K
import numpy as np
import imageio
from PIL import ImageFile, Image
ImageFile.LOAD_TRUNCATED_IMAGES = True

# the following functions are for dealing with backprop through the discrete latent variables in VQGAN
class ReplaceGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x_forward, x_backward):
        ctx.shape = x_backward.shape
        return x_forward

    @staticmethod
    def backward(ctx, grad_in):
        return None, grad_in.sum_to_size(ctx.shape)

replace_grad = ReplaceGrad.apply

class ClampWithGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, min, max):
        ctx.min = min
        ctx.max = max
        ctx.save_for_backward(input)
        return input.clamp(min, max)

    @staticmethod
    def backward(ctx, grad_in):
        input, = ctx.saved_tensors
        return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None

clamp_with_grad = ClampWithGrad.apply

def vector_quantize(x, codebook):
    d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
    indices = d.argmin(-1)
    x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
    return replace_grad(x_q, x)

# MakeCutouts returns a cutn number of patches, of size cut_size, cropped from the image 
#  its foward method is applied to.
class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn

        self.augs = nn.Sequential(
            K.RandomAffine(degrees=15, translate=0.1, p=0.7, padding_mode='border'),
            K.RandomPerspective(0.7,p=0.7),
            K.ColorJitter(hue=0.1, saturation=0.1, p=0.7),
            K.RandomErasing((.1, .4), (.3, 1/.3), same_on_batch=True, p=0.7))
        self.noise_fac = 0.1
        self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
        self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        
        for _ in range(self.cutn):
            cutout = (self.av_pool(input) + self.max_pool(input))/2
            cutouts.append(cutout)
        batch = self.augs(torch.cat(cutouts, dim=0))
        if self.noise_fac:
            facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
            batch = batch + facs * torch.randn_like(batch)
        return batch

def load_vqgan_model(config_path, checkpoint_path):
    config = OmegaConf.load(config_path)
    if config.model.target == 'taming.models.vqgan.VQModel':
        model = vqgan.VQModel(**config.model.params)
        model.eval().requires_grad_(False)
        model.init_from_ckpt(checkpoint_path)
    elif config.model.target == 'taming.models.vqgan.GumbelVQ':
        model = vqgan.GumbelVQ(**config.model.params)
        model.eval().requires_grad_(False)
        model.init_from_ckpt(checkpoint_path)
    elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':
        parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
        parent_model.eval().requires_grad_(False)
        parent_model.init_from_ckpt(checkpoint_path)
        model = parent_model.first_stage_model
    else:
        raise ValueError(f'unknown model type: {config.model.target}')
    del model.loss
    return model




## Settings for this run:
Modify the text in prompt to make an image that depicts that text.

In [None]:
#@title Parameters
#A Van Gogh painting of the Stata Center at MIT
#"A Barcelona tile mosaic of the Stata Center at MIT"
prompt = "Mulan hugging Michael Phelps." #@param {type:"string"}
width =  256#@param {type:"number"}
height = 256#@param {type:"number"}
seed = 7#@param {type:"number"}
max_iterations = 700#@param {type:"number"}

args = argparse.Namespace(
    prompt=prompt,
    size=[width, height],
    clip_model='ViT-B/32',
    vqgan_config=f'{vqgan_model_name}.yaml',
    vqgan_checkpoint=f'{vqgan_model_name}.ckpt',
    lr=0.1,
    cutn=32,
    display_freq=50,
    seed=seed,
)

In [None]:
# setup gpu and random seed
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(args.seed)

# load models
vqgan_model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)
clip_model = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)

## some setup for clip and vqgan

# number of image crops to analyze the loss on
cut_size = clip_model.visual.input_resolution
make_cutouts = MakeCutouts(cut_size, args.cutn)

# parameters of the vqgan latent code
f = 2**(vqgan_model.decoder.num_resolutions - 1)
toksX, toksY = args.size[0] // f, args.size[1] // f
sideX, sideY = toksX * f, toksY * f

e_dim = vqgan_model.quantize.e_dim
n_toks = vqgan_model.quantize.n_e
z_min = vqgan_model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
z_max = vqgan_model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]

## we will be optimizing over the latent code z that is the input to the VQGAN generator

# we start with a random set of values for z
one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float()
z = one_hot @ vqgan_model.quantize.embedding.weight
z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2) 
z = torch.rand_like(z)*2

# then we set up the Adam optimizer to optimize over z
z.requires_grad_(True)
optim = torch.optim.Adam([z], lr=args.lr)

# we will use normalize to preprocess the images before processing them with clip
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                  std=[0.26862954, 0.26130258, 0.27577711])

# synthesize an image from latent code z using the vqgan generator
def synth(z):
    z_q = vector_quantize(z.movedim(1, 3), vqgan_model.quantize.embedding.weight).movedim(3, 1)
    return clamp_with_grad(vqgan_model.decode(z_q).add(1).div(2), 0, 1)

# clip embedding for the text prompt (this will serve as our target for the clip embedding of the generated image)
text_embedding = clip_model.encode_text(clip.tokenize(args.prompt).to(device)).float()
text_embedding_normed = F.normalize(text_embedding, dim=1)


def optimize_z(optim,z):

    for i in tqdm(range(max_iterations)):
      optim.zero_grad()

      # generate image from z vector
      img = synth(z)
      
      # measure loss as similarity between clip embedding of the generated image and clip embedding of the text prompt
      img_embedding = clip_model.encode_image(normalize(make_cutouts(img))).float()
      img_embedding_normed = F.normalize(img_embedding, dim=1)
      loss = img_embedding_normed.sub(text_embedding_normed).norm(dim=1).div(2).arcsin().pow(2).mean()
      
      # do a step of gradient descent on z
      loss.backward()
      optim.step()
      with torch.no_grad():
        z.copy_(z.maximum(z_min).minimum(z_max))

      # write the output image to disk
      img = np.array(img.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]
      img = np.transpose(img, (1, 2, 0))
      imageio.imwrite('./steps/' + str(i) + '.png', np.array(img))

      # display it every now and then
      if i % args.display_freq == 0:
          print('iter {}'.format(i))
          display.display(display.Image('./steps/' + str(i) + '.png'))

In [None]:
optimize_z(optim,z)

In [None]:
#@title Generate a video with the result

init_frame = 1 #This is the frame where the video will start
last_frame = max_iterations #You can change i to the number of the last frame you want to generate. It will raise an error if that number of frames does not exist.

min_fps = 10
max_fps = 60

total_frames = last_frame-init_frame

length = 15 #Desired time of the video in seconds

frames = []
tqdm.write('Generating video...')
for i in range(init_frame,last_frame): #
    frames.append(Image.open("./steps/"+ str(i) +'.png'))

#fps = last_frame/10
fps = np.clip(total_frames/length,min_fps,max_fps)

from subprocess import Popen, PIPE
p = Popen(['ffmpeg', '-y', '-f', 'image2pipe', '-vcodec', 'png', '-r', str(fps), '-i', '-', '-vcodec', 'libx264', '-r', str(fps), '-pix_fmt', 'yuv420p', '-crf', '17', '-preset', 'veryslow', 'video.mp4'], stdin=PIPE)
for im in tqdm(frames):
    im.save(p.stdin, 'PNG')
p.stdin.close()
p.wait()
mp4 = open('video.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
display.HTML("""
<video width=400 controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)