**Paper Name:** DiffDis: Empowering Generative Diffusion Model with Cross-Modal Discrimination Capability

**Link:** https://openaccess.thecvf.com/content/ICCV2023/papers/Huang_DiffDis_Empowering_Generative_Diffusion_Model_with_Cross-Modal_Discrimination_Capability_ICCV_2023_paper.pd

**Project Members:**
Furkan Genç,
Barış Sarper Tezcan

In [None]:
# define the constants 
WIDTH = 512
HEIGHT = 512
LATENTS_WIDTH = WIDTH // 8
LATENTS_HEIGHT = HEIGHT // 8
BATCH_SIZE = 1
root_dir = "../dataset/cc3m/train" # TODO: change root_dir with the path to the dataset according to your setup

# training parameters
num_train_epochs = 6
Lambda = 1.0
save_steps = 1000

# optimizer parameters
learning_rate = 1e-5
discriminative_learning_rate = 1e-4  # New learning rate for discriminative tasks
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_weight_decay = 1e-4
adam_epsilon = 1e-8

# IMAGE TO TEXT
test_dataset = "CIFAR10"  # Set to "CIFAR100" to use CIFAR100 dataset

# output directory
train_output_dir = "../results/output_1"
test_output_dir = "../results/" + test_dataset
inference_output_dir = "../results/text_to_image/output_1/last"

# Load the models
model_file = "data/v1-5-pruned.ckpt"  
train_unet_file = None  # Set to None to finetune from scratch, if specified, the diffusion model will be loaded from this file
test_unet_file = "../results/output_1/last.pt" 
inference_unet_file = "../results/output_1/last.pt"

# EMA parameters
use_ema = False  # Set to True to use EMA
ema_decay = 0.9999
warmup_steps = 1000

# TEXT TO IMAGE
prompt1 = "A river with boats docked and houses in the background"
prompt2 = "A piece of chocolate swirled cake on a plate"
prompt3 = "A large bed sitting next to a small Christmas Tree surrounded by pictures"
prompt4 = "A bear searching for food near the river"
prompts = [prompt1, prompt2, prompt3, prompt4]
uncond_prompt = ""  # Also known as negative prompt
do_cfg = True
cfg_scale = 3  # min: 1, max: 14
num_samples = 1

# SAMPLER
sampler = "ddpm"
num_inference_steps = 50
seed = 42

In [None]:
import torch
import torch.nn.functional as F
import os
from tqdm import tqdm
from ddpm import DDPMSampler
from pipeline import get_time_embedding
from dataloader import train_dataloader
import model_loader
import time
from diffusion import TransformerBlock, UNet_Transformer  # Ensure these are correctly imported

import pipeline
from PIL import Image
from pathlib import Path
from transformers import CLIPTokenizer

# Set the device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Load the models
models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)
ddpm = DDPMSampler(generator=None)

if train_unet_file is not None:
    # Load the UNet model
    print(f"Loading UNet model from {train_unet_file}")
    models['diffusion'].load_state_dict(torch.load(train_unet_file)['model_state_dict'])
    if 'best_loss' in torch.load(train_unet_file):
        best_loss = torch.load(train_unet_file)['best_loss']
        best_step = torch.load(train_unet_file)['best_step']
        last_loss = torch.load(train_unet_file)['last_loss']
        last_step = torch.load(train_unet_file)['last_step']
    else:
        best_loss = float('inf')
        best_step = 0
        last_loss = 0.0
        last_step = 0
else:
    best_loss = float('inf')
    best_step = 0
    last_loss = 0.0
    last_step = 0

# TEXT TO IMAGE
tokenizer = CLIPTokenizer("./data/vocab.json", merges_file="./data/merges.txt")

# Disable gradient computations for the models['encoder'], DDPM, and models['clip'] models
for param in models['encoder'].parameters():
    param.requires_grad = False

for param in models['clip'].parameters():
    param.requires_grad = False

# Set the models['encoder'] and models['clip'] to eval mode
models['encoder'].eval()
models['clip'].eval()

# Separate parameters for discriminative tasks
discriminative_params = []
non_discriminative_params = []

