In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
from datasets import load_dataset
import torch
from transformers import ViTModel, BertModel, BertTokenizer
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from datasets import load_dataset
from PIL import Image
import io
import json
import os
from torchvision import transforms
from huggingface_hub import HfApi, login


2025-05-16 08:03:35.375840: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747382615.563754      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747382615.617757      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
ds = load_dataset('khoadole/cars_8k_balance_dataset_full_augmented_v2')

README.md:   0%|          | 0.00/670 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/42.9M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/13.8M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/14.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/15468 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4824 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5154 [00:00<?, ? examples/s]

In [3]:
color_list = ['blue', 'white', 'black', 'gray', 'silver']
brand_list = ['bentley', 'audi', 'bmw', 'acura']

def get_answer_type(answer):
    if answer in color_list:
        return 'color'
    elif answer in brand_list:
        return 'brand'
    else:
        return 'car_name'

In [4]:
class VQADataset(Dataset):
    def __init__(self, dataset, tokenizer, answer_to_idx):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.answer_to_idx = answer_to_idx
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        # Định nghĩa mapping cho answer_type
        self.answer_type_map = {'color': 0, 'brand': 1, 'car_name': 2}

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        image_bytes = sample['image']['bytes']
        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        image_tensor = self.transform(image)
        question = sample['question']
        tokenized = self.tokenizer(question, padding='max_length', truncation=True, max_length=32, return_tensors='pt')
        input_ids = tokenized['input_ids'].squeeze(0)
        attention_mask = tokenized['attention_mask'].squeeze(0)
        answer = sample['answer']
        answer_idx = self.answer_to_idx.get(answer, -1)
        answer_type = torch.tensor(self.answer_type_map[get_answer_type(answer)], dtype=torch.long)  # Chuyển thành tensor
        return image_tensor, input_ids, attention_mask, answer_idx, answer_type

In [10]:
class VQAModel(nn.Module):
    def __init__(self, num_answers):
        super(VQAModel, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.gate = nn.Sequential(
            nn.Linear(768 * 2, 768),
            nn.Sigmoid()
        )
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(768, 512),
            nn.GELU(),
            nn.Linear(512, num_answers)
        )
        for param in self.bert.encoder.layer[:2].parameters():
            param.requires_grad = False

    def forward(self, image, input_ids, attention_mask):
        image_features = self.vit(image).last_hidden_state[:, 0, :]  # [batch, 768]
        text_features = self.bert(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]  # [batch, 768]
        
        # Gated fusion
        combined_features = torch.cat([image_features, text_features], dim=1)  # [batch, 768*2]
        gate = self.gate(combined_features)  # [batch, 768]
        fused_features = gate * image_features + (1 - gate) * text_features  # [batch, 768]
        
        output = self.classifier(fused_features)
        return output

In [11]:
all_train_answers = list(set(sample['answer'] for sample in ds['train']))
answer_to_idx = {answer: idx for idx, answer in enumerate(all_train_answers)}

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
train_dataset = VQADataset(ds['train'], tokenizer, answer_to_idx)
val_dataset = VQADataset(ds['validation'], tokenizer, answer_to_idx)
test_dataset = VQADataset(ds['test'], tokenizer, answer_to_idx)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

# Khởi tạo mô hình và tối ưu
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = VQAModel(num_answers=len(answer_to_idx)).to(device)
optimizer = AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)
scaler = GradScaler()
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=3, verbose=True)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  scaler = GradScaler()


