Running with Kaggle environment

In [None]:
# ========================================================================
# 1. Cài đặt thư viện
# ========================================================================

!pip install -q torch torchvision transformers pillow numpy tqdm
!pip install -q nltk rouge-score bert-score pycocoevalcap
!pip install -q peft==0.7.1 pyvi

In [None]:
# ========================================================================
# 2. Import thư viện
# ========================================================================

import os
import json
import torch
import math
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import numpy as np
from PIL import Image
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

from transformers import (
    ViTModel, ViTImageProcessor,
    AutoModel, AutoTokenizer,
    AutoModelForSeq2SeqLM,
    get_linear_schedule_with_warmup
)
from transformers.modeling_outputs import BaseModelOutput
from peft import LoraConfig, get_peft_model, TaskType, PeftModel

# Metrics
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from bert_score import score as bert_score
from pycocoevalcap.cider.cider import Cider
from pyvi import ViTokenizer

In [None]:
# ========================================================================
# 3. Cấu hình
# ========================================================================

class Config:
    # Đường dẫn
    JSON_PATH = '/kaggle/input/dataset/vn-revolutionary-heritage-vqa.json'
    IMAGE_BASE_PATH = '/kaggle/input/images'
    OUTPUT_DIR = '/kaggle/working/vit-phobert-vit5-lora'
    
    # Model
    VIT_MODEL = 'google/vit-base-patch16-224'
    PHOBERT_MODEL = 'vinai/phobert-base'
    VIT5_MODEL = 'VietAI/vit5-base'
    
    # Hyperparameters huấn luyện
    BATCH_SIZE = 4
    GRADIENT_ACCUMULATION = 4
    NUM_EPOCHS = 20
    MAX_GRAD_NORM = 1.0
    
    # LoRA
    USE_LORA = True
    LR_LORA = 1e-4
    WARMUP_RATIO = 0.05
    WEIGHT_DECAY = 0.01
    
    # LoRA config cho từng model
    LORA_VIT_R = 16
    LORA_VIT_ALPHA = 32
    LORA_VIT_DROPOUT = 0.15
    LORA_VIT_TARGET_MODULES = ["query", "key", "value"]
    
    LORA_PHOBERT_R = 16
    LORA_PHOBERT_ALPHA = 32
    LORA_PHOBERT_DROPOUT = 0.15
    LORA_PHOBERT_TARGET_MODULES = ["query", "key", "value"]
    
    LORA_VIT5_R = 16
    LORA_VIT5_ALPHA = 32
    LORA_VIT5_DROPOUT = 0.15
    LORA_VIT5_TARGET_MODULES = ["q", "k", "v"]
    
    # Chia dữ liệu
    TRAIN_RATIO = 0.8
    VAL_RATIO = 0.10
    TEST_RATIO = 0.10
    SEED = 42
    
    # Độ dài sequence
    MAX_QUESTION_LEN = 128
    MAX_ANSWER_LEN = 128
    
    # Kiến trúc model
    HIDDEN_SIZE = 768
    NUM_ATTENTION_HEADS = 8
    FUSION_DROPOUT = 0.15
    
    # Tối ưu hóa
    USE_FP16 = True
    USE_GRADIENT_CHECKPOINTING = True
    NUM_WORKERS = 2
    PATIENCE = 3
    
    # Generation parameters tối ưu cho tiếng Việt
    GEN_MAX_LENGTH = 128
    GEN_MIN_LENGTH = 3
    GEN_NUM_BEAMS = 3
    GEN_NO_REPEAT_NGRAM = 2
    GEN_LENGTH_PENALTY = 0.8
    GEN_REPETITION_PENALTY = 1.2
    
    # Composite score weights cho early stopping
    SCORE_WEIGHTS = {
        'bleu': 0.30,
        'bertscore_f1': 0.30,
        'cider': 0.20,
        'rougeL': 0.15,
        'exact_match': 0.05
    }
    
    # Differentiated learning rates
    LR_VIT_MULTIPLIER = 1.0      # 1e-4
    LR_PHOBERT_MULTIPLIER = 0.5  # 5e-5
    LR_VIT5_MULTIPLIER = 0.3     # 3e-5
    LR_FUSION_MULTIPLIER = 5.0   # 5e-4
    
    # Fusion strategy: "gating" hoặc "downsample"
    FUSION_STRATEGY = "gating"
    VISUAL_DOWNSAMPLE_TARGET = 49  # 196 -> 49 nếu dùng downsample

config = Config()

In [None]:
# ========================================================================
# 4. Phân đoạn từ tiếng Việt
# ========================================================================

class WordSegmenter:
    """Tách từ tiếng Việt bằng PyVi"""
    _instance = None
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            print("✓ Đã load PyVi tokenizer")
        return cls._instance
    
    def segment(self, text):
        """Tách từ văn bản tiếng Việt"""
        try:
            return ViTokenizer.tokenize(text)
        except Exception as e:
            print(f"Cảnh báo: Lỗi tách từ: {e}")
            return text

In [None]:
# ========================================================================
# 5. Load dữ liệu 
# ========================================================================

