In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
from datasets import load_dataset
import torch
from transformers import ViTModel, BertModel, BertTokenizer
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from datasets import load_dataset
from PIL import Image
import io
import json
import os
from torchvision import transforms
from huggingface_hub import HfApi, login


2025-04-28 01:26:14.319976: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745803574.615284      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745803574.688288      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
ds = load_dataset('khoadole/cars_8k_balance_dataset_full_augmented_v2')

README.md:   0%|          | 0.00/670 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/42.9M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/13.8M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/14.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/15468 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4824 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5154 [00:00<?, ? examples/s]

In [3]:
color_list = ['blue', 'white', 'black', 'gray', 'silver']
brand_list = ['bentley', 'audi', 'bmw', 'acura']

def get_answer_type(answer):
    if answer in color_list:
        return 'color'
    elif answer in brand_list:
        return 'brand'
    else:
        return 'car_name'

In [4]:
class VQADataset(Dataset):
    def __init__(self, dataset, tokenizer, answer_to_idx):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.answer_to_idx = answer_to_idx
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        # Định nghĩa mapping cho answer_type
        self.answer_type_map = {'color': 0, 'brand': 1, 'car_name': 2}

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        image_bytes = sample['image']['bytes']
        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        image_tensor = self.transform(image)
        question = sample['question']
        tokenized = self.tokenizer(question, padding='max_length', truncation=True, max_length=32, return_tensors='pt')
        input_ids = tokenized['input_ids'].squeeze(0)
        attention_mask = tokenized['attention_mask'].squeeze(0)
        answer = sample['answer']
        answer_idx = self.answer_to_idx.get(answer, -1)
        answer_type = torch.tensor(self.answer_type_map[get_answer_type(answer)], dtype=torch.long)  # Chuyển thành tensor
        return image_tensor, input_ids, attention_mask, answer_idx, answer_type

In [5]:
class VQAModel(nn.Module):
    def __init__(self, num_answers):
        super(VQAModel, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            # nn.Linear(768 * 3, 2048),
            # nn.Linear(2048, 1024),
            # nn.GELU(),
            # nn.Linear(1024, 512),
            # nn.GELU(),
            nn.Linear(768, 512),
            nn.GELU(),
            nn.Linear(512, num_answers)
        )
        # for param in self.vit.encoder.layer[:6].parameters():
        #     param.requires_grad = False
        for param in self.bert.encoder.layer[:2].parameters():
            param.requires_grad = False

    def forward(self, image, input_ids, attention_mask):
        image_features = self.vit(image).last_hidden_state[:, 0, :]
        text_features = self.bert(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
        combined = image_features + text_features
        output = self.classifier(combined)
        return output

In [6]:
all_train_answers = list(set(sample['answer'] for sample in ds['train']))
answer_to_idx = {answer: idx for idx, answer in enumerate(all_train_answers)}

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
train_dataset = VQADataset(ds['train'], tokenizer, answer_to_idx)
val_dataset = VQADataset(ds['validation'], tokenizer, answer_to_idx)
test_dataset = VQADataset(ds['test'], tokenizer, answer_to_idx)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

# Khởi tạo mô hình và tối ưu
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = VQAModel(num_answers=len(answer_to_idx)).to(device)
optimizer = AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)
scaler = GradScaler()
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=3, verbose=True)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

  scaler = GradScaler()


