Running with Kaggle environment

In [None]:
# ========================================================================
# 1. Cài đặt thư viện
# ========================================================================

import os
os.environ["WANDB_DISABLED"] = "true"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

!pip install -q transformers==4.45.0 datasets peft accelerate bitsandbytes
!pip install -q pillow torch torchvision qwen-vl-utils
!pip install -q rouge-score bert-score nltk
!pip install -q git+https://github.com/salaniz/pycocoevalcap.git

import nltk
nltk.download('punkt', quiet=True)
nltk.download('punkt_tab', quiet=True)

In [None]:
# ========================================================================
# 2. Import thư viện
# ========================================================================

import json
import torch
import gc
import numpy as np
from PIL import Image
from pathlib import Path
from typing import List, Dict, Optional

from transformers import (
    Qwen2VLForConditionalGeneration,
    AutoProcessor,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig
)
from transformers import EarlyStoppingCallback
from peft import LoraConfig, get_peft_model
from datasets import Dataset
from qwen_vl_utils import process_vision_info

from rouge_score import rouge_scorer
from bert_score import score as bert_score
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from pycocoevalcap.cider.cider import Cider

In [None]:
# ========================================================================
# 3. Config
# ========================================================================

class Config:
    MODEL_NAME = "Qwen/Qwen2-VL-7B-Instruct"
    DATA_PATH = "/kaggle/input/dataset/vn-revolutionary-heritage-vqa.json"
    IMAGE_BASE_PATH = "/kaggle/input/images"
    OUTPUT_DIR = "/kaggle/working/qwen2vl-heritage-vqa"
    
    MAX_LENGTH = 640 
    IMAGE_MAX_SIZE = 336
    
    TEST_SIZE = 0.10
    VAL_SIZE = 0.10
    
    BATCH_SIZE = 1
    EVAL_BATCH_SIZE = 8 
    GRADIENT_ACCUMULATION = 32
    
    NUM_EPOCHS = 3
    WARMUP_RATIO = 0.05
    
    SAVE_STRATEGY = "epoch"
    EVAL_STRATEGY = "epoch"
    LOGGING_STEPS = 25
    
    LORA_R = 32
    LORA_ALPHA = 64
    LORA_DROPOUT = 0.1

config = Config()

In [None]:
# ========================================================================
# 4. Class tính metrics
# ========================================================================

class VQAMetrics:
    def __init__(self):
        self.rouge_scorer = rouge_scorer.RougeScorer(
            ['rouge1', 'rouge2', 'rougeL'], use_stemmer=True
        )
        self.smooth = SmoothingFunction()
    
    def compute_bleu(self, predictions, references):
        scores = []
        for pred, ref in zip(predictions, references):
            pred_tokens = str(pred).lower().split()
            ref_tokens = str(ref).lower().split()
            try:
                score = sentence_bleu([ref_tokens], pred_tokens, 
                                     smoothing_function=self.smooth.method1)
                scores.append(score)
            except:
                scores.append(0.0)
        return scores
    
    def compute_rouge(self, predictions, references):
        r1, r2, rL = [], [], []
        for pred, ref in zip(predictions, references):
            try:
                scores = self.rouge_scorer.score(str(ref).lower(), str(pred).lower())
                r1.append(scores['rouge1'].fmeasure)
                r2.append(scores['rouge2'].fmeasure)
                rL.append(scores['rougeL'].fmeasure)
            except:
                r1.append(0.0)
                r2.append(0.0)
                rL.append(0.0)
        return r1, r2, rL
    
    def compute_bertscore(self, predictions, references):
        try:
            P, R, F1 = bert_score(
                [str(p) for p in predictions], 
                [str(r) for r in references], 
                lang='vi', verbose=False,
                device='cuda' if torch.cuda.is_available() else 'cpu'
            )
            return F1.mean().item()
        except:
            return 0.0
    
    def compute_cider(self, predictions, references):
        try:
            gts = {i: [str(r)] for i, r in enumerate(references)}
            res = {i: [str(p)] for i, p in enumerate(predictions)}
            cider_scorer = Cider()
            score, _ = cider_scorer.compute_score(gts, res)
            return score
        except:
            return 0.0
    
    def compute_exact_match(self, predictions, references):
        matches = []
        for pred, ref in zip(predictions, references):
            pred_norm = str(pred).strip().lower()
            ref_norm = str(ref).strip().lower()
            matches.append(1.0 if pred_norm == ref_norm else 0.0)
        return matches
    
    def compute_all(self, predictions, references):
        bleu = self.compute_bleu(predictions, references)
        r1, r2, rL = self.compute_rouge(predictions, references)
        bertscore = self.compute_bertscore(predictions, references)
        cider = self.compute_cider(predictions, references)
        exact_match = self.compute_exact_match(predictions, references)
        
        return {
            'bleu': np.mean(bleu),
            'rouge1': np.mean(r1),
            'rouge2': np.mean(r2),
            'rougeL': np.mean(rL),
            'bertscore_f1': bertscore,
            'cider': cider,
            'exact_match': np.mean(exact_match)
        }

