# Text Conditioning (text2image)

Conditioning means prompting. So far, we have been doing unconditional image generation from pure noise, or generation without prompting as any context information:

$$
\textrm{noise} \rightarrow \textrm{image}
$$

Though it is cool to see the model generate interesting images like animal faces, it is not very useful yet as it lacks user control.

In this tutorial, we will start to add text conditioning to the model:

$$
\textrm{text} + \textrm{noise} \rightarrow \textrm{image}
$$

so that a prompt like "a black cat with lots of fur" can actually trigger the model to follow the instruction and generate such an image.

A popular approach to do text-to-Image generation is "classifier-free guidance". Historically, the first major technique along these lines was "classifier guidance", which relies on a separate classifier to steer image synthesis.

Let $x$ be the image, and $y$ be the text prompt. There are 2 distributions we could sample from:

- unconditional distribution $p_{\theta}(x)$: "what do images look like in general?"
- conditional distribution $p_{\theta}(x \mid y)$: "what do images look like given a text prompt $y$?"

Applying the Bayes rule, we have:

$$
p_{\theta}(x \mid y) \propto p_{\theta}(x) \cdot p_{\theta}(y \mid x)
$$

When we look at the gradient of the log-probability,

$$
\nabla_{x} \log p_{\theta}(x \mid y) = \nabla_{x} \log p_{\theta}(x) + \nabla_{x} \log p_{\theta}(y \mid x)
$$

Note that $\nabla_{x_t} \log p_{\theta}(x_t)$ is approximately the output of the denoising model (the noise prediction), up to a constant factor.

Therefore, we can think of text-conditioned generation as taking the unconditional model 
and adjusting ("tweaking") it with $p_{\theta}(y \mid x)$, which is a "classifier" that predicts the label or text description given an image. When we add these two terms, we steer the generation toward images that are both likely under the unconditional model and consistent with the text $y$. This is the core idea behind classifier guidance.

### Classifier Guidance

With classifier guidance, there are 2 models:

- a diffusion model trained to denoise images (learn $p_{\theta}(x)$)
- a separate classifier trained to predict the label or text (e.g. cat, dog, etc.) $y$ given an image $x$.

During generation (the reverse diffusion process), at each step we:

1. Take the partially denoised image $x_t$.
2. Run it through the classifier to get a prediction $\log p_{\theta}(y \mid x_t)$.
3. Compute gradient of the log-probability $\nabla_{\theta} \log p_{\theta}(y \mid x_t)$ that tells us how to tweak $x_t$ to be more consistent with the label $y$.
4. Mix this gradient into the diffusion model’s own prediction to push the sample toward matching $y$.

If $y$ is "a cat sitting on a mat," the classifier $p_{\theta}(y \mid x_t)$ evaluates the likelihood of $x_t$ matching this description. The guidance term $\nabla_{x} \log p_{\theta}(y \mid x_t)$ nudges the model to refine samples towards features of a cat on a mat, while $\nabla_{x} \log p_{\theta}(x)$ ensures the output remains realistic and coherent as an image.

While it works, it requires training or fine-tuning a separate classifier that can handle noisy intermediate images. Usually the classifier is tied to specific labels or tasks, which limits its flexibility.




### Classifier-Free Guidance

Can we get the benefits of classifier guidance (steering images toward a text prompt) without needing a separate classifier? It turns out a clever trick can do this.

The key idea is to train a single model that has two modes:

- Conditional mode: Model sees the text prompt $y$.
- Unconditional mode: Model receives a null or "empty" prompt.

#### During Training

We train **the same diffusion model** to handle both modes by randomly "dropping out" the text prompt some fraction of the time during training. This is so that we have a model is able to output both the unconditional noise prediction $\epsilon_{\theta}(x_t, t, \emptyset)$ and the conditional noise prediction $\epsilon_{\theta}(x_t, t, y)$, where $y$ is the text prompt and $\emptyset$ is the null text prompt.

#### During Generation

At each reverse diffusion step, we make two predictions:

1. Unconditional noise prediction $\mathbf\epsilon_{\theta}(x_t, t, \emptyset)$ using the null text prompt $\emptyset$. $t$ is the denoising time step.
2. Conditional noise prediction: $\mathbf\epsilon_{\theta}(x_t, t, y)$ using real text prompt $y$.

We then mix them:

$$
\mathbf\epsilon = \mathbf\epsilon_{\theta}(x_t, t, \emptyset) + s \cdot (\mathbf\epsilon_{\theta}(x_t, t, y) - \mathbf\epsilon_{\theta}(x_t, t, \emptyset))
$$

where $s$ is the guidance scale that controls how strongly the final output is "pushed" toward matching the text.

## Model Architecture Design choices for Text Conditioning

Similar to the 2D point clouds, we are interested to experiment with different model architectures. Specifically for text conditioning (combining information from text and image), 3 approaches are in consideration:

