In [13]:
from datasets import load_dataset
from transformers import CLIPModel, CLIPProcessor
from model import CLIPBackbone, CLIPTokenizer
import torch
from PIL import Image
import torch.nn.functional as F

In [2]:


backbone = CLIPBackbone()

In [3]:
# Load test sample
dataset = load_dataset("nlphuji/flickr30k", split="test")
sample = dataset[0]
image = sample["image"]
captions = sample["caption"]           # List of 5 captions
images = [image] * len(captions)        # Duplicate image 5 times

In [4]:
# Prepare inputs
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
inputs = processor(text=captions, images=images,
                   return_tensors="pt", padding=True, truncation=True)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [5]:
input_ids = inputs["input_ids"]
pixel_values = inputs["pixel_values"]

In [6]:
# Extract features
img_hidden = backbone.encode_image(pixel_values)      # (1, 50, 768)
txt_embed = backbone.embedded_text(input_ids)            # (1, L, 512)

In [7]:
print("Image shape:", img_hidden.shape)
print("Text shape:", txt_embed.shape)

Image shape: torch.Size([5, 50, 768])
Text shape: torch.Size([5, 19, 512])


In [8]:
print(dataset.column_names)
print(dataset[0])

['image', 'caption', 'sentids', 'split', 'img_id', 'filename']
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=333x500 at 0x1EFE071FB00>, 'caption': ['Two young guys with shaggy hair look at their hands while hanging out in the yard.', 'Two young, White males are outside near many bushes.', 'Two men in green shirts are standing in a yard.', 'A man in a blue shirt standing in a garden.', 'Two friends enjoy time spent together.'], 'sentids': ['0', '1', '2', '3', '4'], 'split': 'train', 'img_id': '0', 'filename': '1000092795.jpg'}


In [9]:
print(f"Text embeddings shape: {txt_embed.shape}")  # (5, L, 512)

# Optional: print each caption's first token embedding
for i in range(5):
    print(f"\nCaption {i + 1}: {captions[i]}")
    print(
        f"First token embedding (shape: {txt_embed[i, 0].shape}):\n{txt_embed[i, 0]}")

Text embeddings shape: torch.Size([5, 19, 512])