In [8]:
# Training loop
num_epochs = 30
best_val_acc = 0
patience = 5
patience_counter = 0
color_loss_weight = 2.0
answer_type_map_reverse = {0: 'color', 1: 'brand', 2: 'car_name'}  # Để map ngược lại từ số sang string

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    total_color_loss = 0
    total_brand_loss = 0
    total_car_name_loss = 0
    color_count = 0
    brand_count = 0
    car_name_count = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
        image, input_ids, attention_mask, answer_idx, answer_type = [x.to(device) for x in batch]
        valid_mask = answer_idx != -1
        if not valid_mask.any():
            continue
        # Lọc các tensor bằng valid_mask
        image = image[valid_mask]
        input_ids = input_ids[valid_mask]
        attention_mask = attention_mask[valid_mask]
        answer_idx = answer_idx[valid_mask]
        answer_type = answer_type[valid_mask]  # answer_type giờ là tensor

        optimizer.zero_grad()
        with autocast():
            output = model(image, input_ids, attention_mask)
            loss = loss_fn(output, answer_idx)
            weighted_loss = torch.zeros_like(loss)
            for i in range(len(answer_type)):
                ans_type_str = answer_type_map_reverse[answer_type[i].item()]  # Map ngược lại thành string
                if ans_type_str == 'color':
                    weighted_loss += loss_fn(output[i].unsqueeze(0), answer_idx[i].unsqueeze(0)) * color_loss_weight
                    total_color_loss += loss_fn(output[i].unsqueeze(0), answer_idx[i].unsqueeze(0)).item()
                    color_count += 1
                elif ans_type_str == 'brand':
                    weighted_loss += loss_fn(output[i].unsqueeze(0), answer_idx[i].unsqueeze(0))
                    total_brand_loss += loss_fn(output[i].unsqueeze(0), answer_idx[i].unsqueeze(0)).item()
                    brand_count += 1
                else:
                    weighted_loss += loss_fn(output[i].unsqueeze(0), answer_idx[i].unsqueeze(0))
                    total_car_name_loss += loss_fn(output[i].unsqueeze(0), answer_idx[i].unsqueeze(0)).item()
                    car_name_count += 1
        scaler.scale(weighted_loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)
    avg_color_loss = total_color_loss / color_count if color_count > 0 else 0
    avg_brand_loss = total_brand_loss / brand_count if brand_count > 0 else 0
    avg_car_name_loss = total_car_name_loss / car_name_count if car_name_count > 0 else 0
    print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}")
    print(f"Color Loss: {avg_color_loss:.4f}, Brand Loss: {avg_brand_loss:.4f}, Car Name Loss: {avg_car_name_loss:.4f}")

    # Validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
            image, input_ids, attention_mask, answer_idx, _ = [x.to(device) for x in batch]
            valid_mask = answer_idx != -1
            if not valid_mask.any():
                continue
            image = image[valid_mask]
            input_ids = input_ids[valid_mask]
            attention_mask = attention_mask[valid_mask]
            answer_idx = answer_idx[valid_mask]
            with autocast():
                output = model(image, input_ids, attention_mask)
            pred = output.argmax(dim=1)
            correct += (pred == answer_idx).sum().item()
            total += answer_idx.size(0)
    val_accuracy = correct / total if total > 0 else 0
    print(f"Epoch {epoch+1}, Validation Accuracy: {val_accuracy:.4f}")

    scheduler.step(val_accuracy)
    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy
        patience_counter = 0
        torch.save(model.state_dict(), 'best_vqa_model.pth')
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered!")
            break

  with autocast():
Epoch 1/30 - Training: 100%|██████████| 484/484 [03:11<00:00,  2.53it/s]


Epoch 1, Train Loss: 0.8185
Color Loss: 0.7538, Brand Loss: 0.7378, Car Name Loss: 0.9643


  with autocast():
Epoch 1/30 - Validation: 100%|██████████| 151/151 [00:17<00:00,  8.59it/s]


Epoch 1, Validation Accuracy: 0.8566


Epoch 2/30 - Training: 100%|██████████| 484/484 [03:10<00:00,  2.54it/s]


Epoch 2, Train Loss: 0.6702
Color Loss: 0.6583, Brand Loss: 0.6543, Car Name Loss: 0.6980


Epoch 2/30 - Validation: 100%|██████████| 151/151 [00:17<00:00,  8.55it/s]


Epoch 2, Validation Accuracy: 0.8661


Epoch 3/30 - Training: 100%|██████████| 484/484 [03:10<00:00,  2.54it/s]


Epoch 3, Train Loss: 0.6579
Color Loss: 0.6506, Brand Loss: 0.6489, Car Name Loss: 0.6741


Epoch 3/30 - Validation: 100%|██████████| 151/151 [00:17<00:00,  8.58it/s]


