<a href="https://colab.research.google.com/github/nguyenminhvuinfo/250505-mern/blob/main/Simple_VQA_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##Setup

In [84]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json
import torch
import torch.nn as nn
from transformers import Trainer, TrainingArguments, AutoTokenizer, DataCollatorWithPadding, AutoModel, AutoModelForCausalLM, AutoImageProcessor
from torch.utils.data import DataLoader
from PIL import Image
import torchvision.transforms as T
from datasets import load_dataset, Dataset

In [85]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [86]:
IMAGE_ENCODER_MODEL = 'google/vit-base-patch16-224'
TEXT_ENCODER_MODEL = 'distilbert/distilbert-base-cased-distilled-squad'
DECODER_MODEL = 'gpt2'

In [87]:
decoder_tokenizer = AutoTokenizer.from_pretrained(DECODER_MODEL)
decoder_tokenizer.pad_token = decoder_tokenizer.eos_token
decoder_tokenizer.padding_side = "left"

##Building the model

In [88]:
class MultimodalVQAModel(nn.Module):

    def __init__(self, text_encoder_model: str, image_encoder_model: str, decoder_model: str):
        super(MultimodalVQAModel, self).__init__()

        # Load pre-trained models
        self.text_encoder = AutoModel.from_pretrained(text_encoder_model).to(device)
        self.image_encoder = AutoModel.from_pretrained(image_encoder_model).to(device)
        self.decoder = AutoModelForCausalLM.from_pretrained(
            decoder_model,
            add_cross_attention=True,
            tie_word_embeddings=True
        ).to(device)

        # Linear layers to project text and image features to the decoder's hidden size
        self.text_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size)
        self.image_proj = nn.Linear(self.image_encoder.config.hidden_size, self.decoder.config.hidden_size)

    def forward(self, input_text, input_image, decoder_input_ids, attention_mask, labels=None):
        # Encode text
        text_features = self.encode_text(input_text, attention_mask)
        # Encode image
        image_features = self.encode_image(input_image)

        combined_features = (text_features + image_features) / 2

        # Decoding
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            labels=labels,
            encoder_hidden_states=combined_features.unsqueeze(1),
        )
        return decoder_outputs

    def encode_text(self, input_text, attention_mask):
        text_outputs = self.text_encoder(input_text, attention_mask=attention_mask)
        text_features = text_outputs.last_hidden_state.mean(dim=1)
        return self.text_proj(text_features)

    def encode_image(self, input_image):
        image_outputs = self.image_encoder(input_image)
        image_features = image_outputs.pooler_output
        return self.image_proj(image_features)

##Data Processing

In [89]:
from datasets import load_dataset

dataset = load_dataset("HuggingFaceM4/A-OKVQA")
train_ds = dataset['train']
val_ds = dataset['validation']
test_ds = dataset['test']


In [90]:
print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}")
train_ds[0]

Train: 17056, Val: 1145, Test: 6702


{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x480>,
 'question_id': '22MexNkBPpdZGX6sxbxVBH',
 'question': 'What is the man by the bags awaiting?',
 'choices': ['skateboarder', 'train', 'delivery', 'cab'],
 'correct_choice_idx': 3,
 'direct_answers': "['ride', 'ride', 'bus', 'taxi', 'travelling', 'traffic', 'taxi', 'cab', 'cab', 'his ride']",
 'difficult_direct_answer': False,
 'rationales': ['A train would not be on the street, he would not have luggage waiting for a delivery, and the skateboarder is there and not paying attention to him so a cab is the only possible answer.',
  'He has bags as if he is going someone, and he is on a road waiting for vehicle that can only be moved on the road and is big enough to hold the bags.',
  'He looks to be waiting for a paid ride to pick him up.']}

In [91]:
text_encoder_tokenizer = AutoTokenizer.from_pretrained(TEXT_ENCODER_MODEL)
image_feature_extractor = AutoImageProcessor.from_pretrained(IMAGE_ENCODER_MODEL)

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


In [92]:
import ast