Caption 1: Two young guys with shaggy hair look at their hands while hanging out in the yard.
First token embedding (shape: torch.Size([512])):
tensor([-3.1102e-03,  2.7125e-04, -8.3240e-03,  6.5898e-03, -5.4523e-03,
        -6.8510e-04, -4.6194e-03, -1.1066e-04,  1.3952e-04,  3.6518e-03,
         7.8880e-05, -3.3456e-03,  3.0308e-03, -2.5638e-03, -6.5136e-04,
         5.1040e-03, -7.1051e-04, -1.5060e-03, -7.7988e-03, -4.4290e-03,
        -6.3955e-04, -4.8186e-03,  1.3089e-03, -1.3589e-03, -3.1162e-03,
        -9.7673e-04,  2.0508e-03,  3.0429e-04, -1.6041e-03,  3.8689e-03,
        -6.1285e-03, -5.2158e-03, -4.5520e-03, -8.7706e-03,  5.2336e-04,
        -5.5566e-03, -9.7672e-03,  1.9584e-02, -1.7442e-03,  1.7207e-03,
        -1.0862e-03,  7.1694e-04,  4.4945e-04, -5.8205e-03,  2.6555e-03,
         6.9091e-03,  2.0991e-03, -8.4015e-05,  1.5352e-03,  2.1451e-03,
        -2.3066e-03,  6.4552e-03,  2.8934e-03,  6.8767e-03,  4.5625e-03,
    

In [10]:
# Constants
B = 1                # Batch size
P = 2                # Image tokens
T = 3                # Text tokens
PAD = 2              # Padding tokens
L = P + T + PAD
D = 16               # Embedding dim
H = 8                # Head dim

In [None]:
import torch.nn as nn
# Mock AttentionHead (with debug prints)


class AttentionHead(nn.Module):
    def __init__(self, embed_dim: int, head_size: int, max_seq_len: int):
        super().__init__()
        self.key = nn.Linear(embed_dim, head_size, bias=False)
        self.query = nn.Linear(embed_dim, head_size, bias=False)
        self.value = nn.Linear(embed_dim, head_size, bias=False)
        self.register_buffer('causal_mask', torch.tril(
            torch.ones(max_seq_len, max_seq_len)))
        self.dropout = nn.Dropout(0.0)  # disable dropout for debug

    def forward(self, x, image_len: int, pad_mask: torch.Tensor):
        B, L, D = x.size()
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        attn_scores = q @ k.transpose(-2, -1) * D**-0.5  # (B, L, L)

        # 1. Build causal mask for text only
        seq_len = L - image_len - pad_mask.sum().item()
        text_start = image_len
        text_end = image_len + seq_len

        mask = torch.ones(L, L, device=x.device)
        causal_text_mask = self.causal_mask[:seq_len, :seq_len]
        mask[text_start:text_end, text_start:text_end] = causal_text_mask

        # 2. Apply padding mask — no attention to/from pad tokens
        pad_indices = (pad_mask == 0).nonzero(as_tuple=True)[1]
        for idx in pad_indices:
            mask[:, idx] = 0
            mask[idx, :] = 0

        # Before masking
        print("\nShape: ", attn_scores.shape)
        print("\nRaw Attention Scores:")
        print(attn_scores[0].detach().numpy().round(2))
        print("\nShape: ", attn_scores.shape)
        
        # 3. Apply mask
        masked_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
        print("\nMasked Attention Scores:")
        print(masked_scores[0].detach().numpy().round(2))
        print("\nShape: ", attn_scores.shape)

        # 4. Softmax
        attn_weights = F.softmax(masked_scores, dim=-1)
        print("\nAttention Weights (After Softmax):")
        print(attn_weights[0].detach().numpy().round(2))
         print("\nShape: ", attn_scores.shape)
        out = attn_weights @ v
        return out

In [14]:
# Initialize
head = AttentionHead(D, H, max_seq_len=L)

# Input tensor: random embeddings
x = torch.randn(B, L, D)

# Create mock padding mask: 1 = keep, 0 = pad
# [image, text, pad]
pad_mask = torch.tensor([[1] * (P+T) + [0]*PAD])  # (B, L)

# Forward pass
output = head(x, image_len=P, pad_mask=pad_mask)
print("\nOutput shape:", output.shape)


Raw Attention Scores:
[[-0.04 -0.23 -0.28 -0.03 -0.31  0.27  0.09]
 [ 0.08  0.07  0.15 -0.18  0.06 -0.11  0.04]
 [ 0.07 -0.24  0.33  0.2  -0.28 -0.27  0.03]
 [ 0.41 -0.45  0.34  0.42  0.19 -0.15 -0.05]
 [ 0.17  0.07  0.07  0.37  0.43 -0.15 -0.44]
 [-0.28  0.33 -0.1   0.18  0.04  0.22 -0.16]
 [-0.39  0.41  0.07 -0.07  0.11  0.3   0.44]]

Masked Attention Scores:
[[-0.04 -0.23 -0.28 -0.03 -0.31  -inf  -inf]
 [ 0.08  0.07  0.15 -0.18  0.06  -inf  -inf]
 [ 0.07 -0.24  0.33  0.2  -0.28  -inf  -inf]
 [ 0.41 -0.45  0.34  0.42  0.19  -inf  -inf]
 [ 0.17  0.07  0.07  0.37  0.43  -inf  -inf]
 [ -inf  -inf  -inf  -inf  -inf  -inf  -inf]
 [ -inf  -inf  -inf  -inf  -inf  -inf  -inf]]

Attention Weights (After Softmax):
[[0.23 0.19 0.18 0.23 0.17 0.   0.  ]
 [0.21 0.2  0.22 0.16 0.2  0.   0.  ]
 [0.2  0.15 0.27 0.23 0.15 0.   0.  ]
 [0.24 0.1  0.22 0.24 0.19 0.   0.  ]
 [0.19 0.17 0.17 0.23 0.24 0.   0.  ]
 [ nan  nan  nan  nan  nan  nan  nan]
 [ nan  nan  nan  nan  nan  nan  nan]]

Output shape: t