Epoch 3, Validation Accuracy: 0.8673


Epoch 4/30 - Training: 100%|██████████| 484/484 [03:10<00:00,  2.54it/s]


Epoch 4, Train Loss: 0.6527
Color Loss: 0.6465, Brand Loss: 0.6459, Car Name Loss: 0.6658


Epoch 4/30 - Validation: 100%|██████████| 151/151 [00:17<00:00,  8.59it/s]


Epoch 4, Validation Accuracy: 0.8686


Epoch 5/30 - Training: 100%|██████████| 484/484 [03:10<00:00,  2.54it/s]


Epoch 5, Train Loss: 0.6522
Color Loss: 0.6452, Brand Loss: 0.6460, Car Name Loss: 0.6653


Epoch 5/30 - Validation: 100%|██████████| 151/151 [00:17<00:00,  8.53it/s]


Epoch 5, Validation Accuracy: 0.8692


Epoch 6/30 - Training: 100%|██████████| 484/484 [03:10<00:00,  2.54it/s]


Epoch 6, Train Loss: 0.6513
Color Loss: 0.6443, Brand Loss: 0.6452, Car Name Loss: 0.6643


Epoch 6/30 - Validation: 100%|██████████| 151/151 [00:17<00:00,  8.57it/s]


Epoch 6, Validation Accuracy: 0.8688


Epoch 7/30 - Training: 100%|██████████| 484/484 [03:11<00:00,  2.53it/s]


Epoch 7, Train Loss: 0.6500
Color Loss: 0.6442, Brand Loss: 0.6444, Car Name Loss: 0.6613


Epoch 7/30 - Validation: 100%|██████████| 151/151 [00:17<00:00,  8.59it/s]


Epoch 7, Validation Accuracy: 0.8692


Epoch 8/30 - Training: 100%|██████████| 484/484 [03:10<00:00,  2.54it/s]


Epoch 8, Train Loss: 0.6497
Color Loss: 0.6432, Brand Loss: 0.6443, Car Name Loss: 0.6615


Epoch 8/30 - Validation: 100%|██████████| 151/151 [00:17<00:00,  8.57it/s]


Epoch 8, Validation Accuracy: 0.8706


Epoch 9/30 - Training: 100%|██████████| 484/484 [03:10<00:00,  2.54it/s]


Epoch 9, Train Loss: 0.6495
Color Loss: 0.6435, Brand Loss: 0.6440, Car Name Loss: 0.6610


Epoch 9/30 - Validation: 100%|██████████| 151/151 [00:17<00:00,  8.62it/s]


Epoch 9, Validation Accuracy: 0.8704


Epoch 10/30 - Training: 100%|██████████| 484/484 [03:10<00:00,  2.54it/s]


Epoch 10, Train Loss: 0.6490
Color Loss: 0.6433, Brand Loss: 0.6436, Car Name Loss: 0.6601


Epoch 10/30 - Validation: 100%|██████████| 151/151 [00:17<00:00,  8.57it/s]


Epoch 10, Validation Accuracy: 0.8706


Epoch 11/30 - Training: 100%|██████████| 484/484 [03:10<00:00,  2.54it/s]


Epoch 11, Train Loss: 0.6497
Color Loss: 0.6437, Brand Loss: 0.6441, Car Name Loss: 0.6613


Epoch 11/30 - Validation: 100%|██████████| 151/151 [00:17<00:00,  8.42it/s]


Epoch 11, Validation Accuracy: 0.8706


Epoch 12/30 - Training: 100%|██████████| 484/484 [03:11<00:00,  2.53it/s]


Epoch 12, Train Loss: 0.6494
Color Loss: 0.6436, Brand Loss: 0.6442, Car Name Loss: 0.6605


Epoch 12/30 - Validation: 100%|██████████| 151/151 [00:17<00:00,  8.57it/s]


Epoch 12, Validation Accuracy: 0.8704


Epoch 13/30 - Training: 100%|██████████| 484/484 [03:10<00:00,  2.54it/s]


Epoch 13, Train Loss: 0.6488
Color Loss: 0.6429, Brand Loss: 0.6439, Car Name Loss: 0.6595


