In [1]:
from transformers import CLIPModel, CLIPProcessor
from transformers import GPT2LMHeadModel, GPT2Tokenizer

from datasets import load_dataset

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.optim import AdamW

# from PIL import Image
# import requests
# import os
# import matplotlib.pyplot as plt
# import numpy as np

In [2]:
###
### set parameters
###

max_length = 512
model_d = 512
batch_size = 1
prefix_length = 16
lm_embed_v_dim = int(model_d//prefix_length)
target_m_dim = 768

In [3]:
###
### define models
###

### encoder - convert raw img/txt to emb

clip_m = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

# clip - vision model
clip_m_v = clip_m.vision_model
clip_m_v_proj = clip_m.visual_projection

# clip - text model
clip_m_t = clip_m.text_model
clip_m_t_proj = clip_m.text_projection

# clip - process both images and texts
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# set lin projection model
lin_m_v = nn.Linear(model_d, model_d)
lin_m_t = nn.Linear(model_d, lm_embed_v_dim)
lin_m_vt4dec = nn.Linear(lm_embed_v_dim, target_m_dim)

### decoder - accept emb and do task

dec_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
dec_model = GPT2LMHeadModel.from_pretrained("gpt2")

dec_tokenizer.pad_token = dec_tokenizer.eos_token # as pad_token is None
dec_tokenizer.pad_token_id = dec_tokenizer.eos_token_id # add token id

# Confirm embedding size
print(dec_model.config.n_embd)  # 768 for GPT2-base

768


In [10]:
###
### load data
###

# # load text
# # load image
# text = ["a photo of a cat"]
# url = "http://images.cocodataset.org/val2017/000000039769.jpg"
# img = Image.open(requests.get(url, stream=True).raw)

# # this causes error
# dataset = load_dataset("nlphuji/flickr30k")
# UnicodeDecodeError: 'utf-8' codec can't decode byte 0xd0 in position 1: invalid continuation byte

# Try the official dataset
try:
    dataset = load_dataset("nlphuji/flickr30k", trust_remote_code=True)
except UnicodeDecodeError:
    # Fallback to alternative sources
    try:
        dataset = load_dataset("nlphuji/flickr30k", encoding='latin-1', trust_remote_code=True)
    except:
        dataset = load_dataset("nlphuji/flickr30k", encoding='cp1252', trust_remote_code=True)

dataset = dataset["test"]

Using the latest cached version of the module from /home/tsujimura/.cache/huggingface/modules/datasets_modules/datasets/nlphuji--flickr30k/6adb9ab2367c57c3e81e76ecaecb8047ea00c33dccf9da10455037f32ec43382 (last modified on Sun Jun 29 22:06:34 2025) since it couldn't be found locally at nlphuji/flickr30k, or remotely on the Hugging Face Hub.
Repo card metadata block was not found. Setting CardData to empty.


In [11]:
dataset

Dataset({
    features: ['image', 'caption', 'sentids', 'split', 'img_id', 'filename'],
    num_rows: 31014
})

In [None]:
# let's say use only the first caption as training dataset for paried image for decoder model
dataset[0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=333x500>,
 'caption': ['Two young guys with shaggy hair look at their hands while hanging out in the yard.',
  'Two young, White males are outside near many bushes.',
  'Two men in green shirts are standing in a yard.',
  'A man in a blue shirt standing in a garden.',
  'Two friends enjoy time spent together.'],
 'sentids': ['0', '1', '2', '3', '4'],
 'split': 'train',
 'img_id': '0',
 'filename': '1000092795.jpg'}

In [6]:
# loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=-100)  # commonly used for language modeling with padding
optimizer = AdamW(dec_model.parameters(), lr=5e-5, weight_decay=0.01)

In [None]:
###
### forward data
###

### encoder parts

# convert image to emb
inputs_v = processor(images=img, return_tensors="pt").pixel_values
outs_v1 = clip_m_v(inputs_v).pooler_output
outs_v_fin = clip_m_v_proj(outs_v1)

# projection layer
projected_image_embed = lin_m_v(outs_v_fin)  # shape: (batch_size, prefix_length * lm_embed_dim)
projected_image_embed = projected_image_embed.view(batch_size, prefix_length, lm_embed_v_dim)
print(projected_image_embed.shape) # (batch_size, prefix_length, lm_embed_dim)

# convert text to emb
inputs_t = processor(text=text, return_tensors="pt", padding=True)
outs_t1 = clip_m_t(**inputs_t).last_hidden_state
outs_t_fin = clip_m_t_proj(outs_t1) # shape: (batch_size, seq_len, lm_embed_dim)
projected_text_embed = lin_m_t(outs_t_fin)  # shape: (batch_size, prefix_length * lm_embed_dim)
print(projected_text_embed.shape)
# outs_t_fin.shape # torch.Size([1, 7, 512])

# Concatenate image prefix embeddings with token embeddings
inputs_embeds = torch.cat([projected_image_embed, projected_text_embed], dim=1)
# inputs_embeds.shape # [batch_size, img_seq+text_seq, model_d_sub]
inputs_embeds4dec_m = lin_m_vt4dec(inputs_embeds)
print(f"final emb size from encoder {inputs_embeds4dec_m.shape}") # torch.Size([1, 23, 768]) or [batch, img+txt, dec_dim]

# Modify attention mask to accommodate prefix tokens
prefix_attention_mask = torch.ones(batch_size, prefix_length) #.to(device)
attention_mask = torch.cat([prefix_attention_mask, inputs_t.attention_mask], dim=1)

# set label
# Prepare labels
text_with_eos = [dec_tokenizer.eos_token + t + dec_tokenizer.eos_token for t in text]  # '<|endoftext|>'
labels = dec_tokenizer(text_with_eos, return_tensors="pt", padding=True).input_ids
# labels = dec_tokenizer(text, return_tensors="pt", padding=True).input_ids

# Pad labels to match full input length (prefix + text)
# Set prefix part to -100 so loss is ignored there
padding_labels = torch.full((batch_size, prefix_length), -100)  # ignore loss here
labels = torch.cat([padding_labels, labels], dim=1)  # (batch_size, prefix_len + text_len)

### decoder part

out_dec = dec_model(
    inputs_embeds=inputs_embeds4dec_m, 
    attention_mask=attention_mask, 
    labels=labels,  # this triggers loss computation
    )


In [None]:
###
### update loss
###

loss = out_dec.loss
loss
