# VQ-GAN+CLIP 

***Author***: Luis Leal

This notebook implements the vqgan+clip architecture used to generate images conditioned on a text prompt(the paper says it can be used for image editing too and the only change is to start with the image to be edited instead of a random initial image). 

This notebook is based on the paper [VQGAN-CLIP: Open Domain Image Generation and Editing with Natural Language Guidance](https://arxiv.org/abs/2204.08583) and the idea is to combine 2 models initially developed separetly.

* **CLIP(contrastive language-image pre-training)**: originally developed to find the text that best matches an image(and visceversa) via contrastive training by embedding both images and text to a common latent space where semantics are maintained(an image of a dog will have an embedding similar to the text 'dog').
    <img src="CLIP_diagram.png">
    
    Paper: 
    [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/pdf/2103.00020.pdf)
* **VQ-GAN(vector quantized GAN)**:  vector quantized  GAN, an improvement of the [vector quantized variational autoencoder](https://arxiv.org/pdf/1711.00937.pdf) where the  latent space is discretized via a codebook in order to learn discrete neural representations(instead of an infinite space) for modalities where it makes sense by nature, for example text.In VQ-GAN images are a composition of codebook vectors(we can think of the codebook as a toolbox of a finite set of visual features) and generation of images is performed via a transformer that models sequences of codebook vectors. This architecture is used to generate images but not conditioned on any text, as original GANS.
<img src="VQGAN_diagram.png">

Paper: 
    [Taming Transformers for High-Resolution Image Synthesis](https://arxiv.org/pdf/2012.09841.pdf)

**The high level idea of this mix is:**

Given a text prompt(and some additional text context described later) the VQGAN will generate an image(initially random) and CLIP will evaluate how good the image corresponds to the text prompt via cosine similarity, this evaluation corresponds to the loss function of the model and the goal is to maximize this similarity(or minimize the negative of it).

The VQGAN paper uses a Transformer for generating new images because the transformer models sequences of codebook vectors, for this vqgan+clip architecture the transformer is not used at all because we replace the sequence of codebook vectors by vectors obtained via gradient descent by calculating gradients of the CLIP loss function with respect to these vectors. In other words: instead of the transformer giving us the latent vectors to input to the VQGAN decoder we input vectors obtained via CLIP comparisons between current image embedding and text prompt embedding so CLIP guides the generation.

**NOTE**: this notebook uses pretrained models for both CLIP and VQ-GAN, they will not be trained and will only download them from github.

**note**: some of the previous installs messed my pytorch installation and cuda did not work, had to reinstall

In [None]:
import sys
sys.path.append("taming_transformers/")

import requests
import os
import importlib

import numpy as np
import matplotlib.pyplot as plt

import torch
import imageio
import  math
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as tf

from  taming_transformers.taming.models.vqgan import VQModel

import PIL

import yaml
from omegaconf import  OmegaConf

from CLIP import clip

# import warnings
# warnings.filterwarnings("ignore")

In [None]:
def show_from_tensor(tensor):
    """show image"""
    img = tensor.clone()
    img = img.mul(255).byte()
    img = img.cpu().numpy().transpose((1,2,0))

    plt.figure(figsize=(10,7))
    plt.axis("off")
    plt.imshow(img)
    plt.show()
    
def norm_data(data):
    return (data.clip(-1,1)+1)/2 # range between 0 and 1

In [None]:
DOWNLOAD_MODELS = False

In [None]:
DEVICE = "cuda:0"

In [None]:
LEARNING_RATE = 0.5
BATCH_SIZE = 1
WEIGHT_DECAY = 0.1 # for optimizer regularization
NOISE_FACTOR = 0.22 # for inject noise when creating crops

TOTAL_ITERATIONS = 400
IMAGE_SHAPE = [400, 400, 3] # height, widgth, channel
height, width, channels = IMAGE_SHAPE

## Setup CLIP

Used to encode both text prompts, and generated images to compare their similarity.
Goal is to make the generated images to become closer and closer to the text prompts

In [None]:
clipmodel,_= clip.load("ViT-B/32",jit = False)  #download pre-trained clip(ViT is visual transformer)
clipmodel = clipmodel.to(DEVICE)
clipmodel.eval() # will not train CLIP just do inference

print("avaliable CLIP modelsL",clip.available_models())
print("CLIP model image resolution:",clipmodel.visual.input_resolution)

In [None]:
torch.cuda.empty_cache()

### Setup VQ-GAN(from Taiming Transformers)

The transformer part of the architecture is not used because in the original vq-gan paper the transformer generates a sequence of zq vectors(from the codebook) to feed to the decoder, but in this vqgan+clip architecture  CLIP determines the zq vectors to use(before quantization) based on the text prompt, by minimizing the CLIP loss the correct zq vectors are used(instead of generated by the transformer).

In [None]:
if DOWNLOAD_MODELS:

    URL = "https://heibox.uni-heidelberg.de/f/867b05fc8c4841768640/?dl=1"
    target_dir =  "./taming_transformers/models/vqgan_imagenet_f16_16384/checkpoints/"

    if not os.path.exists(target_dir):
        os.makedirs(target_dir)

    response = requests.get(URL)
    open(f"{target_dir}last.ckpt", "wb").write(response.content)


In [None]:
if DOWNLOAD_MODELS:

    URL = "https://heibox.uni-heidelberg.de/f/274fb24ed38341bfa753/?dl=1"
    target_dir =  "./taming_transformers/models/vqgan_imagenet_f16_16384/configs/"

    if not os.path.exists(target_dir):
        os.makedirs(target_dir)

    response = requests.get(URL)
    open(f"{target_dir}model.yaml", "wb").write(response.content)

In [None]:
def load_config(config_path, display=False):
    config_data = OmegaConf.load(config_path)
    
    if display:
        print(yaml.dump(OmegaConf.to_container(config_data)))
    
    return config_data

def load_vqgan(config, checkpoint_path=None):
    model = VQModel(**config.model.params)
    
    if checkpoint_path is not None:
        state_dict = torch.load(checkpoint_path,map_location="cpu")["state_dict"]
        missing, unexpected = model.load_state_dict(state_dict, strict=False)
        
    return model.eval()


def generate(model,z_hat):
    """
    it passes the latent vector z_hat for current  candidate generation
    through the quantize + decode step of the vqgan to get a new image.
    
    In the original vqgan this z_hat is obtained:
    
    during training this is the output of the CNN Encoder E for training images
    during inference/generation the tranformer generates a sequence of quantized vectors so z_hat is not 
    used because it's quantized version is generated by a Transformer.
    
    in this vqgan+clip paper this z_hat is the candiate image generated by backpropagating through clip,
    so we can say we are replacing the CNN Encoder of the original VQGAN with our CLIP guided generations.
    
    Instead of generating with the transformer we will: generate a z_hat guided by CLIP, cuantize it and decode it.
    
    The vqgan+clip architecture differentiates the CLIP loss with respect to the z_hat vector(as parameter)
    meaning it finds the z_hat that minimizes the CLIP loss, so z_hat has to be a parameters tensor
    updated via training.
    """
    
    
    # uses the CNN Decoder(in the vqgan paper called G) to generate a new image
    # from a cuantized latent z hat
    #print(z_hat.shape)
    zquant, emb_loss, info = model.quantize(z_hat)
    z_q = model.post_quant_conv(zquant) 
    new = model.decoder(z_q) 
    return new

In [None]:
vqgan_config = load_config("./taming_transformers/models/vqgan_imagenet_f16_16384/configs/model.yaml",
                          display = True
                          )
vqgan_model = load_vqgan(vqgan_config, 
                        checkpoint_path="./taming_transformers/models/vqgan_imagenet_f16_16384/checkpoints/last.ckpt",
                        ).to(DEVICE)

Working with z of shape (1, 256, 16, 16) means (1, channels, patch height, patch width) because it works in patches

### Declare optimization parameters

In [None]:
class Parameters(torch.nn.Module):
    """
    The parameters that we will be optimizing the CLIP loss with respect to.
    These are latent vectors for the candidate images(before quantization in vqgan, called z_hat in the paper)
    """
    def __init__(self):
        super(Parameters, self).__init__()
        self.data = 0.5*torch.randn(BATCH_SIZE, 256, 
                                height//16, width//16).to(DEVICE) #1x256x14x15 (225/16, 400/16)
        self.data = torch.nn.Parameter(torch.sin(self.data)) #positional embedding for the transformer?
        
    def forward(self):
        return self.data
    
def init_params():
    params = Parameters().to(DEVICE)
    optimizer = torch.optim.AdamW([{"params":[params.data], "lr":LEARNING_RATE}],
                                weight_decay=WEIGHT_DECAY
                                )
    
    return params, optimizer

### Preprocessing:  
Encoding prompts

In [None]:
# since we are using pretrained models we need to match their normalization statistics  means and stds
normalize = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                                             (0.26862954, 0.26130258, 0.27577711))


    

In [None]:
def encodeText(text):
    """Create tokens consistent with CLIP"""
    new_text = clip.tokenize(text).to(DEVICE)
    text_encoding = clipmodel.encode_text(new_text).detach().clone()
    return text_encoding

def createEncodings(include, exclude, extras):
    """
    include: what images we want in the space(list)
    exclude: what we don't want in the space(comma separated string)
    extras: additional context or characteristics(comma separated string)
    """
    
    include_encodings = []
    
    
    for text in include:
        include_encodings.append(encodeText(text))
        
    exclude_encodings = encodeText(exclude) if exclude != "" else 0
    extras_encodings = encodeText(extras) if extras != "" else 0
    
    return include_encodings, exclude_encodings, extras_encodings

### Preprocessing:  
Image preprocessing
* first augmentations are applied to single image
* from augmentations output different crops are created

In [None]:
augment_transform = torch.nn.Sequential(
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomAffine(30, (0.2, 0.2), fill=0)
).to(DEVICE)

In [None]:
params, optimizer = init_params()

with torch.no_grad():
    print(params().shape)
    img  = norm_data(generate(vqgan_model, params()).cpu()) # 1 x 3 x 224 x 400
    print("img dimensions:", img.shape)
    show_from_tensor(img[0])

create crops

In [None]:
def create_crops(img, num_crops = 30):
    p =  height//2
    img = torch.nn.functional.pad(img,
                                  (p,p,p,p),
                                  mode="constant",value=0) # 1 x 3 x 448 x 624 (adding 112*2 on all sides)
    img = augment_transform(img) 
    
    crop_set = []
    for crop in range(num_crops):
        gap1=int(torch.normal(1.0,0.5,()).clip(0.2,1.5)*height)
        gap2 =int(torch.normal(1.0, 0.5, ()).clip(0.2,1.5)*height)
        offsetx = torch.randint(0, int(height*2-gap1),())
        offsety = torch.randint(0, int(height*2-gap1),())
        
        crop = img[:,:, offsetx:offsetx+gap2, offsety:offsety+gap2]
        
        crop = torch.nn.functional.interpolate(crop, (224, 224),
                                              mode="bilinear", align_corners=True)
        crop_set.append(crop)
        
    img_crops = torch.cat(crop_set,0) ## num_crops x 3 x 224 x 224
    img_crops = img_crops + NOISE_FACTOR*torch.randn_like(img_crops, requires_grad = False)
    
    return img_crops

In [None]:
def show_generated(params, show_crop):
    with torch.no_grad():
        generated = generate(vqgan_model,params())
        
        if show_crop:
            print("Augmented cropped example:")
            aug_gen = generated.float() # 1 x 3 x height x width
            aug_gen = create_crops(aug_gen, num_crops=1)
            aug_gen_norm = norm_data(aug_gen[0])
            show_from_tensor(aug_gen_norm)
            
        print("Generated:")
        latest_gen = norm_data(generated.cpu()) # 1 x 3 x height x width
        show_from_tensor(latest_gen[0])
        
    return latest_gen[0]

## Training

Optimize the parameters so the generated image matches the prompt

In [None]:
def optimize_result(params, prompt, include_enc, exclude_enc, extras_enc):
    # importance of the encodings(maybe should move this to function params or global constants?)
    alpha = 1.0 # importance of the include encodings
    beta = 0.5 #importance of the exclude encodings
    
    """
    to calculate the loss we compare image and text encodings
    we need a generated image to encode from current params
    """
    # generate a candidate image with vqgan
    output = generate(vqgan_model, params())
    output = norm_data(output)
    output = create_crops(output)
    output = normalize(output) # 30 x 3 x 224 x 224
    # encode the generated image with clip
    image_encoded = clipmodel.encode_image(output) # 30 x 512(for each of the 30 crops return a z vector of 512)
    
    # text encoding
    final_encoding = w1*prompt + w2*extras_enc
    final_text_include_enc = final_encoding/final_encoding.norm(dim=-1, keepdim=True) # 1x512
    final_text_exclude_enc = exclude_enc
    
    loss = torch.cosine_similarity(final_text_include_enc, image_encoded, -1) #30
    penalize_loss = torch.cosine_similarity(final_text_exclude_enc,image_encoded)
    
    # mazimize the similarity between include+extras  and image encodings(via minimze the negative)
    # and minimize exclude and image encodings
    final_loss = -alpha*loss + beta*penalize_loss
    
    return final_loss

def optimize(params, optimizer, prompt, include_enc, exclude_enc, extras_enc):
    loss = optimize_result(params, prompt, include_enc, exclude_enc, extras_enc).mean()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss



### training loop

In [None]:
def train(include_text, exclude_text, extras_text,
          params, optimizer, show_crop = False, capture_gen_every = -1):
    """
    capture_gen_every  = capture output frequency, -1 means only the last one
    """
    capture_gen_every = TOTAL_ITERATIONS - 1 if capture_gen_every == -1 else capture_gen_every
    
    result_img = [] # store images
    result_z = [] # store latent space/parameters
    
    
    include_enc, exclude_enc, extras_enc = createEncodings(include_text,exclude_text,extras_text)
    
    for i,prompt in enumerate(include_enc):
        print(f"\n-------------------STARTING: {include_text[i]} -------------------------")
        iteration = 0
        params, optimizer = init_params()# 1 x 256 x height/16 x width /16
        
        for iteration in range(TOTAL_ITERATIONS):
            loss = optimize(params, optimizer, prompt, include_enc, exclude_enc, extras_enc)
            
            if iteration > 0 and iteration%capture_gen_every==0:
                new_img = show_generated(params, show_crop)
                result_img.append(new_img)
                result_z.append(params().detach().cpu().numpy())
                
                print(f"prompt:{include_text[i]}")
                print(f"loss:{loss.item()} \niteration:{iteration}")
                
                
            iteration+=1
        print(f"-------------------ENDING: {include_text[i]} -------------------------")
        torch.cuda.empty_cache()
        
    return result_img, result_z

In [None]:
torch.cuda.empty_cache()

In [None]:
include = ['A forest with purple trees', 
           #'one  big elephant at the top of a mountain',
           'A painting of a pineapple in a bowl',
          "a wolf looking at the stars at night",
           #"a futuristic city in synthwave style"
          ] #["a boy in a mountain playing guitar for his dog"]
TOTAL_ITERATIONS = 200
exclude =  "watermark, cropped, confusing, incoherent, cut, blurry"
extras = ""#"watercolor paper texture"

w1=1
w2=0.5

result_imgs, result_z = train(include, exclude, extras, params,optimizer,show_crop=True,
                              capture_gen_every=50)

In [None]:
print(len(result_imgs), len(result_z))
print(result_imgs[0].shape, result_z[0].shape)
print(result_imgs[0].min().item(), result_z[0].max().item())
print(result_z[0].min().item(), result_z[0].max().item())

In [None]:
torch.cuda.empty_cache()

### Create animations for interpolations of the images

In [None]:
def interpolate(res_z_list, duration_list):
    """
    for very image in res_z_list we pass a duration(seconds) in duration_list
    """
    gen_img_list = []
    
    fps = 25 # frames per second
    
    for idx, (z,duration) in enumerate(zip(res_z_list, duration_list)):
        torch.cuda.empty_cache()
        num_steps = int(duration*fps) # number of frames
        z1 = z
        
        # the modular division allows to interpolate from last image to first 
        z2 = res_z_list[(idx+1)%len(res_z_list)] # 1 x 256 x (height/16) x (width/16)
        
        for step in range(num_steps):
            #make it a bit interesting: interpolation is not linear
            ## faster in the midle and slower at the end periodically
            alpha = math.sin(1.5*step/num_steps)**6 
            z_new = alpha * z2 + (1-alpha) * z1  #common interpolation formula
            
            new_gen = norm_data(generate(vqgan_model,
                               torch.Tensor(z_new).to(DEVICE)).cpu())[0] ## 3 x height x width
            new_img = transforms.ToPILImage(mode="RGB")(new_gen)
            gen_img_list.append(new_img)
            
    return gen_img_list

In [None]:
durations = [1]*len(result_z)
interpolation_results = interpolate(result_z,durations)

In [None]:
output_video_path = "./outputs/output.mp4"
writer = imageio.get_writer(output_video_path,fps=25)
for pil_img in interpolation_results:
    img = np.array(pil_img, dtype=np.uint8)
    writer.append_data(img)
    
writer.close()

In [None]:
torch.cuda.empty_cache()
include = [
           "a futuristic city in synthwave style"
          ]
TOTAL_ITERATIONS = 500
exclude =  "watermark, cropped, confusing, incoherent, cut, blurry"
extras = ""#"watercolor paper texture"

w1=1
w2=0.5

result_imgs, result_z = train(include, exclude, extras, params,optimizer,show_crop=True,
                              capture_gen_every=25)

In [None]:
durations = [1]*len(result_z)
interpolation_results = interpolate(result_z,durations)

In [None]:
output_video_path = "./outputs/output2.mp4"
writer = imageio.get_writer(output_video_path,fps=25)
for pil_img in interpolation_results:
    img = np.array(pil_img, dtype=np.uint8)
    writer.append_data(img)
    
writer.close()

In [None]:
torch.cuda.empty_cache()
include = [
           "a cyberpunk city in synthwave"
          ]
TOTAL_ITERATIONS = 500
exclude =  "watermark, cropped, confusing, incoherent, cut, blurry"
extras = ""#"watercolor paper texture"

w1=1
w2=1

result_imgs, result_z = train(include, exclude, extras, params,optimizer,show_crop=True,
                              capture_gen_every=25)

In [None]:
durations = [1]*len(result_z)
interpolation_results = interpolate(result_z,durations)

In [None]:
output_video_path = "./outputs/output3.mp4"
writer = imageio.get_writer(output_video_path,fps=25)
for pil_img in interpolation_results:
    img = np.array(pil_img, dtype=np.uint8)
    writer.append_data(img)
    
writer.close()

In [None]:
torch.cuda.empty_cache()
include = [
           "a piano"
          ]
TOTAL_ITERATIONS = 500
exclude =  "watermark, cropped, confusing, incoherent, cut, blurry"
extras = "watercolor paper texture"

w1=1
w2=1

result_imgs, result_z = train(include, exclude, extras, params,optimizer,show_crop=True,
                              capture_gen_every=25)

In [None]:
durations = [1]*len(result_z)
interpolation_results = interpolate(result_z,durations)

In [None]:
output_video_path = "./outputs/output4.mp4"
writer = imageio.get_writer(output_video_path,fps=25)
for pil_img in interpolation_results:
    img = np.array(pil_img, dtype=np.uint8)
    writer.append_data(img)
    
writer.close()

In [None]:
torch.cuda.empty_cache()
include = [
           "a peaceful lake"
          ]
TOTAL_ITERATIONS = 800
exclude =  "watermark, cropped, confusing, incoherent, cut, blurry"
extras = "oil painting"#"watercolor paper texture"

w1=1
w2=1

result_imgs, result_z = train(include, exclude, extras, params,optimizer,show_crop=True,
                              capture_gen_every=25)

In [None]:
durations = [1]*len(result_z)
interpolation_results = interpolate(result_z,durations)

In [None]:
output_video_path = "./outputs/output5.mp4"
writer = imageio.get_writer(output_video_path,fps=25)
for pil_img in interpolation_results:
    img = np.array(pil_img, dtype=np.uint8)
    writer.append_data(img)
    
writer.close()