def validate_samples(samples, name="samples"):
    """Validate quality của samples"""
    print(f"\nValidating {name}...")
    
    valid_samples = []
    issues = {
        'empty_question': 0,
        'empty_answer': 0,
        'too_short_answer': 0,
        'missing_image': 0
    }
    
    for sample in samples:
        # Check question
        if not sample.get('question') or len(sample['question'].strip()) < 3:
            issues['empty_question'] += 1
            continue
        
        # Check answer exists
        if not sample.get('answer') or len(sample['answer'].strip()) < 2:
            issues['empty_answer'] += 1
            continue
        
        # Check answer không quá ngắn (filter "Có", "Không", etc.)
        answer_words = sample['answer'].strip().split()
        if len(answer_words) < 2:
            issues['too_short_answer'] += 1
            continue
        
        # Check image exists
        if not os.path.exists(sample['image_path']):
            issues['missing_image'] += 1
            continue
        
        valid_samples.append(sample)
    
    # In report
    print(f"  Total samples: {len(samples)}")
    print(f"  Valid samples: {len(valid_samples)}")
    if sum(issues.values()) > 0:
        print(f"  Issues found:")
        for issue_type, count in issues.items():
            if count > 0:
                print(f"    - {issue_type}: {count}")
    
    return valid_samples

def load_and_prepare_data(json_path, image_base_path):
    """Load dữ liệu và chia theo ảnh để tránh data leakage"""
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    image_groups = {}
    
    # Nhóm các QA theo ảnh
    for item in data:
        name = item.get('name', '')
        location = item.get('location_after', item.get('location_before', ''))
        
        for img_data in item.get('images', []):
            img_path = os.path.join(image_base_path, img_data['image'])
            
            if not os.path.exists(img_path):
                continue
            
            img_key = img_data['image']
            if img_key not in image_groups:
                image_groups[img_key] = []
            
            for qa in img_data.get('qas', []):
                if not qa.get('answer') or len(qa['answer'].strip()) < 2:
                    continue
                
                image_groups[img_key].append({
                    'image_path': img_path,
                    'question': qa['question'],
                    'answer': qa['answer'],
                    'site_name': name,
                    'location': location,
                    'image_key': img_key
                })
    
    # Chia dữ liệu theo ảnh
    image_keys = list(image_groups.keys())
    np.random.seed(config.SEED)
    np.random.shuffle(image_keys)
    
    n_images = len(image_keys)
    n_test = int(n_images * config.TEST_RATIO)
    n_val = int(n_images * config.VAL_RATIO)
    
    test_keys = image_keys[:n_test]
    val_keys = image_keys[n_test:n_test+n_val]
    train_keys = image_keys[n_test+n_val:]
    
    # Lấy samples
    train_samples = [s for key in train_keys for s in image_groups[key]]
    val_samples = [s for key in val_keys for s in image_groups[key]]
    test_samples = [s for key in test_keys for s in image_groups[key]]
    
    np.random.shuffle(train_samples)
    
    train_samples = validate_samples(train_samples, "train")
    val_samples = validate_samples(val_samples, "val")
    test_samples = validate_samples(test_samples, "test")
    
    print(f"\n{'='*60}")
    print(f"Chia dữ liệu theo ảnh (sau khi validate):")
    print(f"{'='*60}")
    print(f"Ảnh     - Train: {len(train_keys):4d} | Val: {len(val_keys):4d} | Test: {len(test_keys):4d}")
    print(f"Samples - Train: {len(train_samples):4d} | Val: {len(val_samples):4d} | Test: {len(test_samples):4d}")
    print(f"{'='*60}\n")
    
    # Lưu test samples
    os.makedirs(config.OUTPUT_DIR, exist_ok=True)
    with open(os.path.join(config.OUTPUT_DIR, "test_samples.json"), "w", encoding="utf-8") as f:
        json.dump(test_samples, f, ensure_ascii=False, indent=2)
    
    return train_samples, val_samples, test_samples

In [None]:
# ========================================================================
# 6. Dataset
# ========================================================================

class HeritageVQADataset(Dataset):
    """Dataset cho VQA di sản Việt Nam"""
    def __init__(self, samples, vit_processor, phobert_tokenizer, vit5_tokenizer, segmenter):
        self.samples = samples
        self.vit_processor = vit_processor
        self.phobert_tokenizer = phobert_tokenizer
        self.vit5_tokenizer = vit5_tokenizer
        self.segmenter = segmenter
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Load ảnh
        try:
            image = Image.open(sample['image_path']).convert('RGB')
        except Exception as e:
            print(f"Lỗi load ảnh {sample['image_path']}: {e}")
            image = Image.new('RGB', (224, 224), color='gray')
        
        # Xử lý ảnh cho ViT
        pixel_values = self.vit_processor(images=image, return_tensors='pt')['pixel_values'].squeeze(0)
        
        # Tách từ câu hỏi cho PhoBERT
        question = self.segmenter.segment(sample['question'])
        
        # Tokenize câu hỏi
        question_enc = self.phobert_tokenizer(
            question,
            max_length=config.MAX_QUESTION_LEN,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Tokenize câu trả lời
        answer_enc = self.vit5_tokenizer(
            sample['answer'],
            max_length=config.MAX_ANSWER_LEN,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'pixel_values': pixel_values,
            'question_input_ids': question_enc['input_ids'].squeeze(0),
            'question_attention_mask': question_enc['attention_mask'].squeeze(0),
            'answer_input_ids': answer_enc['input_ids'].squeeze(0),
            'answer_attention_mask': answer_enc['attention_mask'].squeeze(0),
            'answer_text': sample['answer'],
            'question_text': sample['question'],  
            'image_path': sample['image_path']     
        }

def collate_fn(batch):
    """Gộp batch"""
    return {
        'pixel_values': torch.stack([item['pixel_values'] for item in batch]),
        'question_input_ids': torch.stack([item['question_input_ids'] for item in batch]),
        'question_attention_mask': torch.stack([item['question_attention_mask'] for item in batch]),
        'answer_input_ids': torch.stack([item['answer_input_ids'] for item in batch]),
        'answer_attention_mask': torch.stack([item['answer_attention_mask'] for item in batch]),
        'answer_text': [item['answer_text'] for item in batch],
        'question_text': [item['question_text'] for item in batch],  
        'image_path': [item['image_path'] for item in batch]         
    }

In [None]:
# ========================================================================
# 7. Module Fusion 
# ========================================================================

class CrossModalAttention(nn.Module):
    """Cross-attention giữa visual và text"""
    def __init__(self, hidden_size, num_heads, dropout=0.1):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        self.norm = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, query, key_value, key_padding_mask=None):
        attn_output, _ = self.multihead_attn(
            query, key_value, key_value,
            key_padding_mask=key_padding_mask
        )
        return self.norm(query + self.dropout(attn_output))