In [12]:
# Training loop
num_epochs = 30
best_val_acc = 0
patience = 5
patience_counter = 0
color_loss_weight = 2.0
answer_type_map_reverse = {0: 'color', 1: 'brand', 2: 'car_name'}  # Để map ngược lại từ số sang string

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    total_color_loss = 0
    total_brand_loss = 0
    total_car_name_loss = 0
    color_count = 0
    brand_count = 0
    car_name_count = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
        image, input_ids, attention_mask, answer_idx, answer_type = [x.to(device) for x in batch]
        valid_mask = answer_idx != -1
        if not valid_mask.any():
            continue
        # Lọc các tensor bằng valid_mask
        image = image[valid_mask]
        input_ids = input_ids[valid_mask]
        attention_mask = attention_mask[valid_mask]
        answer_idx = answer_idx[valid_mask]
        answer_type = answer_type[valid_mask]  # answer_type giờ là tensor

        optimizer.zero_grad()
        with autocast():
            output = model(image, input_ids, attention_mask)
            loss = loss_fn(output, answer_idx)
            weighted_loss = torch.zeros_like(loss)
            for i in range(len(answer_type)):
                ans_type_str = answer_type_map_reverse[answer_type[i].item()]  # Map ngược lại thành string
                if ans_type_str == 'color':
                    weighted_loss += loss_fn(output[i].unsqueeze(0), answer_idx[i].unsqueeze(0)) * color_loss_weight
                    total_color_loss += loss_fn(output[i].unsqueeze(0), answer_idx[i].unsqueeze(0)).item()
                    color_count += 1
                elif ans_type_str == 'brand':
                    weighted_loss += loss_fn(output[i].unsqueeze(0), answer_idx[i].unsqueeze(0))
                    total_brand_loss += loss_fn(output[i].unsqueeze(0), answer_idx[i].unsqueeze(0)).item()
                    brand_count += 1
                else:
                    weighted_loss += loss_fn(output[i].unsqueeze(0), answer_idx[i].unsqueeze(0))
                    total_car_name_loss += loss_fn(output[i].unsqueeze(0), answer_idx[i].unsqueeze(0)).item()
                    car_name_count += 1
        scaler.scale(weighted_loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)
    avg_color_loss = total_color_loss / color_count if color_count > 0 else 0
    avg_brand_loss = total_brand_loss / brand_count if brand_count > 0 else 0
    avg_car_name_loss = total_car_name_loss / car_name_count if car_name_count > 0 else 0
    print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}")
    print(f"Color Loss: {avg_color_loss:.4f}, Brand Loss: {avg_brand_loss:.4f}, Car Name Loss: {avg_car_name_loss:.4f}")

    # Validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
            image, input_ids, attention_mask, answer_idx, _ = [x.to(device) for x in batch]
            valid_mask = answer_idx != -1
            if not valid_mask.any():
                continue
            image = image[valid_mask]
            input_ids = input_ids[valid_mask]
            attention_mask = attention_mask[valid_mask]
            answer_idx = answer_idx[valid_mask]
            with autocast():
                output = model(image, input_ids, attention_mask)
            pred = output.argmax(dim=1)
            correct += (pred == answer_idx).sum().item()
            total += answer_idx.size(0)
    val_accuracy = correct / total if total > 0 else 0
    print(f"Epoch {epoch+1}, Validation Accuracy: {val_accuracy:.4f}")

    scheduler.step(val_accuracy)
    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy
        patience_counter = 0
        torch.save(model.state_dict(), 'best_vqa_model.pth')
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered!")
            break

  with autocast():
Epoch 1/30 - Training: 100%|██████████| 484/484 [03:15<00:00,  2.48it/s]


Epoch 1, Train Loss: 1.7888
Color Loss: 1.4489, Brand Loss: 1.4459, Car Name Loss: 2.4740


  with autocast():
Epoch 1/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.30it/s]


Epoch 1, Validation Accuracy: 0.7210


Epoch 2/30 - Training: 100%|██████████| 484/484 [03:17<00:00,  2.45it/s]


Epoch 2, Train Loss: 1.1078
Color Loss: 1.0269, Brand Loss: 0.7525, Car Name Loss: 1.5437


Epoch 2/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.26it/s]


Epoch 2, Validation Accuracy: 0.7817


Epoch 3/30 - Training: 100%|██████████| 484/484 [03:18<00:00,  2.44it/s]


Epoch 3, Train Loss: 0.9350
Color Loss: 0.8323, Brand Loss: 0.6978, Car Name Loss: 1.2738


Epoch 3/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.23it/s]


Epoch 3, Validation Accuracy: 0.8128