- Linear modulation
    - Scale and shift other embeddings or intermediate representations based on the text embedding
    - Can apply to time embedding, intermediate representations, or normalization layers.
- Concatenation
    - Concatenate the text embedding with other embeddings.
- Cross attention
    - Use the intermediate representations as queries to attend to the text embedding.

The implementation also depends on whether we use the UNet or Transformer backbone. Let's focus from UNet and linear modulation. The modulation can be done by simply adding the text embedding to the time embedding:

```diff
# In UNetModel.__init__():
+ self.text_proj = nn.Sequential(
+                 linear(text_embed_dim, time_embed_dim),
+                 nn.SiLU(),
+                 linear(time_embed_dim, time_embed_dim),
+             )
...

# In forward():
emb = self.time_embed(timestep_embedding(t, self.model_channels))
+ emb = emb + self.text_proj(text_embeddings)  # add the text embedding to the time embedding
```

Then both the time+text embedding `emb` and the intermediate hidden representations `h` passes through the regular UNet blocks (no changes to following UNet blocks):

```python
for module in self.input_blocks:
    h = module(h, emb)
    hs.append(h)
h = self.middle_block(h, emb)
for module in self.output_blocks:
    h = th.cat([h, hs.pop()], dim=1)
    h = module(h, emb)
...
```


### Text embedding "dropout"

The recipe for classifier-free guidance requires us to randomly drop the text embedding some fraction of the time during training. This way we train a single model that can operates in both unconditional and conditional modes.

When computing the training loss in the forward pass, we need to randomly drop the text embedding some fraction of the time:

```diff
# In forward() of the denoising model (UNetModel), do the dropout randomly based on the `p_uncond` parameter:
# If p_uncond is 0.2, then 20% of the time we drop out the text embedding.
+ unconditional_mask = (th.rand(text_embeddings.shape[0]) < p_uncond)  # this gets the indices of the texts that we want to drop out
+ text_embeddings[unconditional_mask] = self.null_text_embed  # this sets the text embeddings to the null text embedding
```

And the `null_text_embed` is a learnable embedding for the null text prompt:

```diff
# In UNetModel.__init__():
+ self.text_embed_dim = text_embed_dim
+ self.null_text_embed = nn.Parameter(th.randn(1, text_embed_dim) * 0.02)
```

This way the model learns what the null text embedding means in the latent space. An alternative is simply setting `null_text_embed` to all zeros.


## Text Encoder

To produce the text embedding, we can use a pre-trained CLIP model.

In [1]:
import torch
import torch.nn as nn
from transformers import CLIPTextModel, CLIPTokenizer


class TextEncoder(nn.Module):
    def __init__(self, model_name: str, device: str):
        super().__init__()
        self.model_name = model_name
        self.model = CLIPTextModel.from_pretrained(model_name).to(device)
        self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
        self.device = device
        # Get the text embedding dimension from the config
        self.text_embed_dim = self.model.config.hidden_size

    def forward(self, text: str) -> torch.Tensor:
        tokens = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt").to(self.device)
        return self.model(**tokens).pooler_output

2024-12-30 22:02:07.609824: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-30 22:02:07.622932: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-30 22:02:07.637863: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-30 22:02:07.642424: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-30 22:02:07.653395: I tensorflow/core/platform/cpu_feature_guar



## Training

Now that the denoising model can handle text embeddings as an additional input, we also need to modify the training step to pass in the text embeddings. The main logic remains the same.

### Classifier-Free Guidance in Training: Text Dropout

With 20% of chance, the text embedding will be set to an empty embedding. The actual logic of dropout happens inside the denoising model.

In [2]:
import itertools
import torch
import torch.nn as nn
from torch.nn import MSELoss
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from typing import Dict

from lib_4_1.diffusion import forward_diffusion
from lib_4_1.bookkeeping import Bookkeeping
from lib_4_1.config import TrainingConfig