class FusionModule(nn.Module):
    """Kết hợp 2 chiều giữa visual và text với Gating - ĐÃ CẢI TIẾN"""
    def __init__(self, config):
        super().__init__()
        self.strategy = config.FUSION_STRATEGY
        
        # Visual attend to Question
        self.v2q_attention = CrossModalAttention(
            config.HIDDEN_SIZE, config.NUM_ATTENTION_HEADS, config.FUSION_DROPOUT
        )
        # Question attend to Visual
        self.q2v_attention = CrossModalAttention(
            config.HIDDEN_SIZE, config.NUM_ATTENTION_HEADS, config.FUSION_DROPOUT
        )
        
        if self.strategy == "gating":
            self.gate = nn.Sequential(
                nn.Linear(config.HIDDEN_SIZE * 2, config.HIDDEN_SIZE),
                nn.ReLU(),
                nn.Dropout(config.FUSION_DROPOUT),
                nn.Linear(config.HIDDEN_SIZE, 1),
                nn.Sigmoid()
            )
            print("✓ Fusion: Gating mechanism enabled")
        
        elif self.strategy == "downsample":
            # Downsample từ 196 -> 49 tokens
            self.visual_downsample = nn.Sequential(
                nn.Conv1d(config.HIDDEN_SIZE, config.HIDDEN_SIZE, 
                         kernel_size=2, stride=2),  # 196 -> 98
                nn.ReLU(),
                nn.Conv1d(config.HIDDEN_SIZE, config.HIDDEN_SIZE, 
                         kernel_size=2, stride=2)   # 98 -> 49
            )
            print("✓ Fusion: Visual downsampling enabled (196 -> 49)")
        
        # Projection
        self.projection = nn.Sequential(
            nn.Linear(config.HIDDEN_SIZE, config.HIDDEN_SIZE),
            nn.LayerNorm(config.HIDDEN_SIZE),
            nn.Dropout(config.FUSION_DROPOUT)
        )
    
    def forward(self, visual_features, question_features, question_mask):
        """
        Args:
            visual_features: [batch, 196, 768] từ ViT
            question_features: [batch, seq_len, 768] từ PhoBERT
            question_mask: [batch, seq_len]
        Returns:
            fused: [batch, N, 768] where N depends on strategy
        """
        if self.strategy == "downsample":
            # visual_features: [batch, 196, 768]
            vis_transposed = visual_features.transpose(1, 2)  # [batch, 768, 196]
            vis_downsampled = self.visual_downsample(vis_transposed)  # [batch, 768, 49]
            visual_features = vis_downsampled.transpose(1, 2)  # [batch, 49, 768]
        
        question_padding_mask = (question_mask == 0)
        
        # Cross-attention
        attended_question = self.v2q_attention(
            query=question_features,
            key_value=visual_features,
            key_padding_mask=None
        )
        
        attended_visual = self.q2v_attention(
            query=visual_features,
            key_value=question_features,
            key_padding_mask=question_padding_mask
        )
        
        # ===== APPLY GATING NẾU CẦN =====
        if self.strategy == "gating":
            # Compute gate weight
            vis_pooled = attended_visual.mean(dim=1)  # [batch, 768]
            q_pooled = attended_question.mean(dim=1)  # [batch, 768]
            gate_input = torch.cat([vis_pooled, q_pooled], dim=-1)  # [batch, 1536]
            gate_weight = self.gate(gate_input)  # [batch, 1]
            
            # Apply gate
            attended_visual = attended_visual * gate_weight.unsqueeze(1)
            attended_question = attended_question * (1 - gate_weight.unsqueeze(1))
        
        # Concat và project
        fused = torch.cat([attended_visual, attended_question], dim=1)
        return self.projection(fused)

In [None]:
# ========================================================================
# 8. Custom ViT với LoRA
# ========================================================================

