# Hierarchical Text-Conditional Image Generation with CLIP Latents (UnCLIP)

- A `prior` $P(z_i|y)$ that produces CLIP image embeddings $z_i$ conditioned on captions $y$.
- A `decoder` $P(x|z_i,y)$ that produces images $x$ conditioned on CLIP image embeddings $z_i$ (and optionally text captions $y$). 

The decoder allows us to invert images given their CLIP image embeddings, while the prior allows us to learn a generative model of the image embeddings themselves. Stacking these two components yields a generative model $P(x|y)$ of images~$x$ given captions~$y$:
$$P(x|y) = P(x,z_i|y) = P(x|z_i,y) P(z_i|y).$$

<img src="./figures/unclip-figurehead.png" title="UNET" />


A high-level overview of \modelname{}. Above the dotted line, we depict the CLIP training process, through which we learn a joint representation space for text and images. Below the dotted line, we depict our text-to-image generation process: a CLIP text embedding is first fed to an autoregressive or diffusion prior to produce an image embedding, and then this embedding is used to condition a diffusion decoder which produces a final image. Note that the CLIP model is frozen during training of the prior and decoder.

In [10]:
import PIL
import torch
import numpy as np
from PIL import Image
from tqdm.auto import tqdm, trange
from itertools import islice
import torch.nn as nn

In [2]:
def load_img(path):
    image = Image.open(path).convert("RGB")
    w, h = image.size
    print(f"loaded input image of size ({w}, {h}) from {path}")
    w, h = map(lambda x: x - x % 64, (w, h))  # resize to integer multiple of 64
    image = image.resize((w, h), resample=PIL.Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2. * image - 1.

def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())

## Dataset

In [34]:
!pip install img2dataset
!wget https://huggingface.co/datasets/ChristophSchuhmann/MS_COCO_2017_URL_TEXT/resolve/main/mscoco.parquet
!img2dataset --url_list mscoco.parquet --input_format "parquet" --url_col "URL" --caption_col "TEXT" --output_format webdataset\
    --output_folder data/mscoco --processes_count 16 --thread_count 64 --image_size 256 --enable_wandb True

In [None]:
from IPython.display import Image
Image(filename='data/mscoco/00000/000000000.jpg')
     

## Model 

### 0. CLIP Model

<img src="./figures/CLIP.png" title="UNET" />


In [16]:
# !pip install open_clip_torch
import open_clip
open_clip.list_pretrained()

  torch.utils._pytree._register_pytree_node(


AttributeError: module 'torch._functorch.eager_transforms' has no attribute 'grad_and_value'

In [6]:
import torch
from PIL import Image
import open_clip

model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
model.eval()  # model in train mode by default, impacts some models with BatchNorm or stochastic depth active
tokenizer = open_clip.get_tokenizer('ViT-B-32')

image = preprocess(Image.open("figures/CLIP.png")).unsqueeze(0)
text = tokenizer(["a diagram", "a dog", "a cat"])

with torch.no_grad(), torch.cuda.amp.autocast():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

print("Label probs:", text_probs)  # prints: [[1., 0., 0.]]

### 1. Prior

A prior model that produces $z_i$ from captions $y$ to enable image generations from text captions.

- `Autoregressive (AR) prior` : the CLIP image embedding $z_i$ is converted into a sequence of discrete codes and predicted autoregressively conditioned on the caption $y$.
- `Diffusion prior` : The continuous vector $z_i$ is directly modelled using a Gaussian diffusion model conditioned on the caption $y$.

#### 1.1. Autoregressive (AR) Prior

In [7]:
ar_prior = None

#### 2.2. Diffusion Prior

In [8]:
diff_prior = None

### 2. Decoder

we train two diffusion upsampler models : one to upsample images from $64 \times 64$ to $256 \times 256$ resolution, and another to further upsample those to $1024 \times 1024$ resolution

### 3. UnCLIP Model

In [11]:
class UnCLIP(nn.Module):
    def __init__(self, prior):
        super(UnCLIP).__init__()
        self.prior = prior
        self.decoder = self.encoder_model()
        self.encoder = self.decoder_model(self.prior)
        
    
    def encoder_model(self):
        pass
    
    def decoder_model(self):
        if self.prior == 'autogressive':
            pass
        elif self.prior == 'diffision':
            pass
        else:
            raise NotImplementedError(f"Seem like this priot option {self.prior} is not implemented.")

    def foward(self):
        pass

### 3. Loss
$$L_{\text{prior}} = \mathbb{E}_{t \sim [1,T], z_i^{(t)} \sim q_t} \big[\|f_{\theta}(z_i^{(t)}, t, y) - z_i\|^2\big]$$

## Training