Epoch 4/30 - Training: 100%|██████████| 484/484 [03:17<00:00,  2.45it/s]


Epoch 4, Train Loss: 0.8088
Color Loss: 0.7199, Brand Loss: 0.6734, Car Name Loss: 1.0336


Epoch 4/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.24it/s]


Epoch 4, Validation Accuracy: 0.8387


Epoch 5/30 - Training: 100%|██████████| 484/484 [03:17<00:00,  2.45it/s]


Epoch 5, Train Loss: 0.7275
Color Loss: 0.6828, Brand Loss: 0.6583, Car Name Loss: 0.8415


Epoch 5/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.25it/s]


Epoch 5, Validation Accuracy: 0.8472


Epoch 6/30 - Training: 100%|██████████| 484/484 [03:17<00:00,  2.45it/s]


Epoch 6, Train Loss: 0.6847
Color Loss: 0.6636, Brand Loss: 0.6498, Car Name Loss: 0.7408


Epoch 6/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.28it/s]


Epoch 6, Validation Accuracy: 0.8557


Epoch 7/30 - Training: 100%|██████████| 484/484 [03:17<00:00,  2.45it/s]


Epoch 7, Train Loss: 0.6678
Color Loss: 0.6603, Brand Loss: 0.6439, Car Name Loss: 0.6991


Epoch 7/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.25it/s]


Epoch 7, Validation Accuracy: 0.8628


Epoch 8/30 - Training: 100%|██████████| 484/484 [03:17<00:00,  2.45it/s]


Epoch 8, Train Loss: 0.6553
Color Loss: 0.6476, Brand Loss: 0.6405, Car Name Loss: 0.6779


Epoch 8/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.27it/s]


Epoch 8, Validation Accuracy: 0.8667


Epoch 9/30 - Training: 100%|██████████| 484/484 [03:18<00:00,  2.44it/s]


Epoch 9, Train Loss: 0.6599
Color Loss: 0.6469, Brand Loss: 0.6391, Car Name Loss: 0.6939


Epoch 9/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.24it/s]


Epoch 9, Validation Accuracy: 0.8653


Epoch 10/30 - Training: 100%|██████████| 484/484 [03:17<00:00,  2.45it/s]


Epoch 10, Train Loss: 0.6646
Color Loss: 0.6676, Brand Loss: 0.6418, Car Name Loss: 0.6845


Epoch 10/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.26it/s]


Epoch 10, Validation Accuracy: 0.7964


Epoch 11/30 - Training: 100%|██████████| 484/484 [03:17<00:00,  2.45it/s]


Epoch 11, Train Loss: 0.6945
Color Loss: 0.7009, Brand Loss: 0.6589, Car Name Loss: 0.7234


Epoch 11/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.23it/s]


Epoch 11, Validation Accuracy: 0.8481


Epoch 12/30 - Training: 100%|██████████| 484/484 [03:16<00:00,  2.46it/s]


Epoch 12, Train Loss: 0.6601
Color Loss: 0.6653, Brand Loss: 0.6496, Car Name Loss: 0.6654


Epoch 12/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.29it/s]


Epoch 12, Validation Accuracy: 0.8630


Epoch 13/30 - Training: 100%|██████████| 484/484 [03:15<00:00,  2.47it/s]


Epoch 13, Train Loss: 0.6415
Color Loss: 0.6389, Brand Loss: 0.6360, Car Name Loss: 0.6497


Epoch 13/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.28it/s]


Epoch 13, Validation Accuracy: 0.8688


Epoch 14/30 - Training: 100%|██████████| 484/484 [03:16<00:00,  2.47it/s]


Epoch 14, Train Loss: 0.6404
Color Loss: 0.6382, Brand Loss: 0.6353, Car Name Loss: 0.6479


Epoch 14/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.25it/s]


Epoch 14, Validation Accuracy: 0.8698


Epoch 15/30 - Training: 100%|██████████| 484/484 [03:16<00:00,  2.46it/s]