class ViTWithLoRA(nn.Module):
    """ViT với LoRA adapters tự implement"""
    def __init__(self, vit_model, r=8, alpha=16, dropout=0.1):
        super().__init__()
        self.vit = vit_model
        self.r = r
        self.alpha = alpha
        self.scaling = alpha / r

        self.lora_layers = nn.ModuleDict()
        
        # Thêm LoRA cho query & value của tất cả attention layers
        self.lora_layers = nn.ModuleDict()
        
        for i, layer in enumerate(self.vit.encoder.layer):
            # LoRA cho Query
            query_layer = layer.attention.attention.query
            in_features = query_layer.in_features
            out_features = query_layer.out_features
            
            self.lora_layers[f'layer_{i}_query_A'] = nn.Linear(in_features, r, bias=False)
            self.lora_layers[f'layer_{i}_query_B'] = nn.Linear(r, out_features, bias=False)

            key_layer = layer.attention.attention.key

            self.lora_layers[f'layer_{i}_key_A'] = nn.Linear(in_features, r, bias=False)  
            self.lora_layers[f'layer_{i}_key_B'] = nn.Linear(r, out_features, bias=False)  
            
            # LoRA cho Value
            self.lora_layers[f'layer_{i}_value_A'] = nn.Linear(in_features, r, bias=False)
            self.lora_layers[f'layer_{i}_value_B'] = nn.Linear(r, out_features, bias=False)
            
            # Khởi tạo weights
            nn.init.kaiming_uniform_(self.lora_layers[f'layer_{i}_query_A'].weight, a=math.sqrt(5))
            nn.init.zeros_(self.lora_layers[f'layer_{i}_query_B'].weight)
            nn.init.kaiming_uniform_(self.lora_layers[f'layer_{i}_key_A'].weight, a=math.sqrt(5))  
            nn.init.zeros_(self.lora_layers[f'layer_{i}_key_B'].weight)
            nn.init.kaiming_uniform_(self.lora_layers[f'layer_{i}_value_A'].weight, a=math.sqrt(5))
            nn.init.zeros_(self.lora_layers[f'layer_{i}_value_B'].weight)
        
        self.dropout = nn.Dropout(dropout)
        
        # Đóng băng ViT gốc
        for param in self.vit.parameters():
            param.requires_grad = False
        
        print(f"ViT LoRA: r={r}, alpha={alpha}, {len(self.lora_layers)} adapters")
    
    def forward(self, pixel_values):
        """Forward pass với LoRA"""
        hooks = []
        
        def make_lora_hook(layer_idx, param_name):
            def hook(module, input, output):
                lora_A = self.lora_layers[f'layer_{layer_idx}_{param_name}_A']
                lora_B = self.lora_layers[f'layer_{layer_idx}_{param_name}_B']
                x = input[0]
                lora_output = lora_B(lora_A(self.dropout(x))) * self.scaling
                return output + lora_output
            return hook
        
        # Register hooks
        for i, layer in enumerate(self.vit.encoder.layer):
            h1 = layer.attention.attention.query.register_forward_hook(make_lora_hook(i, 'query'))
            h2 = layer.attention.attention.key.register_forward_hook(make_lora_hook(i, 'key'))
            h3 = layer.attention.attention.value.register_forward_hook(make_lora_hook(i, 'value'))
            hooks.extend([h1, h2, h3])
        
        # Forward ViT
        outputs = self.vit(pixel_values=pixel_values)
        
        # Remove hooks
        for hook in hooks:
            hook.remove()
        
        return outputs
    
    def print_trainable_parameters(self):
        """In số params huấn luyện"""
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        all_params = sum(p.numel() for p in self.parameters())
        print(f"Trainable: {trainable_params:,} / {all_params:,} ({100*trainable_params/all_params:.4f}%)")

In [None]:
# ========================================================================
# 9. Model chính
# ========================================================================

class ViTPhoBERTViT5Model(nn.Module):
    """Model VQA: ViT + PhoBERT + ViT5 với LoRA"""
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        print("\n" + "="*60)
        print("Khởi tạo Model với LoRA")
        print("="*60)
        
        # Load pretrained models
        print("Loading ViT...")
        vit_base = ViTModel.from_pretrained(config.VIT_MODEL)
        
        print("Loading PhoBERT...")
        self.phobert = AutoModel.from_pretrained(config.PHOBERT_MODEL)
        
        print("Loading ViT5...")
        self.vit5 = AutoModelForSeq2SeqLM.from_pretrained(config.VIT5_MODEL)
        
        # Áp dụng LoRA
        if config.USE_LORA:
            print("\nÁp dụng LoRA...")
            
            # Custom LoRA cho ViT
            self.vit = ViTWithLoRA(
                vit_base,
                r=config.LORA_VIT_R,
                alpha=config.LORA_VIT_ALPHA,
                dropout=config.LORA_VIT_DROPOUT
            )
            self.vit.print_trainable_parameters()
            
            # PEFT LoRA cho PhoBERT
            phobert_lora_config = LoraConfig(
                r=config.LORA_PHOBERT_R,
                lora_alpha=config.LORA_PHOBERT_ALPHA,
                target_modules=config.LORA_PHOBERT_TARGET_MODULES,
                lora_dropout=config.LORA_PHOBERT_DROPOUT,
                bias="none",
                task_type=TaskType.FEATURE_EXTRACTION
            )
            self.phobert = get_peft_model(self.phobert, phobert_lora_config)
            print("✓ PhoBERT LoRA")
            self.phobert.print_trainable_parameters()
            
            # PEFT LoRA cho ViT5
            vit5_lora_config = LoraConfig(
                r=config.LORA_VIT5_R,
                lora_alpha=config.LORA_VIT5_ALPHA,
                target_modules=config.LORA_VIT5_TARGET_MODULES,
                lora_dropout=config.LORA_VIT5_DROPOUT,
                bias="none",
                task_type=TaskType.SEQ_2_SEQ_LM
            )
            self.vit5 = get_peft_model(self.vit5, vit5_lora_config)
            print("✓ ViT5 LoRA")
            self.vit5.print_trainable_parameters()
        else:
            self.vit = vit_base
        
        # Fusion module
        print("\nKhởi tạo Fusion...")
        self.fusion = FusionModule(config)
        
        # Projection đến ViT5
        vit5_dim = self.vit5.config.d_model
        self.fusion_to_vit5 = nn.Linear(config.HIDDEN_SIZE, vit5_dim)
        
        self.pad_token_id = self.vit5.config.pad_token_id or 0
        
        # Gradient checkpointing
        if config.USE_GRADIENT_CHECKPOINTING:
            self.vit5.gradient_checkpointing_enable()
        
        print("="*60 + "\n")
    
    def forward(self, pixel_values, question_input_ids, question_attention_mask, 
                answer_input_ids=None, answer_attention_mask=None):
        batch_size = pixel_values.size(0)
        
        # Extract visual features từ ViT
        vit_outputs = self.vit(pixel_values=pixel_values)
        visual_features = vit_outputs.last_hidden_state[:, 1:, :]  # [batch, 196, 768]
        
        # Extract question features từ PhoBERT
        phobert_outputs = self.phobert(
            input_ids=question_input_ids,
            attention_mask=question_attention_mask
        )
        question_features = phobert_outputs.last_hidden_state
        
        # Fusion
        fused_features = self.fusion(visual_features, question_features, question_attention_mask)
        fused_features = self.fusion_to_vit5(fused_features)
        
        # Tạo attention mask cho fused features
        num_visual_tokens = fused_features.size(1) - question_attention_mask.size(1)
        visual_mask = torch.ones(batch_size, num_visual_tokens, dtype=torch.long, device=fused_features.device)
        fused_attention_mask = torch.cat([visual_mask, question_attention_mask], dim=1)
        
        # Wrap cho ViT5
        encoder_outputs = BaseModelOutput(last_hidden_state=fused_features)
        
        if answer_input_ids is not None:
            # Training mode
            labels = answer_input_ids.clone()
            labels[labels == self.pad_token_id] = -100
            
            outputs = self.vit5(
                attention_mask=fused_attention_mask,
                encoder_outputs=encoder_outputs,
                labels=labels,
                return_dict=True
            )
            return outputs
        else:
            # inference mode
            outputs = self.vit5.generate(
                encoder_outputs=encoder_outputs,
                attention_mask=fused_attention_mask,
                max_length=self.config.GEN_MAX_LENGTH,
                min_length=self.config.GEN_MIN_LENGTH,
                num_beams=self.config.GEN_NUM_BEAMS,
                early_stopping=True,
                no_repeat_ngram_size=self.config.GEN_NO_REPEAT_NGRAM,
                length_penalty=self.config.GEN_LENGTH_PENALTY,
                repetition_penalty=self.config.GEN_REPETITION_PENALTY,
                do_sample=False
            )
            return outputs