for name, param in models['diffusion'].named_parameters():
    if isinstance(getattr(models['diffusion'], name.split('.')[0], None), (TransformerBlock, UNet_Transformer)):
        discriminative_params.append(param)
    else:
        non_discriminative_params.append(param)

# AdamW optimizer with separate learning rates
optimizer = torch.optim.AdamW([
    {'params': non_discriminative_params, 'lr': learning_rate},
    {'params': discriminative_params, 'lr': discriminative_learning_rate}
], betas=(adam_beta1, adam_beta2), weight_decay=adam_weight_decay, eps=adam_epsilon)

if train_unet_file is not None:
    print(f"Loading optimizer state from {train_unet_file}")
    optimizer.load_state_dict(torch.load(train_unet_file)['optimizer_state_dict'])

# Linear warmup scheduler for non-discriminative parameters
def warmup_lr_lambda(current_step: int):
    if current_step < warmup_steps:
        return float(current_step) / float(max(1, warmup_steps))
    return 1.0

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[
    warmup_lr_lambda,  # Apply warmup for non-discriminative params
    lambda step: 1.0  # Keep constant learning rate for discriminative params
])

# EMA setup
if use_ema:
    ema_unet = torch.optim.swa_utils.AveragedModel(models['diffusion'], avg_fn=lambda averaged_model_parameter, model_parameter, num_averaged: ema_decay * averaged_model_parameter + (1 - ema_decay) * model_parameter)