In [None]:
# ========================================================================
# 5. Xử lý data
# ========================================================================

def resize_image(image_path: str, max_size: int = 336) -> Optional[Image.Image]:
    try:
        img = Image.open(image_path)
        if img.mode in ("RGBA", "P", "LA"):
            img = img.convert("RGB")
        
        if max(img.size) > max_size:
            img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
        
        return img
    except:
        return None


def load_and_prepare_data(json_path, image_base_path, max_samples=None):
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    # GROUP theo IMAGE
    image_groups = {}  # key: image_path, value: [samples]
    
    for item in data:
        name = item['name']
        location = item.get('location_after', item.get('location_before', ''))
        
        for img_data in item.get('images', []):
            img_path = os.path.join(image_base_path, img_data['image'])
            
            if not os.path.exists(img_path):
                continue
            
            img = resize_image(img_path, config.IMAGE_MAX_SIZE)
            if img is None:
                continue
            
            # Key = image path để group
            img_key = img_data['image']
            
            if img_key not in image_groups:
                image_groups[img_key] = []
            
            for qa in img_data.get('qas', []):
                if not qa['answer'] or len(qa['answer'].strip()) < 2:
                    continue
                
                image_groups[img_key].append({
                    'image': img,
                    'question': qa['question'],
                    'answer': qa['answer'],
                    'context': f"{name} tại {location}.",
                    'site_name': name,
                    'image_path': img_key
                })
    
    # SPLIT theo IMAGE groups
    image_keys = list(image_groups.keys())
    np.random.seed(42)
    np.random.shuffle(image_keys)
    
    n_images = len(image_keys)
    n_test_imgs = int(n_images * config.TEST_SIZE)
    n_val_imgs = int(n_images * config.VAL_SIZE)
    
    test_img_keys = image_keys[:n_test_imgs]
    val_img_keys = image_keys[n_test_imgs:n_test_imgs+n_val_imgs]
    train_img_keys = image_keys[n_test_imgs+n_val_imgs:]
    
    # Flatten
    train_samples = [s for key in train_img_keys for s in image_groups[key]]
    val_samples = [s for key in val_img_keys for s in image_groups[key]]
    test_samples = [s for key in test_img_keys for s in image_groups[key]]

    np.random.shuffle(train_samples) 
    
    print(f"Images - Train: {len(train_img_keys)} | Val: {len(val_img_keys)} | Test: {len(test_img_keys)}")
    print(f"Samples - Train: {len(train_samples)} | Val: {len(val_samples)} | Test: {len(test_samples)}")

    # Lưu file
    os.makedirs(config.OUTPUT_DIR, exist_ok=True)
    test_metadata = [
        {
            'question': s['question'],
            'answer': s['answer'],
            'context': s['context'],
            'site_name': s['site_name'],
            'image_path': s['image_path']  
        }
        for s in test_samples
    ]
    
    with open('/kaggle/working/qwen2vl-heritage-vqa/test_set.json', 'w', encoding='utf-8') as f:
        json.dump(test_metadata, f, ensure_ascii=False, indent=2)
    
    print(f"✓ Đã lưu {len(test_metadata)} test samples vào test_set.json")
    
    return train_samples, val_samples, test_samples


def create_conversation_format(sample):
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": sample['image']},
                {"type": "text", "text": sample['question']}
            ]
        },
        {
            "role": "assistant",
            "content": [{"type": "text", "text": sample['answer']}]
        }
    ]
    return messages

In [None]:
# ========================================================================
# 6. Data Collator
# ========================================================================