def train(
    config: TrainingConfig,
    model: nn.Module,
    text_encoder: TextEncoder,
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    noise_schedule: Dict,
    optimizer: torch.optim.Optimizer,
    steps: int=100,
    silent: bool=False,
    bookkeeping: Bookkeeping=None
) -> float:
  device = config.device
  num_denoising_steps = config.num_denoising_steps
  
  model.train()
  if not silent:
    print("Training on device:", device)
  max_train_steps = steps

  loss = None
  progress_bar = tqdm(itertools.cycle(train_dataloader), total=max_train_steps, disable=silent)
  step = 0
  criterion = MSELoss()
  for batch in progress_bar:
    x_0 = batch[0]  # x_0 is the clean image to teach the model to generate
    text = batch[1]["text"]  # text is the caption of the image
    assert len(text) == x_0.shape[0]
    # assert the type of text is a list of strings
    x_0 = x_0.float().to(device)  # x_0 is the clean data to teach the model to generate
    optimizer.zero_grad()

    # Implement classifier-free guidance training
    # Randomly drop out text conditioning with 10% probability
    # The dropout is applied to the batch as a whole.
    # Alternatively, we could apply it to each image in the batch.
    text_drop_prob = 0.2
    true_noise = common_noise = torch.randn(x_0.shape).to(device)
    t = torch.randint(0, num_denoising_steps, (x_0.shape[0],), device=device).long()
    x_t, _ = forward_diffusion(x_0, t, noise_schedule, noise=common_noise)

    with torch.no_grad():
        text_embeddings = text_encoder(text)

    # A dropout is applied to the ``text_embeddings`` input:
    #   This means `predicted_noise` will be computed with 20% probability of the text embeddings being dropped out.
    #   The model learns to predict the noise both with and without the text embeddings.
    predicted_noise = model(t=t, x=x_t, text_embeddings=text_embeddings, p_uncond=text_drop_prob)

    loss = criterion(predicted_noise, true_noise)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1)  # try commenting it out
    optimizer.step()

    step += 1

    if not silent:
      progress_bar.set_postfix({"loss": loss.cpu().item()})

    if bookkeeping:
      bookkeeping.run_callbacks(config=config, step=step, loss=loss, optimizer=optimizer, val_dataloader=val_dataloader)

    if step >= max_train_steps:
      break

  return loss

### Captioned Image Dataset and Text Encoder

Let's create the main components of the model. Its has under 15M parameters, rather small.

For text encoder, we are using CLIP.

For the dataset of image-text pairs, we will use the dataset `reese-green/afhq64_captions_64k` generated by running the `blip2-opt-2.7b` model on the animal face images to extract the text description of the image. Here, we use a resolution of 32x32 pixels to make training faster.

A good dataset is important for the performance of the model. To learn more about how to create a captioned image dataset, please refer to [this tutorial](Image%20Captioning%20Lesson.ipynb).

In [3]:
from torch import optim
from lib_4_1.data import load_data
from lib_4_1.model import create_unet_model
from lib_4_1.diffusion import create_noise_schedule

config = TrainingConfig(dataset="reese-green/afhq64_captions_64k", caption_column="caption_blip2-opt-2.7b", batch_size=16, resolution=32)
text_encoder = TextEncoder("openai/clip-vit-large-patch14", "cuda:0")
text_encoder.eval()
train_ds, val_ds = load_data(config)
noise_schedule = create_noise_schedule(n_T=config.num_denoising_steps, device=config.device)
denoising_model = create_unet_model(config, config.device)
optimizer = optim.AdamW(denoising_model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)

model params: 14.68 M


### A Quick Data Check

A quick check that the inputs and targets of a training example look good.

In [4]:
for x in train_ds:
    print(x[0].shape)
    print(x[1])
    break

torch.Size([3, 32, 32])
{'label': 0, 'text': 'a large white dog with brown eyes sitting on the grass'}


Similarly, check to see if the mini-batch look good.

In [5]:
train_dataloader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size=config.batch_size, shuffle=False)
for x in train_dataloader:
    print(x[0].shape)
    print(x[1]["text"])
    text_embeddings = text_encoder(x[1]["text"])
    print(text_embeddings.shape)
    break

torch.Size([16, 3, 32, 32])
['a gray cat with green eyes sitting on a table', 'a close up of a cat with a sad expression', 'a black cat with green eyes standing in front of a fence', 'a gray and white dog with an orange collar', 'a gray cat is sitting on a red blanket', 'a leopard walking through the grass in the wild', 'a white dog with its tongue out and its tongue hanging out', 'a cheetah with its tongue out in the grass', 'a gray and white cat laying on a couch', 'a black french bulldog sitting down with his ears up', 'a brown dog running on the ground', 'a cat is sitting on a branch with green leaves', 'a small black and white dog wearing a purple harness', 'a white dog with a collar on sitting on a bed', 'a lion cub is sitting in a zoo enclosure', 'a golden retriever sitting on a bench with its owner']


torch.Size([16, 768])


### We Train

We train 20,000 steps. This can take 20~30 minutes on a A10 instance on Lambda Labs.

In [6]:
%%time

train(
    config=config,
    model=denoising_model,
    text_encoder=text_encoder,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    noise_schedule=noise_schedule,
    optimizer=optimizer,
    steps=20000,
    silent=False
)

Training on device: cuda


  0%|          | 0/20000 [00:00<?, ?it/s]

tensor(0.0426, device='cuda:0', grad_fn=<MseLossBackward0>)

### Save the Model

The checkpoint will be useful in the next tutorial, where we see our model it in action.

In [7]:
# save the model
torch.save(denoising_model.state_dict(), "denoising_model_4_1.pth")


## Generate images with text conditioning

In the next tutorial, we will use the updated sampling code to generate images with text conditioning. By varying the `guidance_scale` parameter, we can see how the text conditioning affects the generated images.