In [None]:
def train(num_train_epochs, device="cuda", save_steps=1000):
    global best_loss, best_step, last_loss, last_step

    if train_unet_file is not None:
        first_epoch = last_step // len(train_dataloader)
        global_step = last_step + 1
    else:
        first_epoch = 0
        global_step = 0

    accumulator = 0

    # Move models to the device
    models['encoder'].to(device)
    models['clip'].to(device)
    models['diffusion'].to(device)
    if use_ema:
        ema_unet.to(device)

    num_train_epochs = tqdm(range(first_epoch, num_train_epochs), desc="Epoch")
    for epoch in num_train_epochs:
        train_loss = 0.0
        num_train_steps = len(train_dataloader)
        for step, batch in enumerate(train_dataloader):
            start_time = time.time()

            # Extract images and texts from batch
            images = batch["pixel_values"]
            texts = batch["input_ids"]

            # Move batch to the device
            images = images.to(device)
            texts = texts.to(device)

            # Encode images to latent space
            encoder_noise = torch.randn(images.shape[0], 4, LATENTS_HEIGHT, LATENTS_WIDTH).to(device)  # Shape (BATCH_SIZE, 4, 32, 32)
            latents = models['encoder'](images, encoder_noise)

            # Sample noise and timesteps for diffusion process
            bsz = latents.shape[0]
            timesteps = torch.randint(0, ddpm.num_train_timesteps, (bsz,), device=latents.device).long()
            text_timesteps = torch.randint(0, ddpm.num_train_timesteps, (bsz,), device=latents.device).long()

            # Add noise to latents and texts
            noisy_latents, image_noise = ddpm.add_noise(latents, timesteps)
            encoder_hidden_states = models['clip'](texts)
            noisy_text_query, text_noise = ddpm.add_noise(encoder_hidden_states, text_timesteps)

            # Get time embeddings
            image_time_embeddings = get_time_embedding(timesteps, is_image=True).to(device)
            text_time_embeddings = get_time_embedding(timesteps, is_image=False).to(device)
            
            # Average and normalize text time embeddings
            average_noisy_text_query = noisy_text_query.mean(dim=1)
            text_query = F.normalize(average_noisy_text_query, p=2, dim=-1)

            # Randomly drop 10% of text and image conditions: Context Free Guidance
            if torch.rand(1).item() < 0.1:
                text_query = torch.zeros_like(text_query)
            if torch.rand(1).item() < 0.1:
                noisy_latents = torch.zeros_like(noisy_latents)

            # Predict the noise residual and compute loss
            image_pred, text_pred = models['diffusion'](noisy_latents, encoder_hidden_states, image_time_embeddings, text_time_embeddings, text_query)
            image_loss = F.mse_loss(image_pred.float(), image_noise.float(), reduction="mean")
            text_loss = F.mse_loss(text_pred.float(), text_query.float(), reduction="mean")

            loss = image_loss + Lambda * text_loss
            train_loss += loss.item()
            accumulator += loss.item()

            # Backpropagate
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()
            if use_ema:
                ema_unet.update_parameters(models['diffusion'])

            end_time = time.time()

            if train_unet_file is not None and epoch == first_epoch:
                print(f"Step: {step+1+last_step}/{num_train_steps+last_step}   Loss: {loss.item()}   Time: {end_time - start_time}", end="\r")
            else:
                print(f"Step: {step}/{num_train_steps}   Loss: {loss.item()}   Time: {end_time - start_time}", end="\r")

            if global_step % save_steps == 0 and global_step > 0:
                # Check if the current step's loss is the best
                if accumulator / save_steps < best_loss:
                    best_loss = accumulator / save_steps
                    best_step = global_step
                    best_save_path = os.path.join(train_output_dir, "best.pt")
                    if use_ema:
                        torch.save({
                            'model_state_dict': models['diffusion'].state_dict(),
                            'ema_state_dict': ema_unet.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'best_loss': best_loss,
                            'best_step': best_step,
                            'last_loss': accumulator / save_steps,
                            'last_step': global_step
                        }, best_save_path) 
                    else:
                        torch.save({
                            'model_state_dict': models['diffusion'].state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'best_loss': best_loss,
                            'best_step': best_step,
                            'last_loss': accumulator / save_steps,
                            'last_step': global_step
                        }, best_save_path)              

                    print(f"\nNew best model saved to {best_save_path} with loss {best_loss}")

                # Save model and optimizer state
                last_save_path = os.path.join(train_output_dir, f"last.pt")
                if use_ema:
                    torch.save({
                        'model_state_dict': models['diffusion'].state_dict(),
                        'ema_state_dict': ema_unet.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'best_loss': best_loss,
                        'best_step': best_step,
                        'last_loss': accumulator / save_steps,
                        'last_step': global_step
                    }, last_save_path)
                else:
                    torch.save({
                        'model_state_dict': models['diffusion'].state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'best_loss': best_loss,
                        'best_step': best_step,
                        'last_loss': accumulator / save_steps,
                        'last_step': global_step
                    }, last_save_path)
                    
                print(f"Saved state to {last_save_path}")

                # Generate samples from the model
                for i, prompt in enumerate(prompts):
                    # Sample images from the model
                    output_image = pipeline.generate(
                        prompt=prompt,
                        uncond_prompt=uncond_prompt,
                        input_image=None,
                        strength=0.9,
                        do_cfg=do_cfg,
                        cfg_scale=cfg_scale,
                        sampler_name=sampler,
                        n_inference_steps=num_inference_steps,
                        seed=seed,
                        models=models,
                        device=DEVICE,
                        idle_device=DEVICE,
                        tokenizer=tokenizer,
                    )

                    # Save the generated image
                    output_image = Image.fromarray(output_image)
                    output_image.save(os.path.join(train_output_dir, "images", "prompt" + str(i+1), f"step{global_step}.png"))
                
                print(f"\nSaved images for step {global_step}")
                print('Epoch: %d   Step: %d   Loss: %.5f   Best Loss: %.5f   Best Step: %d\n' % (epoch+1, global_step, accumulator / save_steps, best_loss, best_step))

                accumulator = 0.0

            global_step += 1

        print(f"Average loss over epoch: {train_loss / (step + 1)}")

