In [1]:
import torch
from transformers import BlipProcessor, BlipForQuestionAnswering
from PIL import Image
import requests
import torch

import torch.nn as nn
import torch.optim as optim

In [2]:
from datasets import load_dataset

# Load OK-VQA dataset
ds_dict = load_dataset("lmms-lab/OK-VQA")
full_ds = ds_dict["val2014"]  # total 5046 rows

# First split: 70% train, 30% temp (val + test)
split_1 = full_ds.train_test_split(test_size=0.30, seed=42)
train_ds = split_1["train"]
temp_ds = split_1["test"]

# Second split: 15% val, 15% test from the 30% temp
split_2 = temp_ds.train_test_split(test_size=0.5, seed=42)
val_ds = split_2["train"]
test_ds = split_2["test"]

# Confirm sizes
print(f"Train size: {len(train_ds)}")
print(f"Validation size: {len(val_ds)}")
print(f"Test size: {len(test_ds)}")


Found cached dataset parquet (/Users/mngtn/.cache/huggingface/datasets/lmms-lab___parquet/lmms-lab--OK-VQA-134cd3ac306f3257/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/1 [00:00<?, ?it/s]

Loading cached split indices for dataset at /Users/mngtn/.cache/huggingface/datasets/lmms-lab___parquet/lmms-lab--OK-VQA-134cd3ac306f3257/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-5dc0d48e031110cc.arrow and /Users/mngtn/.cache/huggingface/datasets/lmms-lab___parquet/lmms-lab--OK-VQA-134cd3ac306f3257/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-3ffe7268338e8169.arrow
Loading cached split indices for dataset at /Users/mngtn/.cache/huggingface/datasets/lmms-lab___parquet/lmms-lab--OK-VQA-134cd3ac306f3257/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-701a2d3e4aa1093b.arrow and /Users/mngtn/.cache/huggingface/datasets/lmms-lab___parquet/lmms-lab--OK-VQA-134cd3ac306f3257/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-46d3932ecec0d6f3.arrow


Train size: 3532
Validation size: 757
Test size: 757


In [3]:
train_ds[0]

{'question_id': '3625635',
 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x427>,
 'question': 'What position is this person playing?',
 'answers': ['shortstop',
  'shortstop',
  'shortstop',
  'shortstop',
  'outfielder',
  'outfielder',
  'catcher',
  'catcher',
  'first base',
  'first base'],
 'question_type': 'Sports and Recreation',
 'answer_type': 'other'}

In [4]:

processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
inputs = processor(images=train_ds[0]['image'], text=train_ds[0]['question'], return_tensors="pt")
        
img_embedding = model.vision_model(inputs["pixel_values"]).last_hidden_state
text_embedding = model.text_encoder(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).last_hidden_state

hidden_size = 512



In [5]:
print(img_embedding.shape)
print(text_embedding.shape)

torch.Size([1, 577, 768])
torch.Size([1, 9, 768])


In [6]:
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")

In [7]:
for param in model.vision_model.parameters():
    param.requires_grad = False

for param in model.text_encoder.parameters():
    param.requires_grad = False


In [None]:
class BLIPFinetuer(nn.Module):

    def __init__(self, model, processor, hidden_size):
        self.model = model
        self.hidden_size = hidden_size
        self.processor = processor
        self.hidden_size_img = 768
        self.hidden_size_text = 768

        self.linear_img = nn.Sequential(
            nn.Linear(self.hidden_size_img, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size_img)
        )

        self.linear_text = nn.Sequential(
            nn.Linear(self.hidden_size_text, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size_text)
        )

        self.combine = nn.MultiheadAttention(self.hidden_size, 16)

    def forward(self, text, img):
        inputs = self.processor(images=img, text=text, return_tensors="pt")
        
        img_embedding = model.vision_model(inputs["pixel_values"]).last_hidden_state
        text_embedding = model.text_encoder(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).last_hidden_state

        img_feature = self.linear_img(img_embedding)
        text_feature = self.linear_text(text_embedding)

        # here we need to figure out a way to combine the img_feature and the text_feature
        # before feeding them into the text decoder. 
        # my suggested way would be to have a multi-head cross attention layer and have the 2 features
        # concatenated as input. so the process would be:
        # concat(img_feature, text_feature) -> multi-head attention -> text_decoder.
        # now the attention layer clearly does not work and needs fixing.

        combined_feature, _ = self.combine(torch.concat([text_feature, img_feature], 1))
        output = self.model.text_decoder(encoder_hidden_states=combined_feature)
        return output
    
    def generate(self, text, img, max_length, input_ids):
        self.eval()
        with torch.no_grad():
            inputs = self.processor(images=img, text=text, return_tensors="pt")
        
            img_embedding = model.vision_model(inputs["pixel_values"]).last_hidden_state
            text_embedding = model.text_encoder(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).last_hidden_state

            img_feature = self.linear_img(img_embedding)
            text_feature = self.linear_text(text_embedding)

            combined_feature, _ = self.combine(torch.concat([text_feature, img_feature], 1))

            output_seq = input_ids.clone()

            for i in range(max_length):
                mask = torch.ones_like(output_seq)
                outputs = self.model.text_decoder(
                    input_ids=output_seq,
                    attention_mask=mask,
                    encoder_hidden_states=combined_feature,
                    return_dict=True
                )
                logits = outputs.logits  
                next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
                output_seq = torch.cat([output_seq, next_token], dim=-1)
                if next_token.item() == self.processor.tokenizer.eos_token_id:
                    break
            return output_seq











        





In [None]:

image_embedding_dim = model.vision_model.config.hidden_size
text_embedding_dim = model.text_encoder.config.hidden_size

# Define the dense layers
image_dense = nn.Linear(image_embedding_dim, 512)  # Example output size of 512
text_dense = nn.Linear(text_embedding_dim, 512)    # Example output size of 512

# Initialize the optimizer, only training the parameters of the dense layers
optimizer = optim.Adam(list(image_dense.parameters()) + list(text_dense.parameters()), lr=1e-4)

# Example input and forward pass (simplified for demonstration)
image_url = "https://i.imgur.com/N601nO1.jpg"
image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
text = "What is the person doing?"
inputs = processor(images=image, text=text, return_tensors="pt")

# Forward pass through the image encoder and the dense layer
image_encoder_output = model.vision_model(inputs["pixel_values"])
image_embeddings = image_encoder_output.last_hidden_state
processed_image_embeddings = image_dense(image_embeddings)

# Forward pass through the text encoder and the dense layer
text_encoder_output = model.text_encoder(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
text_embeddings = text_encoder_output.last_hidden_state
processed_text_embeddings = text_dense(text_embeddings)

# Assuming you have a decoder part of your custom model that takes these processed embeddings
# For BLIP VQA, the 'model' itself acts as the combined encoder and decoder
# We'll perform a simplified forward pass to illustrate the concept

# Concatenate or otherwise fuse the processed embeddings if needed for the decoder
# This part depends heavily on your custom decoder architecture
fused_embeddings = torch.cat((processed_image_embeddings[:, 0, :], processed_text_embeddings[:, 0, :]), dim=-1) # Example fusion

# Since we froze the original model's parameters, any direct call to 'model' for VQA
# will use the frozen weights. If you have a separate decoder, you'd feed 'fused_embeddings' into it.

# For demonstration with the original BLIP model (with frozen weights),
# we can try a forward pass for VQA using the original processor's format:
vqa_outputs = model(**inputs) # This will use the frozen encoders and the VQA head
print("VQA Outputs (using frozen model):", vqa_outputs.logits.shape)

# In a training loop:
# 1. Zero the gradients
optimizer.zero_grad()
# 2. Perform the forward pass (as shown above, potentially through your custom decoder)
# 3. Calculate the loss based on the decoder's output and your target
# 4. Backpropagate the gradients only through the trainable parameters (dense layers)
#    loss.backward()
# 5. Update the parameters
#    optimizer.step()

# Check which parameters are being trained
trainable_params = 0
for name, param in model.named_parameters():
    if param.requires_grad:
        trainable_params += param.numel()
        print(f"Trainable parameter: {name}")

for name, param in image_dense.named_parameters():
    if param.requires_grad:
        trainable_params += param.numel()
        print(f"Trainable parameter (image dense): {name}")

for name, param in text_dense.named_parameters():
    if param.requires_grad:
        trainable_params += param.numel()
        print(f"Trainable parameter (text dense): {name}")

print(f"Total trainable parameters: {trainable_params}")