Epoch 13/30 - Validation: 100%|██████████| 151/151 [00:17<00:00,  8.57it/s]

Epoch 13, Validation Accuracy: 0.8706
Early stopping triggered!





In [9]:
from collections import defaultdict
# Test
model.load_state_dict(torch.load('best_vqa_model.pth', weights_only=True))
model.eval()
correct_color = 0
total_color = 0
correct_brand = 0
total_brand = 0
correct_car_name = 0
total_car_name = 0
total_correct = 0
total_samples = 0
total_color_loss = 0
total_brand_loss = 0
total_car_name_loss = 0
color_count = 0
brand_count = 0
car_name_count = 0
answer_type_map_reverse = {0: 'color', 1: 'brand', 2: 'car_name'}
idx_to_answer = {idx: answer for answer, idx in answer_to_idx.items()}  # Map ngược từ idx sang answer

# Theo dõi số lần xuất hiện, đúng, và nhầm lẫn
color_stats = defaultdict(lambda: {'correct': 0, 'total': 0, 'confusion': defaultdict(int)})
brand_stats = defaultdict(lambda: {'correct': 0, 'total': 0, 'confusion': defaultdict(int)})
car_name_stats = defaultdict(lambda: {'correct': 0, 'total': 0, 'confusion': defaultdict(int)})

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Test"):
        image, input_ids, attention_mask, answer_idx, answer_type = [x.to(device) for x in batch]
        valid_mask = (answer_idx != -1).to(device)
        if not valid_mask.any():
            continue
        image = image[valid_mask]
        input_ids = input_ids[valid_mask]
        attention_mask = attention_mask[valid_mask]
        answer_idx = answer_idx[valid_mask]
        answer_type = answer_type[valid_mask]

        with autocast():
            output = model(image, input_ids, attention_mask)
            loss = loss_fn(output, answer_idx)
        pred = output.argmax(dim=1)
        total_correct += (pred == answer_idx).sum().item()
        total_samples += answer_idx.size(0)

        for i in range(len(answer_type)):
            ans_type_str = answer_type_map_reverse[answer_type[i].item()]
            true_answer = idx_to_answer[answer_idx[i].item()]
            pred_answer = idx_to_answer[pred[i].item()]
            is_correct = pred[i] == answer_idx[i]

            if ans_type_str == 'color':
                total_color += 1
                total_color_loss += loss_fn(output[i].unsqueeze(0), answer_idx[i].unsqueeze(0)).item()
                color_count += 1
                if is_correct:
                    correct_color += 1
                color_stats[true_answer]['total'] += 1
                if is_correct:
                    color_stats[true_answer]['correct'] += 1
                else:
                    color_stats[true_answer]['confusion'][pred_answer] += 1

            elif ans_type_str == 'brand':
                total_brand += 1
                total_brand_loss += loss_fn(output[i].unsqueeze(0), answer_idx[i].unsqueeze(0)).item()
                brand_count += 1
                if is_correct:
                    correct_brand += 1
                brand_stats[true_answer]['total'] += 1
                if is_correct:
                    brand_stats[true_answer]['correct'] += 1
                else:
                    brand_stats[true_answer]['confusion'][pred_answer] += 1

            else:
                total_car_name += 1
                total_car_name_loss += loss_fn(output[i].unsqueeze(0), answer_idx[i].unsqueeze(0)).item()
                car_name_count += 1
                if is_correct:
                    correct_car_name += 1
                car_name_stats[true_answer]['total'] += 1
                if is_correct:
                    car_name_stats[true_answer]['correct'] += 1
                else:
                    car_name_stats[true_answer]['confusion'][pred_answer] += 1

# Tính accuracy tổng và riêng
test_color_accuracy = correct_color / total_color if total_color > 0 else 0
test_brand_accuracy = correct_brand / total_brand if total_brand > 0 else 0
test_car_name_accuracy = correct_car_name / total_car_name if total_car_name > 0 else 0
test_total_accuracy = total_correct / total_samples if total_samples > 0 else 0
avg_color_loss = total_color_loss / color_count if color_count > 0 else 0
avg_brand_loss = total_brand_loss / brand_count if brand_count > 0 else 0
avg_car_name_loss = total_car_name_loss / car_name_count if car_name_count > 0 else 0