Epoch 15, Train Loss: 0.6388
Color Loss: 0.6365, Brand Loss: 0.6341, Car Name Loss: 0.6459


Epoch 15/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.29it/s]


Epoch 15, Validation Accuracy: 0.8694


Epoch 16/30 - Training: 100%|██████████| 484/484 [03:16<00:00,  2.46it/s]


Epoch 16, Train Loss: 0.6380
Color Loss: 0.6357, Brand Loss: 0.6336, Car Name Loss: 0.6449


Epoch 16/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.28it/s]


Epoch 16, Validation Accuracy: 0.8704


Epoch 17/30 - Training: 100%|██████████| 484/484 [03:16<00:00,  2.46it/s]


Epoch 17, Train Loss: 0.6374
Color Loss: 0.6359, Brand Loss: 0.6331, Car Name Loss: 0.6434


Epoch 17/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.26it/s]


Epoch 17, Validation Accuracy: 0.8698


Epoch 19/30 - Training: 100%|██████████| 484/484 [03:16<00:00,  2.46it/s]


Epoch 19, Train Loss: 0.6365
Color Loss: 0.6349, Brand Loss: 0.6323, Car Name Loss: 0.6423


Epoch 19/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.24it/s]


Epoch 19, Validation Accuracy: 0.8717


Epoch 20/30 - Training: 100%|██████████| 484/484 [03:16<00:00,  2.46it/s]


Epoch 20, Train Loss: 0.6365
Color Loss: 0.6355, Brand Loss: 0.6320, Car Name Loss: 0.6420


Epoch 20/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.27it/s]


Epoch 20, Validation Accuracy: 0.8721


Epoch 21/30 - Training: 100%|██████████| 484/484 [03:17<00:00,  2.45it/s]


Epoch 21, Train Loss: 0.6362
Color Loss: 0.6349, Brand Loss: 0.6320, Car Name Loss: 0.6416


Epoch 21/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.26it/s]


Epoch 21, Validation Accuracy: 0.8750


Epoch 22/30 - Training: 100%|██████████| 484/484 [03:16<00:00,  2.46it/s]


Epoch 22, Train Loss: 0.6356
Color Loss: 0.6345, Brand Loss: 0.6316, Car Name Loss: 0.6408


Epoch 22/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.24it/s]


Epoch 22, Validation Accuracy: 0.8758


Epoch 23/30 - Training: 100%|██████████| 484/484 [03:16<00:00,  2.46it/s]


Epoch 23, Train Loss: 0.6358
Color Loss: 0.6349, Brand Loss: 0.6316, Car Name Loss: 0.6410


Epoch 23/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.29it/s]


Epoch 23, Validation Accuracy: 0.8787


Epoch 24/30 - Training: 100%|██████████| 484/484 [03:16<00:00,  2.46it/s]


Epoch 24, Train Loss: 0.6349
Color Loss: 0.6337, Brand Loss: 0.6310, Car Name Loss: 0.6401


Epoch 24/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.29it/s]


Epoch 24, Validation Accuracy: 0.8756


Epoch 25/30 - Training: 100%|██████████| 484/484 [03:17<00:00,  2.46it/s]


Epoch 25, Train Loss: 0.6349
Color Loss: 0.6343, Brand Loss: 0.6307, Car Name Loss: 0.6397


Epoch 25/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.23it/s]


Epoch 25, Validation Accuracy: 0.8758


Epoch 26/30 - Training: 100%|██████████| 484/484 [03:17<00:00,  2.45it/s]


Epoch 26, Train Loss: 0.6346
Color Loss: 0.6342, Brand Loss: 0.6303, Car Name Loss: 0.6394


Epoch 26/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.27it/s]


Epoch 26, Validation Accuracy: 0.8789


Epoch 27/30 - Training: 100%|██████████| 484/484 [03:18<00:00,  2.44it/s]


Epoch 27, Train Loss: 0.6343
Color Loss: 0.6339, Brand Loss: 0.6302, Car Name Loss: 0.6390


Epoch 27/30 - Validation: 100%|██████████| 151/151 [00:18<00:00,  8.22it/s]