def data_collator(batch):
    # Text inputs
    text_inputs = [sample['question'] for sample in batch]
    text_tensors = text_encoder_tokenizer(text_inputs, padding=True, return_tensors="pt")

    # Image inputs
    images = [sample['image'].convert('RGB') for sample in batch]
    image_inputs = image_feature_extractor(images, return_tensors="pt")
    image_tensors = image_inputs['pixel_values']

    # ✅ FIX: Parse string thành list trước
    target_inputs = []
    for sample in batch:
        # Parse string thành list
        answers = ast.literal_eval(sample['direct_answers'])
        # Lấy answer đầu tiên
        answer = answers[0]
        target_inputs.append(f"<|endoftext|>{answer}<|endoftext|>")

    target_tensors = decoder_tokenizer(target_inputs, padding=True, return_tensors="pt")

    # Labels
    labels = target_tensors["input_ids"].clone()
    labels = torch.where((labels == decoder_tokenizer.pad_token_id), -100, labels)
    labels[:, -1] = decoder_tokenizer.eos_token_id

    return {
        "input_text": text_tensors["input_ids"],
        "attention_mask": text_tensors["attention_mask"],
        "input_image": image_tensors,
        "decoder_input_ids": target_tensors["input_ids"],
        "labels": labels
    }

In [93]:
# Test data_collator với fix mới
batch = data_collator([train_ds[0], train_ds[1]])

print("Sample 0:")
print(f"  Question: {train_ds[0]['question']}")
print(f"  Direct answers (raw): {train_ds[0]['direct_answers']}")
print(f"  Decoder input: {decoder_tokenizer.decode(batch['decoder_input_ids'][0])}")
print(f"  Labels: {decoder_tokenizer.decode([t for t in batch['labels'][0] if t != -100])}")

print("\nSample 1:")
print(f"  Question: {train_ds[1]['question']}")
print(f"  Direct answers (raw): {train_ds[1]['direct_answers']}")
print(f"  Decoder input: {decoder_tokenizer.decode(batch['decoder_input_ids'][1])}")
print(f"  Labels: {decoder_tokenizer.decode([t for t in batch['labels'][1] if t != -100])}")

Sample 0:
  Question: What is the man by the bags awaiting?
  Direct answers (raw): ['ride', 'ride', 'bus', 'taxi', 'travelling', 'traffic', 'taxi', 'cab', 'cab', 'his ride']
  Decoder input: <|endoftext|>ride<|endoftext|>
  Labels: ride<|endoftext|>

Sample 1:
  Question: Where does this man eat pizza?
  Direct answers (raw): ['work', 'office', 'work', 'work', 'at work', 'desk', 'at desk', 'office', 'work desk', 'office']
  Decoder input: <|endoftext|>work<|endoftext|>
  Labels: work<|endoftext|>


In [94]:
batch = data_collator([train_ds[0], train_ds[1]])
print("Input text shape:", batch['input_text'].shape)
print("Input image shape:", batch['input_image'].shape)
print("Labels shape:", batch['labels'].shape)

Input text shape: torch.Size([2, 11])
Input image shape: torch.Size([2, 3, 224, 224])
Labels shape: torch.Size([2, 3])


##Training the multimodal VQA model


In [95]:
training_args = TrainingArguments(
    output_dir="./aokvqa_output",
    num_train_epochs=3,
    fp16=True,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_safetensors=False,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=128,
    logging_steps=10,
    report_to='none',
    warmup_ratio=0.1,
    learning_rate=2e-5,
    lr_scheduler_type='cosine',
    load_best_model_at_end=True,
    metric_for_best_model='eval_loss',
    remove_unused_columns=False,
    dataloader_num_workers=4,
)

In [96]:
model = MultimodalVQAModel(
    TEXT_ENCODER_MODEL,
    IMAGE_ENCODER_MODEL,
    DECODER_MODEL
).to(device)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['transformer.h.0.crossattention.c_attn.bias', 'transformer.h.0.crossattention.c_attn.weight', 'transformer.h.0.crossattention.c_proj.bias', 'transformer.h.0.crossattention.c_proj.weight', 'transformer.h.0.crossattention.q_attn.bias', 'transformer.h.0.crossattention.q_attn.weight', 'transformer.h.0.ln_cross_attn.bias', 'transformer.h.0.ln_cross_attn.weight', 'transformer.h.1.crossattention.c_attn.bias', 'transformer.h.1.crossattention.c_attn.weight', 'transformer.h.1.crossattention.c_proj.bias', 'transformer.h.1.crossattention.c_proj.weight', 'transformer.h.1.crossattention.q_attn.bias', 'tran

In [97]:
num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_trainable_params}")