class VQADataCollator:
    def __init__(self, processor, max_length=640):
        self.processor = processor
        self.max_length = max_length
        self.IGNORE_INDEX = -100

        self.processor.tokenizer.padding_side = 'right'
        
        # Lấy token IDs
        self.im_start_id = processor.tokenizer.convert_tokens_to_ids("<|im_start|>")
        self.im_end_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
        self.assistant_token_ids = processor.tokenizer.encode("assistant", add_special_tokens=False)
        self.newline_id = processor.tokenizer.encode("\n", add_special_tokens=False)[0]

        print("="*70)
        print("DEBUG COLLATOR INIT")
        print("="*70)
        print(f"im_start_id      : {self.im_start_id}")
        print(f"im_end_id        : {self.im_end_id}")
        print(f"assistant_tokens : {self.assistant_token_ids}")
        print(f"newline_id       : {self.newline_id}")
        print(f"pad_token_id     : {processor.tokenizer.pad_token_id}")
        print(f"eos_token_id     : {processor.tokenizer.eos_token_id}")
        print("="*70)
        
        # Validate
        if self.im_start_id == processor.tokenizer.unk_token_id:
            raise ValueError("<|im_start|> không có trong vocab")
        if self.im_end_id == processor.tokenizer.unk_token_id:
            raise ValueError("<|im_end|> không có trong vocab")
    
    def find_assistant_start(self, input_ids):
        """Tìm vị trí bắt đầu của assistant response"""
        input_ids = input_ids.tolist() if torch.is_tensor(input_ids) else input_ids
        
        for i in range(len(input_ids) - len(self.assistant_token_ids) - 1):
            if input_ids[i] != self.im_start_id:
                continue
            
            # Kiểm tra match "assistant"
            match = True
            for j, token_id in enumerate(self.assistant_token_ids):
                if i + 1 + j >= len(input_ids) or input_ids[i + 1 + j] != token_id:
                    match = False
                    break
            
            if not match:
                continue
            
            pos = i + 1 + len(self.assistant_token_ids)
            
            # Skip newline
            if pos < len(input_ids) and input_ids[pos] == self.newline_id:
                pos += 1
            
            # Kiểm tra answer rỗng
            if pos >= len(input_ids) or input_ids[pos] == self.im_end_id:
                return None
            
            return pos
        
        return None
    
    def __call__(self, features):
        texts = []
        images_list = []
        
        for feat in features:
            messages = create_conversation_format(feat)
            
            text = self.processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=False
            )
            texts.append(text)
            
            image_inputs, _ = process_vision_info(messages)
            images_list.append(image_inputs)
        
        # Tokenize
        batch = self.processor(
            text=texts,
            images=images_list,
            return_tensors="pt",
            padding="longest",
            truncation=True,
            max_length=self.max_length
        )
        
        # Tạo labels
        labels = batch["input_ids"].clone()
        labels[labels == self.processor.tokenizer.pad_token_id] = self.IGNORE_INDEX
        
        skipped = 0

        debug_this_batch = not hasattr(self, '_first_batch_debugged')
        if debug_this_batch:
            self._first_batch_debugged = True
            print("\n" + "="*70)
            print("DEBUG FIRST BATCH - KIỂM TRA MASK")
            print("="*70)
        
        for i in range(len(texts)):
            input_ids = batch["input_ids"][i]
            assistant_start = self.find_assistant_start(input_ids)

            if debug_this_batch and i == 0:
                print(f"\n--- SAMPLE 0 ---")
                print(f"Input length: {len(input_ids)}")
                
                # Decode toàn bộ
                full_text = self.processor.tokenizer.decode(input_ids, skip_special_tokens=False)
                print(f"\nFULL TEXT:")
                print(full_text[:600]) 
                print("...\n")
                
                # Tìm vị trí special tokens
                im_start_positions = [j for j, tok in enumerate(input_ids) if tok == self.im_start_id]
                im_end_positions = [j for j, tok in enumerate(input_ids) if tok == self.im_end_id]
                
                print(f"Token positions:")
                print(f"   <|im_start|> at: {im_start_positions}")
                print(f"   <|im_end|> at  : {im_end_positions}")
                print(f"   Assistant start: {assistant_start}")
                
                if assistant_start is not None:
                    # Decode phần prompt (sẽ bị mask)
                    prompt_tokens = input_ids[:assistant_start]
                    prompt_text = self.processor.tokenizer.decode(prompt_tokens, skip_special_tokens=False)
                    print(f"\nPROMPT (will be MASKED):")
                    print(f"{prompt_text[:400]}")
                    print("...")
                    
                    # Decode phần answer (trainable)
                    answer_tokens = input_ids[assistant_start:]
                    answer_text = self.processor.tokenizer.decode(answer_tokens, skip_special_tokens=False)
                    print(f"\nANSWER (TRAINABLE):")
                    print(f"{answer_text}")
                    
                    # Đếm tokens
                    valid_mask = (input_ids != self.processor.tokenizer.pad_token_id)
                    temp_labels = labels[i].clone()
                    temp_labels[:assistant_start] = self.IGNORE_INDEX
                    trainable = ((temp_labels != self.IGNORE_INDEX) & valid_mask).sum().item()
                    
                    print(f"\nStatistics:")
                    print(f"   Total tokens    : {len(input_ids)}")
                    print(f"   Prompt tokens   : {assistant_start}")
                    print(f"   Answer tokens   : {len(input_ids) - assistant_start}")
                    print(f"   Trainable tokens: {trainable}")
                else:
                    print("\nERROR: Assistant start NOT FOUND!")
                    print("   → This sample will be SKIPPED")
            
            if assistant_start is None:
                labels[i, :] = self.IGNORE_INDEX
                skipped += 1
                continue
            
            # Mask prompt
            labels[i, :assistant_start] = self.IGNORE_INDEX
            
            # Verify có token trainable
            valid_mask = (input_ids != self.processor.tokenizer.pad_token_id)
            trainable = ((labels[i] != self.IGNORE_INDEX) & valid_mask).sum().item()
            
            if trainable == 0:
                labels[i, :] = self.IGNORE_INDEX
                skipped += 1

        if debug_this_batch:
            print("="*70 + "\n")
        
        if skipped > 0:
            print(f"⚠ Skipped {skipped}/{len(texts)} samples trong batch")
        
        batch["labels"] = labels
        
        #debug
        valid_labels = (labels != self.IGNORE_INDEX).sum(dim=1)
        # print(f"Valid labels per sample: {valid_labels.tolist()}")
        
        if (valid_labels == 0).any():
            print(f"⚠ WARNING: {(valid_labels == 0).sum()} samples có 0 trainable tokens!")
        
        batch["labels"] = labels
        
        return batch