In [None]:
# ========================================================================
# 10. Metrics
# ========================================================================

class VQAMetrics:
    """Tính toán metrics đánh giá VQA"""
    def __init__(self):
        self.rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
        self.smooth = SmoothingFunction()
    
    def compute_bleu(self, predictions, references):
        """Tính BLEU"""
        scores = []
        for pred, ref in zip(predictions, references):
            pred_tokens = str(pred).lower().split()
            ref_tokens = str(ref).lower().split()
            try:
                score = sentence_bleu([ref_tokens], pred_tokens, smoothing_function=self.smooth.method1)
                scores.append(score)
            except:
                scores.append(0.0)
        return scores
    
    def compute_rouge(self, predictions, references):
        """Tính ROUGE"""
        r1, r2, rL = [], [], []
        for pred, ref in zip(predictions, references):
            try:
                scores = self.rouge_scorer.score(str(ref).lower(), str(pred).lower())
                r1.append(scores['rouge1'].fmeasure)
                r2.append(scores['rouge2'].fmeasure)
                rL.append(scores['rougeL'].fmeasure)
            except:
                r1.append(0.0)
                r2.append(0.0)
                rL.append(0.0)
        return r1, r2, rL
    
    def compute_bertscore(self, predictions, references):
        """Tính BERTScore"""
        try:
            P, R, F1 = bert_score(
                [str(p) for p in predictions], 
                [str(r) for r in references], 
                model_type='bert-base-multilingual-cased',
                verbose=False,
                device='cuda' if torch.cuda.is_available() else 'cpu'
            )
            return F1.mean().item()
        except:
            return 0.0
    
    def compute_cider(self, predictions, references):
        """Tính CIDEr"""
        try:
            gts = {i: [str(r)] for i, r in enumerate(references)}
            res = {i: [str(p)] for i, p in enumerate(predictions)}
            cider_scorer = Cider()
            score, _ = cider_scorer.compute_score(gts, res)
            return score
        except:
            return 0.0
    
    def compute_exact_match(self, predictions, references):
        """Tính Exact Match"""
        matches = []
        for pred, ref in zip(predictions, references):
            pred_norm = str(pred).strip().lower()
            ref_norm = str(ref).strip().lower()
            matches.append(1.0 if pred_norm == ref_norm else 0.0)
        return matches
    
    def compute_all(self, predictions, references):
        """Tính tất cả metrics"""
        bleu = self.compute_bleu(predictions, references)
        r1, r2, rL = self.compute_rouge(predictions, references)
        bertscore = self.compute_bertscore(predictions, references)
        cider = self.compute_cider(predictions, references)
        exact_match = self.compute_exact_match(predictions, references)
        
        return {
            'bleu': np.mean(bleu),
            'rouge1': np.mean(r1),
            'rouge2': np.mean(r2),
            'rougeL': np.mean(rL),
            'bertscore_f1': bertscore,
            'cider': cider,
            'exact_match': np.mean(exact_match)
        }

In [None]:
# ========================================================================
# 11. Trainer 
# ========================================================================