In [None]:
s = '==> Training starts..'
s += f'\n\nModel file: {model_file}'
s += f'\nUNet file: {train_unet_file}'
s += f'\nBatch size: {BATCH_SIZE}'
s += f'\nWidth: {WIDTH}'
s += f'\nHeight: {HEIGHT}'
s += f'\nLatents width: {LATENTS_WIDTH}'
s += f'\nLatents height: {LATENTS_HEIGHT}'
s += f'\nFirst epoch: {last_step // len(train_dataloader)}'
s += f'\nNumber of training epochs: {num_train_epochs}'
s += f'\nLambda: {Lambda}'
s += f'\nLearning rate: {learning_rate}'
s += f'\nDiscriminative learning rate: {discriminative_learning_rate}'
s += f'\nAdam beta1: {adam_beta1}'
s += f'\nAdam beta2: {adam_beta2}'
s += f'\nAdam weight decay: {adam_weight_decay}'
s += f'\nAdam epsilon: {adam_epsilon}'
s += f'\nUse EMA: {use_ema}'
s += f'\nEMA decay: {ema_decay}'
s += f'\nWarmup steps: {warmup_steps}'
s += f'\nOutput directory: {train_output_dir}'
s += f'\nSave steps: {save_steps}'
s += f'\nDevice: {DEVICE}'
s += f'\nSampler: {sampler}'
s += f'\nNumber of inference steps: {num_inference_steps}'
s += f'\nSeed: {seed}'
for i, prompt in enumerate(prompts):
    s += f'\nPrompt {i + 1}: {prompt}'
s += f'\nUnconditional prompt: {uncond_prompt}'
s += f'\nDo CFG: {do_cfg}'
s += f'\nCFG scale: {cfg_scale}'
s += f'\n\n'
print(s)


train(num_train_epochs=num_train_epochs, device=DEVICE, save_steps=save_steps)

In [None]:
import torch
import torchvision
from torchvision import transforms
import torch.nn.functional as F
import os
from ddpm import DDPMSampler
from pipeline import get_time_embedding
import model_loader
import time
from transformers import CLIPTokenizer

# Set the device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Load the models
models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)
ddpm = DDPMSampler(generator=None)

if test_unet_file is not None:
    # Load the UNet model
    print(f"Loading UNet model from {test_unet_file}")
    if use_ema:
        models['diffusion'].load_state_dict(torch.load(test_unet_file)['ema_state_dict'])
    else:
        models['diffusion'].load_state_dict(torch.load(test_unet_file)['model_state_dict'])

# TEXT TO IMAGE
tokenizer = CLIPTokenizer("./data/vocab.json", merges_file="./data/merges.txt")

# Set the models['encoder'], models['clip'], models['diffusion'] to eval mode
models['encoder'].eval()
models['clip'].eval()
models['diffusion'].eval()

