# Text conditioning for Flow Matching


Similarly, we can add text conditioning to the flow matching algorithm. The modification to the training loop is small. Most of the code change is already done inside the denoising model.

In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
import numpy as np
np.int = np.int32
import torch
import torch.nn as nn
from transformers import CLIPTextModel, CLIPTokenizer


class TextEncoder(nn.Module):
    def __init__(self, model_name: str, device: str):
        super().__init__()
        self.model_name = model_name
        self.model = CLIPTextModel.from_pretrained(model_name).to(device)
        self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
        self.device = device
        # Get the text embedding dimension from the config
        self.text_embed_dim = self.model.config.hidden_size

    def forward(self, text: str) -> torch.Tensor:
        tokens = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt").to(self.device)
        return self.model(**tokens).pooler_output

2025-01-09 03:50:19.415450: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-01-09 03:50:19.428194: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1736394619.442354 3747064 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1736394619.446563 3747064 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-09 03:50:19.462277: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

## Training step

In [3]:
import itertools
import torch
import torch.nn as nn
from torch.nn import MSELoss
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import sys

sys.path.append("..")
from lib_4_1.config import TrainingConfig
from src.train_cfm import ExactOptimalTransportConditionalFlowMatcher, ConditionalFlowMatcher

def train(
    config: TrainingConfig,
    model: nn.Module,
    text_encoder: TextEncoder,
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    steps: int=100,
    silent: bool=False,
) -> float:
  device = config.device
  # FM = ExactOptimalTransportConditionalFlowMatcher(sigma=0)
  FM = ConditionalFlowMatcher(sigma=0)
  
  model.train()
  if not silent:
    print("Training on device:", device)
  max_train_steps = steps

  loss = None
  progress_bar = tqdm(itertools.cycle(train_dataloader), total=max_train_steps, disable=silent)
  step = 0
  criterion = MSELoss()
  for batch in progress_bar:
    x_1 = batch[0]  # x_0 is the clean image to teach the model to generate
    text = batch[1]["text"]  # text is the caption of the image
    assert len(text) == x_1.shape[0]
    # assert the type of text is a list of strings
    
    optimizer.zero_grad()

    # Implement classifier-free guidance training
    # Randomly drop out text conditioning with 10% probability
    # The dropout is applied to the batch as a whole.
    # Alternatively, we could apply it to each image in the batch.
    text_drop_prob = 0.2
    x_1 = x_1.to(device)
    x_0 = torch.randn_like(x_1).to(device)

    with torch.no_grad():
        text_embeddings = text_encoder(text)

    t, x_t, u_t = FM.sample_location_and_conditional_flow(x0=x_0, x1=x_1)
    # t, x_t, u_t, _, text_embeddings_t = FM.guided_sample_location_and_conditional_flow(x0=x_0, x1=x_1, y0=None, y1=text_embeddings)
    
    # A dropout is applied to the ``text_embeddings`` input:
    #   This means `predicted_noise` will be computed with 20% probability of the text embeddings being dropped out.
    #   The model learns to predict the noise both with and without the text embeddings.
    v_t = model(t=t, x=x_t, text_embeddings=text_embeddings, p_uncond=text_drop_prob)

    loss = criterion(u_t, v_t)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1)  # try commenting it out
    optimizer.step()

    step += 1

    if not silent:
      progress_bar.set_postfix({"loss": loss.cpu().item()})

    if step >= max_train_steps:
      break

  return loss

In [4]:
from torch import optim
from lib_4_1.data import load_data
from lib_4_1.model import create_unet_model

config = TrainingConfig(dataset="reese-green/afhq64_captions_64k", caption_column="caption_blip2-opt-2.7b", batch_size=16, resolution=32)
text_encoder = TextEncoder("openai/clip-vit-large-patch14", "cuda:0")
text_encoder.eval()
train_ds, val_ds = load_data(config)
denoising_model = create_unet_model(config, config.device)
optimizer = optim.AdamW(denoising_model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)

model params: 14.68 M


In [5]:
train_dataloader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size=config.batch_size, shuffle=False)

In [6]:
train(
    config=config,
    model=denoising_model,
    text_encoder=text_encoder,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    optimizer=optimizer,
    steps=8000,
    silent=False
)

Training on device: cuda


  0%|          | 0/8000 [00:00<?, ?it/s]

tensor(0.1608, device='cuda:0', grad_fn=<MseLossBackward0>)

In [7]:
torch.save(denoising_model.state_dict(), "denoising_model_4_2.pth")

In the next tutorial, we will use the trained flow matching model to generate images with text conditioning.