In [None]:
# ========================================================================
# 7. Load Model
# ========================================================================

def load_model_and_processor():
    torch.cuda.empty_cache()
    gc.collect()
    
    print("Đang load model...")
    
    processor = AutoProcessor.from_pretrained(
        config.MODEL_NAME,
        trust_remote_code=True,
        min_pixels=config.IMAGE_MAX_SIZE**2,
        max_pixels=config.IMAGE_MAX_SIZE**2
    )
    
    # Set padding side
    processor.tokenizer.padding_side = 'right'

    if processor.tokenizer.pad_token_id is None:
        processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id
    
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
    )
    
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        config.MODEL_NAME,
        device_map="auto",
        quantization_config=bnb_config,
        trust_remote_code=True
    )
    
    return model, processor

In [None]:
# ========================================================================
# 8. Setup LoRA
# ========================================================================

def setup_lora(model):
    lora_config = LoraConfig(
        r=config.LORA_R,
        lora_alpha=config.LORA_ALPHA,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                       "gate_proj", "up_proj", "down_proj",
                       "qkv", "proj", "fc1", "fc2"],
        lora_dropout=config.LORA_DROPOUT,
        bias="none",
        task_type="CAUSAL_LM"
    )
    
    model = get_peft_model(model, lora_config)

    for name, param in model.named_parameters():
        if 'visual' in name and 'lora' not in name:
            param.requires_grad = False
    
    model.gradient_checkpointing_enable()
    model.config.use_cache = False
    
    print("✓ LoRA applied")
    model.print_trainable_parameters()

    print("\n" + "="*70)
    print("VERIFICATION")
    print("="*70)
    visual_lora = sum(p.numel() for n, p in model.named_parameters() 
                     if 'visual' in n and 'lora' in n and p.requires_grad)
    text_lora = sum(p.numel() for n, p in model.named_parameters() 
                   if 'visual' not in n and p.requires_grad)
    print(f"Visual LoRA: {visual_lora:,} params")
    print(f"Text LoRA:   {text_lora:,} params")
    print(f"Total:       {visual_lora + text_lora:,} params")
    print("="*70 + "\n")
    
    return model

In [None]:
# ========================================================================
# 9. Training Arguments
# ========================================================================