Epoch 27, Validation Accuracy: 0.8754


Epoch 28/30 - Training:  10%|▉         | 47/484 [00:19<03:02,  2.39it/s]


KeyboardInterrupt: 

In [13]:
from collections import defaultdict
# Test
model.load_state_dict(torch.load('best_vqa_model.pth', weights_only=True))
model.eval()
correct_color = 0
total_color = 0
correct_brand = 0
total_brand = 0
correct_car_name = 0
total_car_name = 0
total_correct = 0
total_samples = 0
total_color_loss = 0
total_brand_loss = 0
total_car_name_loss = 0
color_count = 0
brand_count = 0
car_name_count = 0
answer_type_map_reverse = {0: 'color', 1: 'brand', 2: 'car_name'}
idx_to_answer = {idx: answer for answer, idx in answer_to_idx.items()}  # Map ngược từ idx sang answer

# Theo dõi số lần xuất hiện, đúng, và nhầm lẫn
color_stats = defaultdict(lambda: {'correct': 0, 'total': 0, 'confusion': defaultdict(int)})
brand_stats = defaultdict(lambda: {'correct': 0, 'total': 0, 'confusion': defaultdict(int)})
car_name_stats = defaultdict(lambda: {'correct': 0, 'total': 0, 'confusion': defaultdict(int)})

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Test"):
        image, input_ids, attention_mask, answer_idx, answer_type = [x.to(device) for x in batch]
        valid_mask = (answer_idx != -1).to(device)
        if not valid_mask.any():
            continue
        image = image[valid_mask]
        input_ids = input_ids[valid_mask]
        attention_mask = attention_mask[valid_mask]
        answer_idx = answer_idx[valid_mask]
        answer_type = answer_type[valid_mask]

        with autocast():
            output = model(image, input_ids, attention_mask)
            loss = loss_fn(output, answer_idx)
        pred = output.argmax(dim=1)
        total_correct += (pred == answer_idx).sum().item()
        total_samples += answer_idx.size(0)

        for i in range(len(answer_type)):
            ans_type_str = answer_type_map_reverse[answer_type[i].item()]
            true_answer = idx_to_answer[answer_idx[i].item()]
            pred_answer = idx_to_answer[pred[i].item()]
            is_correct = pred[i] == answer_idx[i]

            if ans_type_str == 'color':
                total_color += 1
                total_color_loss += loss_fn(output[i].unsqueeze(0), answer_idx[i].unsqueeze(0)).item()
                color_count += 1
                if is_correct:
                    correct_color += 1
                color_stats[true_answer]['total'] += 1
                if is_correct:
                    color_stats[true_answer]['correct'] += 1
                else:
                    color_stats[true_answer]['confusion'][pred_answer] += 1

            elif ans_type_str == 'brand':
                total_brand += 1
                total_brand_loss += loss_fn(output[i].unsqueeze(0), answer_idx[i].unsqueeze(0)).item()
                brand_count += 1
                if is_correct:
                    correct_brand += 1
                brand_stats[true_answer]['total'] += 1
                if is_correct:
                    brand_stats[true_answer]['correct'] += 1
                else:
                    brand_stats[true_answer]['confusion'][pred_answer] += 1

            else:
                total_car_name += 1
                total_car_name_loss += loss_fn(output[i].unsqueeze(0), answer_idx[i].unsqueeze(0)).item()
                car_name_count += 1
                if is_correct:
                    correct_car_name += 1
                car_name_stats[true_answer]['total'] += 1
                if is_correct:
                    car_name_stats[true_answer]['correct'] += 1
                else:
                    car_name_stats[true_answer]['confusion'][pred_answer] += 1

# Tính accuracy tổng và riêng
test_color_accuracy = correct_color / total_color if total_color > 0 else 0
test_brand_accuracy = correct_brand / total_brand if total_brand > 0 else 0
test_car_name_accuracy = correct_car_name / total_car_name if total_car_name > 0 else 0
test_total_accuracy = total_correct / total_samples if total_samples > 0 else 0
avg_color_loss = total_color_loss / color_count if color_count > 0 else 0
avg_brand_loss = total_brand_loss / brand_count if brand_count > 0 else 0
avg_car_name_loss = total_car_name_loss / car_name_count if car_name_count > 0 else 0