class Trainer:
    """Huấn luyện model với differentiated LR và composite score"""
    def __init__(self, model, train_loader, val_loader, vit5_tokenizer, config):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.vit5_tokenizer = vit5_tokenizer
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.model.to(self.device)
        
        self.optimizer = self._create_optimizer_with_grouped_lr()
        
        # Learning rate scheduler
        num_training_steps = len(train_loader) * config.NUM_EPOCHS // config.GRADIENT_ACCUMULATION
        num_warmup_steps = int(num_training_steps * config.WARMUP_RATIO)
        
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps
        )
        
        # Mixed precision
        self.scaler = GradScaler() if config.USE_FP16 else None
        
        # Metrics
        self.metrics_calculator = VQAMetrics()
        
        # Early stopping với composite score
        self.best_val_loss = float('inf')
        self.best_val_bleu = 0.0
        self.best_composite_score = 0.0 
        self.patience_counter = 0
        
        # History
        self.train_losses = []
        self.val_losses = []
        self.val_metrics_history = []
    
    def _create_optimizer_with_grouped_lr(self):
        """Tạo optimizer với LR khác nhau cho từng component"""
        vit_lora_params = []
        phobert_lora_params = []
        vit5_lora_params = []
        fusion_params = []
        
        for name, param in self.model.named_parameters():
            if not param.requires_grad:
                continue
            
            if 'vit.lora' in name:
                vit_lora_params.append(param)
            elif 'phobert' in name and 'lora' in name.lower():
                phobert_lora_params.append(param)
            elif 'vit5' in name and 'lora' in name.lower():
                vit5_lora_params.append(param)
            elif 'fusion' in name or 'fusion_to_vit5' in name:
                fusion_params.append(param)
        
        param_groups = []
        
        if len(vit_lora_params) > 0:
            param_groups.append({
                'params': vit_lora_params, 
                'lr': self.config.LR_LORA * self.config.LR_VIT_MULTIPLIER,
                'name': 'vit_lora'
            })
        
        if len(phobert_lora_params) > 0:
            param_groups.append({
                'params': phobert_lora_params, 
                'lr': self.config.LR_LORA * self.config.LR_PHOBERT_MULTIPLIER,
                'name': 'phobert_lora'
            })
        
        if len(vit5_lora_params) > 0:
            param_groups.append({
                'params': vit5_lora_params, 
                'lr': self.config.LR_LORA * self.config.LR_VIT5_MULTIPLIER,
                'name': 'vit5_lora'
            })
        
        if len(fusion_params) > 0:
            param_groups.append({
                'params': fusion_params, 
                'lr': self.config.LR_LORA * self.config.LR_FUSION_MULTIPLIER,
                'name': 'fusion'
            })
        
        print("\nDifferentiated Learning Rates:")
        for pg in param_groups:
            print(f"  {pg['name']:15s}: {pg['lr']:.2e} ({len(pg['params'])} param groups)")
        
        return torch.optim.AdamW(
            param_groups,
            weight_decay=self.config.WEIGHT_DECAY
        )
    
    def train_epoch(self, epoch):
        """Train 1 epoch"""
        self.model.train()
        total_loss = 0
        self.optimizer.zero_grad()
        
        pbar = tqdm(self.train_loader, desc=f'Epoch {epoch+1}/{self.config.NUM_EPOCHS}')
        
        for step, batch in enumerate(pbar):
            # Chuyển batch sang device
            batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v 
                    for k, v in batch.items()}
            
            # Forward với mixed precision
            with autocast(enabled=self.config.USE_FP16):
                outputs = self.model(
                    pixel_values=batch['pixel_values'],
                    question_input_ids=batch['question_input_ids'],
                    question_attention_mask=batch['question_attention_mask'],
                    answer_input_ids=batch['answer_input_ids'],
                    answer_attention_mask=batch['answer_attention_mask']
                )
                loss = outputs.loss / self.config.GRADIENT_ACCUMULATION
            
            # Backward
            if self.scaler:
                self.scaler.scale(loss).backward()
            else:
                loss.backward()
            
            # Optimizer step với gradient accumulation
            if (step + 1) % self.config.GRADIENT_ACCUMULATION == 0:
                if self.scaler:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.MAX_GRAD_NORM)
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.MAX_GRAD_NORM)
                    self.optimizer.step()
                
                self.scheduler.step()
                self.optimizer.zero_grad()
            
            # Cập nhật progress bar
            total_loss += loss.item() * self.config.GRADIENT_ACCUMULATION
            current_loss = total_loss / (step + 1)
            pbar.set_postfix({
                'loss': f'{current_loss:.4f}',
                'lr': f'{self.scheduler.get_last_lr()[0]:.2e}'
            })
        
        avg_loss = total_loss / len(self.train_loader)
        self.train_losses.append(avg_loss)
        return avg_loss
    
    @torch.no_grad()
    def evaluate(self, data_loader, desc='Đánh giá', return_details=False):
        """Đánh giá trên validation/test set"""
        self.model.eval()
        total_loss = 0
        all_predictions = []
        all_references = []
        all_questions = []    
        all_image_paths = [] 
        
        for batch in tqdm(data_loader, desc=desc):
            batch_device = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v 
                           for k, v in batch.items()}
            
            # Generate predictions
            generated_ids = self.model(
                pixel_values=batch_device['pixel_values'],
                question_input_ids=batch_device['question_input_ids'],
                question_attention_mask=batch_device['question_attention_mask']
            )
            
            # Decode
            predictions = self.vit5_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            all_predictions.extend(predictions)
            all_references.extend(batch['answer_text'])
            all_questions.extend(batch['question_text'])   
            all_image_paths.extend(batch['image_path'])     
            
            # Tính loss
            try:
                outputs = self.model(
                    pixel_values=batch_device['pixel_values'],
                    question_input_ids=batch_device['question_input_ids'],
                    question_attention_mask=batch_device['question_attention_mask'],
                    answer_input_ids=batch_device['answer_input_ids'],
                    answer_attention_mask=batch_device['answer_attention_mask']
                )
                total_loss += outputs.loss.item()
            except:
                pass
            
            del generated_ids
            torch.cuda.empty_cache()
        
        # Tính metrics
        metrics = self.metrics_calculator.compute_all(all_predictions, all_references)
        metrics['loss'] = total_loss / len(data_loader) if total_loss > 0 else 0.0
        
        if return_details:
            return metrics, all_predictions, all_references, all_questions, all_image_paths
        else:
            return metrics, all_predictions, all_references
    
    def save_checkpoint(self, epoch, metrics, is_best=False):
        """Lưu checkpoint với error handling - ĐÃ CẢI TIẾN"""
        checkpoint_dir = self.config.OUTPUT_DIR
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        try:
            # Lưu ViT Custom LoRA
            vit_lora_path = os.path.join(checkpoint_dir, 'vit_lora.pt')
            torch.save(
                self.model.vit.lora_layers.state_dict(),
                vit_lora_path
            )
            
            # Lưu PhoBERT PEFT LoRA
            phobert_path = os.path.join(checkpoint_dir, 'phobert_lora')
            self.model.phobert.save_pretrained(phobert_path)
            
            # Lưu ViT5 PEFT LoRA
            vit5_path = os.path.join(checkpoint_dir, 'vit5_lora')
            self.model.vit5.save_pretrained(vit5_path)
            
            # Lưu Fusion và components khác
            checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint.pt')
            torch.save({
                'epoch': epoch,
                'fusion': self.model.fusion.state_dict(),
                'projection': self.model.fusion_to_vit5.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'scheduler': self.scheduler.state_dict(),
                'metrics': metrics,
                'config': vars(self.config),
                'vit_lora_config': {
                    'r': self.config.LORA_VIT_R,
                    'alpha': self.config.LORA_VIT_ALPHA,
                    'dropout': self.config.LORA_VIT_DROPOUT
                }
            }, checkpoint_path)
            
        except Exception as e:
            print(f"❌ Error saving checkpoint: {e}")
            import traceback
            traceback.print_exc()
            return False
        
        # Lưu best model
        if is_best:
            try:
                import shutil
                
                shutil.copy(
                    os.path.join(checkpoint_dir, 'vit_lora.pt'),
                    os.path.join(checkpoint_dir, 'best_vit_lora.pt')
                )
                
                for name in ['phobert_lora', 'vit5_lora']:
                    src = os.path.join(checkpoint_dir, name)
                    dst = os.path.join(checkpoint_dir, f'best_{name}')
                    if os.path.exists(dst):
                        shutil.rmtree(dst)
                    shutil.copytree(src, dst)
                
                shutil.copy(
                    os.path.join(checkpoint_dir, 'checkpoint.pt'),
                    os.path.join(checkpoint_dir, 'best_checkpoint.pt')
                )
            except Exception as e:
                print(f"❌ Error saving best model: {e}")
        
        return True
    
    def train(self):
        """Vòng lặp huấn luyện chính với composite score"""
        print("\n" + "="*60)
        print(f"Bắt đầu huấn luyện với LoRA")
        print(f"Device: {self.device}")
        print("="*60)
        
        # Thống kê model
        total_params = sum(p.numel() for p in self.model.parameters())
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        print(f"\nThống kê Model:")
        print(f"  Tổng params:      {total_params:,}")
        print(f"  Trainable params: {trainable_params:,}")
        print(f"  Tỷ lệ trainable:  {100*trainable_params/total_params:.2f}%")
        print("="*60 + "\n")
        
        for epoch in range(self.config.NUM_EPOCHS):
            # Train
            train_loss = self.train_epoch(epoch)
            
            # Validate
            val_metrics, val_preds, val_refs = self.evaluate(self.val_loader, desc='Validation')
            self.val_losses.append(val_metrics['loss'])
            self.val_metrics_history.append(val_metrics)
            
            composite_score = sum(
                val_metrics[k] * self.config.SCORE_WEIGHTS[k]
                for k in self.config.SCORE_WEIGHTS.keys()
                if k in val_metrics
            )
            
            # In kết quả
            print(f"\n{'='*60}")
            print(f"Epoch {epoch+1}/{self.config.NUM_EPOCHS}:")
            print(f"{'='*60}")
            print(f"Train Loss:      {train_loss:.4f}")
            print(f"Val Loss:        {val_metrics['loss']:.4f}")
            print(f"Composite Score: {composite_score:.4f}")  # ===== MỚI =====
            print(f"\nMetrics:")
            print(f"  BLEU:        {val_metrics['bleu']:.4f}")
            print(f"  ROUGE-1:     {val_metrics['rouge1']:.4f}")
            print(f"  ROUGE-2:     {val_metrics['rouge2']:.4f}")
            print(f"  ROUGE-L:     {val_metrics['rougeL']:.4f}")
            print(f"  BERTScore:   {val_metrics['bertscore_f1']:.4f}")
            print(f"  CIDEr:       {val_metrics['cider']:.4f}")
            print(f"  Exact Match: {val_metrics['exact_match']:.4f}")
            
            is_best = False
            if composite_score > self.best_composite_score:
                self.best_composite_score = composite_score
                self.best_val_bleu = val_metrics['bleu']
                self.best_val_loss = val_metrics['loss']
                self.patience_counter = 0
                is_best = True
                print(f"\n✓ Model tốt nhất mới!")
                print(f"  Composite: {composite_score:.4f}")
                print(f"  BLEU:      {val_metrics['bleu']:.4f}")
            else:
                self.patience_counter += 1
                print(f"\n⚠ Không cải thiện ({self.patience_counter}/{self.config.PATIENCE})")
                print(f"  Composite: {composite_score:.4f} (best: {self.best_composite_score:.4f})")
            
            # Lưu checkpoint
            self.save_checkpoint(epoch, val_metrics, is_best)
            
            # Sample predictions
            print(f"\nVí dụ dự đoán:")
            for i in range(min(3, len(val_preds))):
                print(f"  GT:   {val_refs[i][:80]}")
                print(f"  Pred: {val_preds[i][:80]}\n")
            
            print("="*60 + "\n")
            
            # Early stopping
            if self.patience_counter >= self.config.PATIENCE:
                print(f"Early stopping sau {epoch+1} epochs")
                break
        
        # Lưu history
        history = {
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'val_metrics': self.val_metrics_history,
            'best_composite_score': self.best_composite_score,  
            'best_bleu': self.best_val_bleu
        }
        with open(os.path.join(self.config.OUTPUT_DIR, 'training_history.json'), 'w') as f:
            json.dump(history, f, indent=2)
        
        return val_metrics