def get_training_args():
    return TrainingArguments(
        output_dir=config.OUTPUT_DIR,
        num_train_epochs=config.NUM_EPOCHS,
        per_device_train_batch_size=config.BATCH_SIZE,
        per_device_eval_batch_size=config.EVAL_BATCH_SIZE,
        gradient_accumulation_steps=config.GRADIENT_ACCUMULATION,
        learning_rate=5e-6,
        lr_scheduler_type="cosine",
        warmup_ratio=config.WARMUP_RATIO,
        weight_decay=0.01,
        max_grad_norm=1.0,
        fp16=True,
        bf16=False,
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        dataloader_num_workers=2,
        dataloader_pin_memory=True,
        dataloader_prefetch_factor=2,
        remove_unused_columns=False,
        dataloader_drop_last=False,
        logging_steps=config.LOGGING_STEPS,
        eval_strategy=config.EVAL_STRATEGY,    
        save_strategy=config.SAVE_STRATEGY,
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        optim="adamw_torch_fused",
        report_to="none",
        save_safetensors=True,
        disable_tqdm=False,
    )

In [None]:
# ========================================================================
# 10. Evaluation
# ========================================================================

def evaluate_model(model, processor, test_samples, batch_size=8):
    print(f"Bắt đầu đánh giá trên {len(test_samples)} samples")

    original_padding_side = processor.tokenizer.padding_side

    processor.tokenizer.padding_side = 'left'
    
    try:
        model.eval()
        metrics_calculator = VQAMetrics()
        
        predictions = []
        references = []
        sample_indices = []
        
        for i in range(0, len(test_samples), batch_size):
            batch_samples = test_samples[i:i+batch_size]
            
            if (i // batch_size + 1) % 10 == 0:
                print(f"Đã xử lý {i}/{len(test_samples)} samples")
            
            batch_texts = []
            batch_images = []
            batch_refs = []
            batch_idx = []
            
            for j, sample in enumerate(batch_samples):
                try:
                    img = sample['image']
                    
                    messages = [
                        {
                            "role": "user",
                            "content": [
                                {"type": "image", "image": img},
                                {"type": "text", "text": sample['question']}
                            ]
                        }
                    ]
                    
                    text = processor.apply_chat_template(
                        messages, tokenize=False, add_generation_prompt=True
                    )
                    image_inputs, _ = process_vision_info(messages)
                    
                    batch_texts.append(text)
                    batch_images.append(image_inputs)
                    batch_refs.append(sample['answer'])
                    batch_idx.append(i + j)
                    
                except Exception as e:
                    print(f"⚠ Lỗi sample {i+j}: {e}")
                    continue
            
            if len(batch_texts) > 0:
                try:
                    # Flatten images list
                    flat_images = []
                    for img_list in batch_images:
                        flat_images.extend(img_list)
                    
                    inputs = processor(
                        text=batch_texts,
                        images=flat_images,
                        padding=True,
                        return_tensors="pt"
                    ).to(model.device)
                    
                    with torch.no_grad():
                        outputs = model.generate(
                            **inputs,
                            max_new_tokens=256,
                            do_sample=False,
                            repetition_penalty=1.1,
                            pad_token_id=processor.tokenizer.pad_token_id
                        )
                    
                    for j, output in enumerate(outputs):
                        answer = processor.decode(
                            output[inputs['input_ids'].shape[1]:],
                            skip_special_tokens=True
                        )
                        predictions.append(answer)
                        references.append(batch_refs[j])
                        sample_indices.append(batch_idx[j])
                    
                except Exception as e:
                    print(f"⚠ Lỗi inference batch {i}: {e}")
            
            # Clear cache mỗi 10 batches
            if (i // batch_size + 1) % 10 == 0:
                torch.cuda.empty_cache()
        
        print("\n" + "="*70)
        print("KẾT QUẢ TEST SET")
        print("="*70)
        
        test_metrics = metrics_calculator.compute_all(predictions, references)
        
        for key, value in test_metrics.items():
            print(f"{key:15s}: {value:.4f}")
        print("="*70)
        
        results = {
            'metrics': test_metrics,
            'num_samples': len(test_samples),
            'processed_samples': len(predictions),
            'sample_predictions': [
                {
                    'question': test_samples[sample_indices[i]]['question'],
                    'predicted': predictions[i],
                    'reference': references[i],
                    'image_path': test_samples[sample_indices[i]]['image_path']
                }
                for i in range(min(20, len(predictions)))
            ]
        }
        
        with open(f"{config.OUTPUT_DIR}/test_results.json", 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
        
        return test_metrics
    finally:
        # Restore lại setting ban đầu
        processor.tokenizer.padding_side = original_padding_side
        print(f"\n✓ Đã restore padding_side về '{original_padding_side}'")

In [None]:
# ========================================================================
# 11. Main Training
# ========================================================================

def train():
    torch.cuda.empty_cache()
    gc.collect()
    
    print("="*70)
    print("BƯỚC 1: LOAD DATA")
    print("="*70)
    train_samples, val_samples, test_samples = load_and_prepare_data(
        config.DATA_PATH, config.IMAGE_BASE_PATH
    )
    
    train_dataset = Dataset.from_list(train_samples)
    val_dataset = Dataset.from_list(val_samples)
    
    print("\n" + "="*70)
    print("BƯỚC 2: LOAD MODEL")
    print("="*70)
    model, processor = load_model_and_processor()
    
    print("\n" + "="*70)
    print("BƯỚC 3: SETUP LoRA")
    print("="*70)
    model = setup_lora(model)
    
    print("\n" + "="*70)
    print("BƯỚC 4: CHUẨN BỊ COLLATOR")
    print("="*70)
    data_collator = VQADataCollator(processor, max_length=config.MAX_LENGTH)
    
    # Test collator
    print("Kiểm tra collator...")
    test_batch = data_collator([train_samples[0], train_samples[1]])
    print("✓ Collator OK")
    
    print("\n" + "="*70)
    print("BƯỚC 5: TEST FORWARD PASS")
    print("="*70)
    
    model.eval()
    test_sample = [train_samples[0]]
    test_batch = data_collator(test_sample)
    test_batch = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v 
                  for k, v in test_batch.items()}
    
    with torch.no_grad():
        outputs = model(**test_batch)
        test_loss = outputs.loss
        
        if torch.isnan(test_loss) or torch.isinf(test_loss):
            raise RuntimeError(f"Forward pass tạo invalid loss: {test_loss.item()}")
        
        print(f"✓ Forward pass OK, loss: {test_loss.item():.4f}")
    
    model.train()
    
    print("\n" + "="*70)
    print("BƯỚC 6: TRAINING")
    print("="*70)
    
    training_args = get_training_args()
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
        callbacks=[
            EarlyStoppingCallback(
                early_stopping_patience=2,      
                early_stopping_threshold=0.005
            )
        ] 
    )
    
    trainer.train()
    
    print("\n" + "="*70)
    print("BƯỚC 7: LƯU MODEL")
    print("="*70)
    trainer.save_model(config.OUTPUT_DIR)
    processor.save_pretrained(config.OUTPUT_DIR)
    print(f"✓ Đã lưu model tại {config.OUTPUT_DIR}")
    
    print("\n" + "="*70)
    print("BƯỚC 8: ĐÁNH GIÁ TEST SET")
    print("="*70)
    test_metrics = evaluate_model(model, processor, test_samples)
    
    del model, trainer
    torch.cuda.empty_cache()
    gc.collect()
    
    return test_metrics

In [None]:
# ========================================================================
# 12. Chạy
# ========================================================================

if __name__ == "__main__":
    print("="*70)
    print("QWEN2-VL VIETNAMESE HERITAGE VQA TRAINING")
    print("="*70)
    

    print("\nBắt đầu training...")
    test_metrics = train()
    print("\nHOÀN THÀNH!")

In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
from PIL import Image
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
from peft import PeftModel
from qwen_vl_utils import process_vision_info

# Load từ CHECKPOINT-423 thay vì folder gốc
CHECKPOINT_PATH = "/kaggle/working/qwen2vl-heritage-vqa/checkpoint-423"

processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
processor.tokenizer.padding_side = 'left'

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

base_model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct",
    device_map="auto",
    quantization_config=bnb_config
)

# Load từ checkpoint-423
model = PeftModel.from_pretrained(base_model, CHECKPOINT_PATH)
model.eval()

# Reset generation config
model.generation_config.temperature = None
model.generation_config.top_p = None  
model.generation_config.top_k = None
model.generation_config.repetition_penalty = 1.2

print("✓ Model loaded from checkpoint-423!")

# Test
img = Image.open("/kaggle/input/images/images/can_tho/can_cu_vuon_man/19.png").convert('RGB')
img.thumbnail((336, 336), Image.Resampling.LANCZOS)

messages = [{"role": "user", "content": [
    {"type": "image", "image": img},
    {"type": "text", "text": "Điều gì đặc biệt về những kỷ vật được trưng bày trong bức ảnh?"}
]}]

text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, _ = process_vision_info(messages)
inputs = processor(text=[text], images=image_inputs, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model.generate(**inputs, max_new_tokens=256, do_sample=False, 
                             pad_token_id=processor.tokenizer.pad_token_id)

answer = processor.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
print(f"Trả lời: {answer}")