# Tính top 4 sai nhiều nhất và nhầm lẫn
def get_top_errors(stats, category_name):
    error_rates = []
    for answer, stat in stats.items():
        total = stat['total']
        correct = stat['correct']
        if total > 0:
            error_rate = 1 - (correct / total)
            error_rates.append((answer, error_rate, total, stat['confusion']))
    error_rates.sort(key=lambda x: (x[1], x[2]), reverse=True)
    top_4 = error_rates[:4]
    print(f"\nTop 4 {category_name} sai nhiều nhất:")
    for answer, error_rate, total, confusion in top_4:
        print(f"- {answer}: {error_rate:.4f} ({(error_rate * 100):.2f}%), xuất hiện {total} lần")
        print(f"  Nhầm lẫn với:")
        for pred_answer, count in confusion.items():
            print(f"    + {pred_answer}: {count} lần")

# In kết quả
print(f"Test Total Accuracy: {test_total_accuracy:.4f}")
print(f"Test Color Accuracy: {test_color_accuracy:.4f}, Brand Accuracy: {test_brand_accuracy:.4f}, Car Name Accuracy: {test_car_name_accuracy:.4f}")
print(f"Test Color Loss: {avg_color_loss:.4f}, Brand Loss: {avg_brand_loss:.4f}, Car Name Loss: {avg_car_name_loss:.4f}")

get_top_errors(color_stats, "màu")
get_top_errors(brand_stats, "hãng xe")
get_top_errors(car_name_stats, "tên xe")

  with autocast():
Test: 100%|██████████| 162/162 [00:19<00:00,  8.47it/s]

Test Total Accuracy: 0.8481
Test Color Accuracy: 0.8475, Brand Accuracy: 0.9325, Car Name Accuracy: 0.7643
Test Color Loss: 1.1091, Brand Loss: 0.7996, Car Name Loss: 1.3064

Top 4 màu sai nhiều nhất:
- silver: 0.2581 (25.81%), xuất hiện 248 lần
  Nhầm lẫn với:
    + gray: 27 lần
    + white: 17 lần
    + black: 11 lần
    + blue: 9 lần
- gray: 0.2367 (23.67%), xuất hiện 245 lần
  Nhầm lẫn với:
    + white: 9 lần
    + black: 28 lần
    + silver: 20 lần
    + blue: 1 lần
- blue: 0.1298 (12.98%), xuất hiện 339 lần
  Nhầm lẫn với:
    + silver: 16 lần
    + black: 25 lần
    + gray: 3 lần
- black: 0.1141 (11.41%), xuất hiện 368 lần
  Nhầm lẫn với:
    + white: 19 lần
    + silver: 8 lần
    + gray: 13 lần
    + blue: 2 lần

Top 4 hãng xe sai nhiều nhất:
- bmw: 0.1482 (14.82%), xuất hiện 425 lần
  Nhầm lẫn với:
    + audi: 36 lần
    + bentley: 17 lần
    + acura: 10 lần
- audi: 0.0520 (5.20%), xuất hiện 442 lần
  Nhầm lẫn với:
    + bmw: 8 lần
    + acura: 6 lần
    + bentley: 9 lần
- ac




In [16]:
save_path = "/kaggle/working/my-vqa-model"
os.makedirs(save_path, exist_ok=True)
torch.save(model.state_dict(), f"{save_path}/pytorch_model.bin")
config = {
    "vit_model": "google/vit-base-patch16-224",
    "bert_model": "bert-base-uncased",
    "num_answers": len(answer_to_idx),
    "architecture": "ViT+BERT with concatenation and multiplication fusion"
}
with open(f"{save_path}/config.json", "w") as f:
    json.dump(config, f)
tokenizer.save_pretrained(save_path)
with open(f"{save_path}/answer_list.json", "w") as f:
    json.dump(all_train_answers, f)