In [None]:
# ========================================================================
# 12. Main Function 
# ========================================================================

def main():
    print("\n" + "="*60)
    print("ViT + PhoBERT + ViT5 với LoRA")
    print("Hệ thống VQA Di sản Việt Nam")
    print("="*60 + "\n")
    
    # Khởi tạo tokenizers
    print("Loading tokenizers...")
    segmenter = WordSegmenter()
    vit_processor = ViTImageProcessor.from_pretrained(config.VIT_MODEL)
    phobert_tokenizer = AutoTokenizer.from_pretrained(config.PHOBERT_MODEL)
    vit5_tokenizer = AutoTokenizer.from_pretrained(config.VIT5_MODEL)
    print("✓ Done\n")
    
    # Load dữ liệu
    print("Loading dataset...")
    train_samples, val_samples, test_samples = load_and_prepare_data(
        config.JSON_PATH, config.IMAGE_BASE_PATH
    )
    
    # Tạo datasets
    print("Tạo datasets...")
    train_dataset = HeritageVQADataset(
        train_samples, vit_processor, phobert_tokenizer, vit5_tokenizer, segmenter
    )
    val_dataset = HeritageVQADataset(
        val_samples, vit_processor, phobert_tokenizer, vit5_tokenizer, segmenter
    )
    test_dataset = HeritageVQADataset(
        test_samples, vit_processor, phobert_tokenizer, vit5_tokenizer, segmenter
    )
    
    # Tạo data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=config.BATCH_SIZE, 
        shuffle=True,
        num_workers=config.NUM_WORKERS, 
        collate_fn=collate_fn,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=config.BATCH_SIZE, 
        shuffle=False,
        num_workers=config.NUM_WORKERS, 
        collate_fn=collate_fn,
        pin_memory=True
    )
    test_loader = DataLoader(
        test_dataset, 
        batch_size=config.BATCH_SIZE, 
        shuffle=False,
        num_workers=config.NUM_WORKERS, 
        collate_fn=collate_fn,
        pin_memory=True
    )
    print("✓ Done\n")
    
    # Khởi tạo model
    model = ViTPhoBERTViT5Model(config)
    
    # Khởi tạo trainer
    trainer = Trainer(model, train_loader, val_loader, vit5_tokenizer, config)
    
    # Huấn luyện
    trainer.train()
    
    print("\n" + "="*60)
    print("Đánh giá trên test set...")
    print("="*60 + "\n")
    
    test_metrics, test_preds, test_refs, test_questions, test_images = trainer.evaluate(
        test_loader, 
        desc='Testing',
        return_details=True  
    )
    
    print("\nKẾT QUẢ CUỐI CÙNG:")
    print("="*60)
    print(f"BLEU:        {test_metrics['bleu']:.4f}")
    print(f"ROUGE-1:     {test_metrics['rouge1']:.4f}")
    print(f"ROUGE-2:     {test_metrics['rouge2']:.4f}")
    print(f"ROUGE-L:     {test_metrics['rougeL']:.4f}")
    print(f"BERTScore:   {test_metrics['bertscore_f1']:.4f}")
    print(f"CIDEr:       {test_metrics['cider']:.4f}")
    print(f"Exact Match: {test_metrics['exact_match']:.4f}")
    print("="*60 + "\n")
    
    test_results = {
        'metrics': test_metrics,
        'predictions': []
    }
    
    for i, (pred, ref, question, image) in enumerate(zip(
        test_preds, test_refs, test_questions, test_images
    )):
        # Lấy thông tin sample từ test_samples
        sample_info = test_samples[i] if i < len(test_samples) else {}
        
        test_results['predictions'].append({
            'index': i,
            'image': image,
            'question': question,
            'prediction': pred,
            'reference': ref,
            'site_name': sample_info.get('site_name', ''),
            'location': sample_info.get('location', '')
        })
    
    # Lưu results
    with open(os.path.join(config.OUTPUT_DIR, 'test_results.json'), 'w', encoding='utf-8') as f:
        json.dump(test_results, f, ensure_ascii=False, indent=2)
    
    print("✓ Đã lưu test results với cấu trúc đầy đủ")
    
    # In một số ví dụ
    print("\n" + "="*60)
    print("Một số ví dụ dự đoán:")
    print("="*60)
    for i in range(min(5, len(test_preds))):
        print(f"\n[{i+1}] Image: {test_images[i]}")
        print(f"Question:   {test_questions[i]}")
        print(f"Prediction: {test_preds[i]}")
        print(f"Reference:  {test_refs[i]}")
        print("-" * 60)
    
    print("\nHoàn thành huấn luyện!")
    print(f"Kết quả lưu tại: {config.OUTPUT_DIR}")

In [None]:
if __name__ == '__main__':
    main()