# Tính top 4 sai nhiều nhất và nhầm lẫn
def get_top_errors(stats, category_name):
    error_rates = []
    for answer, stat in stats.items():
        total = stat['total']
        correct = stat['correct']
        if total > 0:
            error_rate = 1 - (correct / total)
            error_rates.append((answer, error_rate, total, stat['confusion']))
    error_rates.sort(key=lambda x: (x[1], x[2]), reverse=True)
    top_4 = error_rates[:4]
    print(f"\nTop 4 {category_name} sai nhiều nhất:")
    for answer, error_rate, total, confusion in top_4:
        print(f"- {answer}: {error_rate:.4f} ({(error_rate * 100):.2f}%), xuất hiện {total} lần")
        print(f"  Nhầm lẫn với:")
        for pred_answer, count in confusion.items():
            print(f"    + {pred_answer}: {count} lần")

# In kết quả
print(f"Test Total Accuracy: {test_total_accuracy:.4f}")
print(f"Test Color Accuracy: {test_color_accuracy:.4f}, Brand Accuracy: {test_brand_accuracy:.4f}, Car Name Accuracy: {test_car_name_accuracy:.4f}")
print(f"Test Color Loss: {avg_color_loss:.4f}, Brand Loss: {avg_brand_loss:.4f}, Car Name Loss: {avg_car_name_loss:.4f}")

get_top_errors(color_stats, "màu")
get_top_errors(brand_stats, "hãng xe")
get_top_errors(car_name_stats, "tên xe")

  with autocast():
Test: 100%|██████████| 162/162 [00:20<00:00,  7.84it/s]

Test Total Accuracy: 0.8576
Test Color Accuracy: 0.8423, Brand Accuracy: 0.9476, Car Name Accuracy: 0.7829
Test Color Loss: 1.1505, Brand Loss: 0.7713, Car Name Loss: 1.2820

Top 4 màu sai nhiều nhất:
- gray: 0.2735 (27.35%), xuất hiện 245 lần
  Nhầm lẫn với:
    + white: 10 lần
    + black: 31 lần
    + silver: 26 lần
- silver: 0.2500 (25.00%), xuất hiện 248 lần
  Nhầm lẫn với:
    + white: 23 lần
    + gray: 18 lần
    + blue: 8 lần
    + black: 13 lần
- blue: 0.1386 (13.86%), xuất hiện 339 lần
  Nhầm lẫn với:
    + silver: 17 lần
    + black: 27 lần
    + gray: 1 lần
    + white: 2 lần
- black: 0.1304 (13.04%), xuất hiện 368 lần
  Nhầm lẫn với:
    + white: 27 lần
    + silver: 6 lần
    + gray: 13 lần
    + blue: 2 lần

Top 4 hãng xe sai nhiều nhất:
- bmw: 0.0659 (6.59%), xuất hiện 425 lần
  Nhầm lẫn với:
    + bentley: 3 lần
    + audi: 22 lần
    + acura: 3 lần
- audi: 0.0611 (6.11%), xuất hiện 442 lần
  Nhầm lẫn với:
    + bmw: 13 lần
    + bentley: 6 lần
    + acura: 8 lần
- ac




In [14]:
save_path = "/kaggle/working/my-vqa-model"
os.makedirs(save_path, exist_ok=True)
torch.save(model.state_dict(), f"{save_path}/pytorch_model.bin")
config = {
    "vit_model": "google/vit-base-patch16-224",
    "bert_model": "bert-base-uncased",
    "num_answers": len(answer_to_idx),
    "architecture": "ViT+BERT gated fusion"
}
with open(f"{save_path}/config.json", "w") as f:
    json.dump(config, f)
tokenizer.save_pretrained(save_path)
with open(f"{save_path}/answer_list.json", "w") as f:
    json.dump(all_train_answers, f)