Number of trainable parameters: 305174784


In [98]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=data_collator,
)

In [99]:
trainer.train()

Epoch,Training Loss,Validation Loss
1,3.1705,2.912757
2,2.6152,2.739261
3,2.5897,2.697063


TrainOutput(global_step=1599, training_loss=3.1480240288043184, metrics={'train_runtime': 694.1198, 'train_samples_per_second': 73.716, 'train_steps_per_second': 2.304, 'total_flos': 0.0, 'train_loss': 3.1480240288043184, 'epoch': 3.0})

In [100]:
# Save model
torch.save(model.state_dict(), "aokvqa_model.pt")
print("✅ Model saved as aokvqa_model.pt")

✅ Model saved as aokvqa_model.pt


##TEST


In [None]:
# Test với sample từ train_ds
test_model_train = lambda idx: test_model_from_dataset(idx, train_ds)

def test_model_from_dataset(idx, dataset):
    sample = dataset[idx]

    plt.imshow(sample['image'])
    plt.axis('off')
    plt.title(f"Q: {sample['question']}")
    plt.show()

    print(f"Ground Truth: {sample['direct_answers']}")

    batch = data_collator([sample])
    batch = {k: v.to(device) for k, v in batch.items()}

    model.eval()
    with torch.no_grad():
        text_features = model.encode_text(batch["input_text"], batch["attention_mask"])
        image_features = model.encode_image(batch["input_image"])
        combined_features = (text_features + image_features) / 2

        attention_mask = torch.ones((1, 1), dtype=torch.long, device=device)

        generated_ids = model.decoder.generate(
            encoder_hidden_states=combined_features.unsqueeze(1),
            attention_mask=attention_mask,
            max_length=20,
            num_beams=3,
            pad_token_id=decoder_tokenizer.eos_token_id
        )

        prediction = decoder_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        print(f"Prediction: {prediction}")

# Test 10 ví dụ từ tập train
for i in range(10):
    print(f"\n{'='*50}")
    print(f"EXAMPLE {i+1}")
    print('='*50)
    test_model_from_dataset(i, train_ds)

In [107]:
# ===== KIỂM TRA QUAN TRỌNG NHẤT =====
print("="*80)
print("🔍 KIỂM TRA: Model đã được train chưa?")
print("="*80)

# Kiểm tra 1: Có file checkpoint không?
import os
checkpoint_dir = "checkpoints"  # Thay bằng folder của bạn
if os.path.exists(checkpoint_dir):
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pth') or f.endswith('.pt')]
    print(f"✅ Tìm thấy {len(checkpoints)} checkpoints: {checkpoints}")
else:
    print(f"❌ KHÔNG TÌM THẤY FOLDER CHECKPOINT!")
    print("   => Model đang dùng RANDOM WEIGHTS - Chưa train!")

# Kiểm tra 2: Model có ở evaluation mode không?
print(f"\nModel training mode: {model.training}")
if model.training:
    print("⚠️ WARNING: Model đang ở TRAINING mode, phải chuyển sang eval!")
    model.eval()

# Kiểm tra 3: Decoder có được train không?
decoder_trainable_params = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad)
decoder_total_params = sum(p.numel() for p in model.decoder.parameters())
print(f"\nDecoder trainable params: {decoder_trainable_params:,} / {decoder_total_params:,}")

if decoder_trainable_params == 0:
    print("❌ DECODER BỊ FREEZE - Không thể train được!")
    print("   => Phải unfreeze decoder trước khi train")

# ===== NGUYÊN NHÂN: Bạn chưa train model =====
# ➡️ GIẢI PHÁP: Train model trước khi test!

🔍 KIỂM TRA: Model đã được train chưa?
❌ KHÔNG TÌM THẤY FOLDER CHECKPOINT!
   => Model đang dùng RANDOM WEIGHTS - Chưa train!

Model training mode: False

Decoder trainable params: 152,806,656 / 152,806,656