In [None]:
def test(device="cuda"):
    # Get the transform for the test data
    transform = transforms.Compose([
        transforms.Resize((WIDTH, HEIGHT), interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

    # Load the CIFAR-10 dataset
    if test_dataset == "CIFAR10":
        testset = torchvision.datasets.CIFAR10(
            root='../dataset', train=False, download=True, transform=transform)

    elif test_dataset == "CIFAR100":
        testset = torchvision.datasets.CIFAR100(
            root='../dataset', train=False, download=True, transform=transform)

    print(f"Test dataset: {test_dataset} | Number of test samples: {len(testset)}")

    # Load the test data
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    # Move models to the device
    models['encoder'].to(device)
    models['clip'].to(device)
    models['diffusion'].to(device)

    # Define the class names and tokens
    class_names = testset.classes
    class_tokens = []

    # Tokenize class names
    for class_name in class_names:
        # Tokenize text
        tokens = tokenizer.batch_encode_plus(
            [class_name], padding="max_length", max_length=77
        ).input_ids
        tokens = torch.tensor(tokens, dtype=torch.long).squeeze()
        class_tokens.append(tokens)

    # Convert list of class tokens to a tensor
    class_tokens = torch.stack(class_tokens).to(device)
    print(f"Class tokens shape: {class_tokens.shape}")

    # Encode class tokens with the CLIP model
    with torch.no_grad():
        # Encode class tokens
        encoder_hidden_states = models['clip'](class_tokens)

        # Average and normalize class embeddings
        class_embeddings = encoder_hidden_states.mean(dim=1)
        class_embeddings = F.normalize(class_embeddings, p=2, dim=-1)
        print(f"Class embeddings shape: {class_embeddings.shape}\n")
    
    # Start testing
    test_loss = 0.0
    num_test_steps = len(testloader)
    correct_predictions = 0
    total_predictions = 0
    with torch.no_grad():
        for batch_idx, (images, targets) in enumerate(testloader):
            start_time = time.time()

            # Move batch to the device
            images = images.to(device)
            targets = targets.to(device)
            texts = [class_tokens[target] for target in targets]
            
            # Convert list of class tokens to a tensor
            texts = torch.stack(texts).to(device)

            # Encode images to latent space
            encoder_noise = torch.randn(images.shape[0], 4, LATENTS_HEIGHT, LATENTS_WIDTH).to(device)  # Shape (BATCH_SIZE, 4, 32, 32)
            latents = models['encoder'](images, encoder_noise)

            # Sample noise and timesteps for diffusion process
            bsz = latents.shape[0]
            timesteps = torch.randint(0, ddpm.num_train_timesteps, (bsz,), device=latents.device).long()
            text_timesteps = torch.randint(0, ddpm.num_train_timesteps, (bsz,), device=latents.device).long()

            # Add noise to latents and texts
            noisy_latents, image_noise = ddpm.add_noise(latents, timesteps)
            encoder_hidden_states = models['clip'](texts)
            noisy_text_query, text_noise = ddpm.add_noise(encoder_hidden_states, text_timesteps)

            # Get time embeddings
            image_time_embeddings = get_time_embedding(timesteps, is_image=True).to(device)
            text_time_embeddings = get_time_embedding(timesteps, is_image=False).to(device)
            
            # Average and normalize text time embeddings
            average_noisy_text_query = noisy_text_query.mean(dim=1)
            text_query = F.normalize(average_noisy_text_query, p=2, dim=-1)

            # Randomly drop 10% of text and image conditions: Context Free Guidance
            if torch.rand(1).item() < 0.1:
                text_query = torch.zeros_like(text_query)
            if torch.rand(1).item() < 0.1:
                noisy_latents = torch.zeros_like(noisy_latents)

            # Predict the noise residual and compute loss
            _, text_pred = models['diffusion'](noisy_latents, encoder_hidden_states, image_time_embeddings, text_time_embeddings, text_query)
                
            # Calculate loss
            loss = F.mse_loss(text_pred.float(), text_query.float(), reduction="mean")
            test_loss += loss.item()
            
            # Calculate cosine similarity between the generated text query and class embeddings
            similarities = F.cosine_similarity(text_pred.unsqueeze(1), class_embeddings.unsqueeze(0), dim=-1)
            predicted_classes = similarities.argmax(dim=-1)

            # Compare predictions with actual targets
            correct_predictions += (predicted_classes == targets).sum().item()
            total_predictions += targets.size(0)

            end_time = time.time()

            print(f"Batch {batch_idx + 1}/{num_test_steps} | Loss: {loss:.4f} | Time: {end_time - start_time:.2f}s", end="\r")

    # Calculate total accuracy
    accuracy = correct_predictions / total_predictions
    s = f"Accuracy: {correct_predictions}/{total_predictions} ({accuracy:.4f})\n"
    s += f"\nTest Loss: {test_loss / num_test_steps:.4f}"
    print(s)

In [None]:
s = '==> Testing starts..'
s += f'\n\nTest dataset: {test_dataset}'
s += f'\nModel file: {model_file}'
s += f'\nUNet file: {test_unet_file}'
s += f'\nBatch size: {BATCH_SIZE}'
s += f'\nWidth: {WIDTH}'
s += f'\nHeight: {HEIGHT}'
s += f'\nLatents width: {LATENTS_WIDTH}'
s += f'\nLatents height: {LATENTS_HEIGHT}'
s += f'\nNumber of training epochs: {num_train_epochs}'
s += f'\nLambda: {Lambda}'
s += f'\nLearning rate: {learning_rate}'
s += f'\nDiscriminative learning rate: {discriminative_learning_rate}'
s += f'\nAdam beta1: {adam_beta1}'
s += f'\nAdam beta2: {adam_beta2}'
s += f'\nAdam weight decay: {adam_weight_decay}'
s += f'\nAdam epsilon: {adam_epsilon}'
s += f'\nUse EMA: {use_ema}'
s += f'\nEMA decay: {ema_decay}'
s += f'\nWarmup steps: {warmup_steps}'
s += f'\nOutput directory: {test_output_dir}'
s += f'\nSave steps: {save_steps}'
s += f'\nDevice: {DEVICE}'
s += f'\nSampler: {sampler}'
s += f'\nNumber of inference steps: {num_inference_steps}'
s += f'\nSeed: {seed}'
for i, prompt in enumerate(prompts):
    s += f'\nPrompt {i + 1}: {prompt}'
s += f'\nUnconditional prompt: {uncond_prompt}'
s += f'\nDo CFG: {do_cfg}'
s += f'\nCFG scale: {cfg_scale}'
s += f'\n\n'
print(s)

test(device=DEVICE)

In [None]:
import torch
import torch.nn.functional as F
import os
import model_loader
import time
import pipeline
from PIL import Image
from transformers import CLIPTokenizer
from IPython.display import display

# Set the device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Load the models
models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)

if inference_unet_file is not None:
    # Load the UNet model
    print(f"Loading UNet model from {inference_unet_file}")
    models['diffusion'].load_state_dict(torch.load(inference_unet_file)['model_state_dict'])

# TEXT TO IMAGE
tokenizer = CLIPTokenizer("./data/vocab.json", merges_file="./data/merges.txt")

# Generate samples from the model
for i, prompt in enumerate(prompts):
    for j in range(num_samples):
        start = time.time()

        # Sample images from the model
        output_image = pipeline.generate(
            prompt=prompt,
            uncond_prompt=uncond_prompt,
            input_image=None,
            strength=0.9,
            do_cfg=do_cfg,
            cfg_scale=cfg_scale,
            sampler_name=sampler,
            n_inference_steps=num_inference_steps,
            seed=seed,
            models=models,
            device=DEVICE,
            idle_device=DEVICE,
            tokenizer=tokenizer,
        )

        end = time.time()
        
        print(f"PROMPT {i+1} - SAMPLE {j+1} - TIME: {end - start:.2f}s\n")

        # Save the generated image
        output_image = Image.fromarray(output_image)
        
        # Display the generated image
        display(output_image)

**IMPORTANT REMARK** For the purpose of validating the model's training process, we have opted to utilize a representative dummy dataset. This dataset, while not the actual one, mirrors the format of the original CC3M dataset. It's worth noting that the CC3M dataset is considerably large, with a total size of approximately 430 GB. Given this substantial size, and considering our resource constraints, we have determined that it would not be feasible to complete the training process prior to the deadline for the first version, which is set for May 5. However, we have devised a plan to ensure that the training process is completed by the subsequent deadline for the second version, scheduled for May 31.

During the implementation of the model as described in the research paper, we encountered several areas of ambiguity that necessitated the formulation of certain assumptions. These assumptions, which guided our implementation, are detailed as follows:

1) The paper did not provide explicit information regarding whether the transformers employed for the image-to-text alignment task were part of a separate architecture or integrated within the proposed UNet middle blocks. Given this lack of clarity, our implementation treats the transformers as a distinct architecture.

2) The paper did not specify which blocks were to be modified to incorporate the dual stream deep fusion blocks. In our implementation, we have chosen to integrate these blocks into both the downsample and middle blocks of the original UNet, also known as Stable Diffusion. Notably, we did not apply these changes to the upsample layers, as they are not utilized in the image-to-text alignment task.

3) The term 'text query' was used in the paper without a clear definition. In our interpretation, we have chosen to represent the text query as the normalized average of the output from the text encoder.

4) The paper did not provide a clear methodology for the concatenation of the hidden latent image and the output of the fully connected layer. In our implementation, we expanded the output of the fully connected layer from a shape of (Batch Size, Channels) to (Batch Size, Channels, Height, Width), enabling its concatenation with the latent image, which also has a shape of (Batch Size, Channels, Height, Width).

5) The paper did not provide explicit instructions on how the fully connected layer projects the text query back into the text embedding space. We assumed that it generates an output with a shape of (Batch Size, Width * Height, 768). We then computed the normalized average of this output along dimension 1, resulting in an output of shape (Batch Size, 768). This output serves as the hidden text query that is input to the next layer.