In [1]:
import sys
!{sys.executable} -m pip install transformers accelerate matplotlib qwen_vl_utils datasets



In [2]:
import torch
import numpy as np
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Optional, Dict, List
import matplotlib.pyplot as plt

In [3]:
from huggingface_hub import login
login("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")

# import os
# os.environ["HF_TOKEN"] = "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"


In [4]:
import torch
import torch.nn as nn
import math
from typing import Optional, List, Tuple
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from PIL import Image
import requests
from io import BytesIO
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor


# ============================================================
# COMPONENT 1: Learnable γ network
# γₜ = σ(W hₜ) — học khi nào cần trust image
# ============================================================
class LearnableGamma(nn.Module):
    """
    Học adaptive visual trust coefficient từ hidden state.

    γₜ nhỏ → model tự tin vào visual evidence → tăng correction
    γₜ lớn → token ngữ pháp → giảm correction
    """
    def __init__(self, hidden_size: int):
        super().__init__()
        # Simple linear projection: W ∈ ℝ^(hidden_size → 1)
        self.proj = nn.Linear(hidden_size, 1, bias=True)
        # Init bias nhỏ để γ bắt đầu ~0.5
        nn.init.zeros_(self.proj.weight)
        nn.init.constant_(self.proj.bias, 0.0)

    def forward(self, h_t: torch.Tensor) -> torch.Tensor:
        """
        Args:
            h_t: hidden state [batch, hidden_size]
        Returns:
            gamma: [batch, 1] ∈ (0, 1)
        """
        return torch.sigmoid(self.proj(h_t))  # γₜ = σ(W hₜ)


# ============================================================
# COMPONENT 2: Hybrid Patch + Region Encoder
# Visual tokens = {patch tokens} ∪ {region tokens}
# ============================================================
class HybridVisualEncoder(nn.Module):
    """
    Kết hợp hai mức trừu tượng:
    - Patch tokens: perception (texture, background, global context)
    - Region tokens: grounding anchor (object-level semantic)

    Trong Qwen2VL, patch tokens đã có sẵn trong visual_hidden_states.
    Region tokens được tổng hợp từ patch tokens qua attention pooling.
    """
    def __init__(self, hidden_size: int, num_regions: int = 16):
        super().__init__()
        self.num_regions = num_regions
        self.hidden_size = hidden_size

        # Learnable region query vectors (như DETR object queries)
        # Mỗi query học cách attend vào một "semantic region" khác nhau
        self.region_queries = nn.Parameter(
            torch.randn(num_regions, hidden_size) * 0.02
        )

        # Cross-attention: region queries attend patch tokens
        self.region_attn = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=8,
            batch_first=True
        )

        # Fusion gate: học trọng số blend patch vs region
        self.fusion_gate = nn.Linear(hidden_size * 2, 1)

    def forward(self, patch_tokens: torch.Tensor) -> torch.Tensor:
        """
        Args:
            patch_tokens: [batch, num_patches, hidden_size]
        Returns:
            hybrid_tokens: [batch, num_patches + num_regions, hidden_size]
        """
        batch_size = patch_tokens.shape[0]

        # Expand region queries cho batch
        queries = self.region_queries.unsqueeze(0).expand(batch_size, -1, -1)
        # queries: [batch, num_regions, hidden_size]

        # Cross-attention: region queries ← patch tokens
        # Mỗi region query học attend vào spatial region khác nhau
        region_tokens, _ = self.region_attn(
            query=queries,       # [batch, num_regions, H]
            key=patch_tokens,    # [batch, num_patches, H]
            value=patch_tokens   # [batch, num_patches, H]
        )
        # region_tokens: [batch, num_regions, hidden_size]

        # Concatenate: Visual tokens = patches ∪ regions
        hybrid_tokens = torch.cat([patch_tokens, region_tokens], dim=1)
        # [batch, num_patches + num_regions, hidden_size]

        return hybrid_tokens


# ============================================================
# MAIN: Attention-aware M3ID với Learnable γ
# Formula: l̂ₜ = lc + αₜ · ((1-γₜ)/γₜ) · (lc - lu)
# ============================================================
class HybridAttentionM3ID:
    """
    Full framework:
    1. Hybrid Patch + Region Encoder → richer visual representation
    2. Attention-aware scaling (αₜ) → correction chỉ khi attend image
    3. Learnable γₜ → adaptive visual trust

    Decode formula:
        l̂ₜ = lc + αₜ · ((1-γₜ)/γₜ) · (lc - lu)

    So sánh với M3ID gốc:
        l̂ₜ = lc + [indicator] · ((1-exp(-λt))/exp(-λt)) · (lc - lu)

    Thay thế:
    - exp(-λt) heuristic → γₜ = σ(W hₜ) learnable
    - binary indicator → αₜ = attention mass (continuous)
    """

    def __init__(
        self,
        model: Qwen2_5_VLForConditionalGeneration,
        processor: AutoProcessor,
        hidden_size: int = 2048,       # Qwen2VL-7B hidden size is 3584 and 3B is 2048
        num_regions: int = 16,          # số region tokens
        gamma_lr: float = 1e-4,         # learning rate cho γ network
        device: str = "cuda" if torch.cuda.is_available() else "cpu"
    ):
        self.model = model
        self.processor = processor
        self.device = device
        self.hidden_size = hidden_size

        # === Learnable components ===
        self.gamma_net = LearnableGamma(hidden_size).to(device=device, dtype=torch.float16)
        self.hybrid_encoder = HybridVisualEncoder(hidden_size, num_regions).to(device=device, dtype=torch.float16)

        # Optimizer chỉ cho các learnable components (không train lại LLM)
        self.optimizer = torch.optim.Adam(
            list(self.gamma_net.parameters()) +
            list(self.hybrid_encoder.parameters()),
            lr=gamma_lr
        )

    # ----------------------------------------------------------
    # Image loading
    # ----------------------------------------------------------
    def load_image(self, image_source) -> Image.Image:
      if isinstance(image_source, Image.Image):
          return image_source
      if isinstance(image_source, str) and image_source.startswith(('http://', 'https://')):
          headers = {
              "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
          }
          response = requests.get(image_source, headers=headers, timeout=10)
          # Check HTTP status
          response.raise_for_status()
          # Check content type
          content_type = response.headers.get("Content-Type", "")
          if "image" not in content_type:
              raise ValueError(f"URL did not return an image. Content-Type: {content_type}")
          return Image.open(BytesIO(response.content)).convert("RGB")
      return Image.open(image_source).convert("RGB")

    # ----------------------------------------------------------
    # Input preparation
    # ----------------------------------------------------------
    def _prepare_inputs_with_image(self, prompt: str, image: Image.Image) -> dict:
        messages = [{
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt},
            ],
        }]
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        return self.processor(text=[text], images=[image], return_tensors="pt", padding=True)

    def _prepare_inputs_without_image(self, prompt: str) -> dict:
        messages = [{
            "role": "user",
            "content": [{"type": "text", "text": prompt}],
        }]
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        return self.processor(text=[text], images=None, return_tensors="pt", padding=True)

    # ----------------------------------------------------------
    # COMPONENT: Tính attention mass αₜ
    # αₜ = tổng attention weight từ text token hiện tại → image tokens
    # ----------------------------------------------------------
    def _compute_attention_mass(
        self,
        attentions: Tuple,              # tuple of [batch, heads, seq, seq]
        num_image_tokens: int,
        layer_idx: int = -1             # dùng layer cuối (most semantic)
    ) -> torch.Tensor:
        """
        Tính αₜ = attention mass từ token cuối → image tokens.

        Dùng layer attention cuối cùng vì nó capture
        semantic-level dependencies (không phải syntactic).

        Returns:
            alpha_t: scalar tensor ∈ [0, 1]
        """
        if attentions is None:
            # Fallback: không có attention → assume moderate attention
            return torch.tensor(0.5, device=self.device)

        # Lọc ra các layer có attention weights thực sự (không phải None)
        valid_attentions = [a for a in attentions if a is not None]

        # Fallback 2: tất cả layers đều là None
        if len(valid_attentions) == 0:
            return torch.tensor(0.5, device=self.device)

        # Lấy layer hợp lệ theo layer_idx
        # Nếu layer_idx=-1 → lấy layer cuối trong danh sách hợp lệ
        attn_layer = valid_attentions[layer_idx]  # [batch, heads, seq_len, seq_len]

        # Attention từ token cuối (vị trí -1) đến tất cả positions
        # Average across heads
        attn_last_token = attn_layer[0, :, -1, :]   # [heads, seq_len]
        attn_avg = attn_last_token.mean(dim=0)       # [seq_len]

        # Image tokens nằm ở đầu sequence
        image_start = 1
        image_end = min(image_start + num_image_tokens, attn_avg.shape[0])

        # αₜ = tổng attention mass đến image tokens
        alpha_t = attn_avg[image_start:image_end].mean()

        return alpha_t.clamp(0.0, 1.0)

    # ----------------------------------------------------------
    # GENERATE: Main decode loop
    # ----------------------------------------------------------
    @torch.no_grad()
    def generate(
        self,
        prompt: str,
        image_path: str,
        max_new_tokens: int = 100,
        temperature: float = 0.7,
        top_p: float = 0.9,
        verbose: bool = True
    ) -> str:
        """
        Decode với full framework:

        Mỗi bước t:
        1. lc = log p(yₜ | x, image)
        2. lu = log p(yₜ | x)
        3. αₜ = attention mass → image tokens
        4. γₜ = σ(W hₜ) từ hidden state
        5. l̂ₜ = lc + αₜ · ((1-γₜ)/γₜ) · (lc - lu)
        6. Sample yₜ ~ softmax(l̂ₜ)
        """
        image = self.load_image(image_path)

        # Prepare inputs
        inputs_c = self._prepare_inputs_with_image(prompt, image)
        inputs_u = self._prepare_inputs_without_image(prompt)
        inputs_c = {k: v.to(self.device) for k, v in inputs_c.items()}
        inputs_u = {k: v.to(self.device) for k, v in inputs_u.items()}

        # === PREFILL: lần đầu chạy cả sequence, khởi tạo KV cache ===
        # output_attentions=True để lấy αₜ
        outputs_c = self.model(
            **inputs_c,
            use_cache=True,
            output_attentions=True,
            output_hidden_states=True,   # cần hₜ cho γₜ
        )
        past_kv_c = outputs_c.past_key_values
        logits_c   = outputs_c.logits[:, -1, :]

        # Hidden state của layer cuối, token cuối → dùng cho γₜ
        # hidden_states: tuple of [batch, seq, hidden], lấy layer cuối
        h_t = outputs_c.hidden_states[-1][:, -1, :]  # [1, hidden_size]

        # Số image tokens trong conditioned input
        # Qwen2VL trả về image_grid_thw để tính số patch tokens
        num_image_tokens = 0
        if hasattr(outputs_c, 'image_grid_thw') or 'image_grid_thw' in inputs_c:
            grid = inputs_c.get('image_grid_thw', None)
            if grid is not None:
                # num_patches = T * H * W (với T=1 cho ảnh tĩnh)
                num_image_tokens = int(grid[0].prod().item())
        # Fallback estimate
        if num_image_tokens == 0:
            num_image_tokens = 256  # typical for 448px image

        # Tính αₜ từ prefill attention
        alpha_attention = self._compute_attention_mass(
            outputs_c.attentions,
            num_image_tokens
        )

        # Unconditioned prefill (không cần attention/hidden)
        outputs_u = self.model(**inputs_u, use_cache=True)
        past_kv_u = outputs_u.past_key_values
        logits_u  = outputs_u.logits[:, -1, :]

        seq_len_c = inputs_c['input_ids'].shape[1]
        seq_len_u = inputs_u['input_ids'].shape[1]

        inv_temp = 1.0 / temperature
        eos_token_id = self.processor.tokenizer.eos_token_id
        generated_ids = []

        if verbose:
            print(f"\n{'t':<4} {'α_attn':<10} {'γt':<8} {'w=(1-γ)/γ':<12} {'Token'}")
            print("-" * 60)

        # === DECODE LOOP ===
        for t in range(1, max_new_tokens + 1):

            # --- Step 4: Tính γₜ từ hidden state ---
            # Dùng gamma_net (có thể fine-tune sau)
            # inference_mode: tạm thời enable grad chỉ cho gamma_net
            with torch.enable_grad():
                h_t_input = h_t.detach().to(dtype=next(self.gamma_net.parameters()).dtype)
                gamma_t = self.gamma_net(h_t_input)  # [1, 1]
            gamma_t = gamma_t.detach().squeeze()         # scalar

            # --- Step 2-3: Log-probs ---
            lc = torch.log_softmax(logits_c * inv_temp, dim=-1)  # [1, vocab]
            lu = torch.log_softmax(logits_u * inv_temp, dim=-1)  # [1, vocab]

            # --- Step 5: Attention-aware M3ID formula ---
            # l̂ₜ = lc + αₜ · ((1-γₜ)/γₜ) · (lc - lu)
            # αₜ: continuous attention mass (thay cho binary indicator)
            # (1-γₜ)/γₜ: learnable correction weight (thay cho heuristic)

            eps = 1e-6
            correction_weight = ((1.0 - gamma_t) / (gamma_t + eps)).clamp(0.0, 5.0)

            # αₜ đóng vai trò gate: chỉ correct khi thực sự attend image
            l_star = lc + alpha_attention * correction_weight * (lc - lu)
            # [1, vocab]

            # --- Step 6: Top-p sampling ---
            probs = torch.softmax(l_star, dim=-1)

            # Nucleus sampling
            sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
            cum_probs = torch.cumsum(sorted_probs, dim=-1)

            mask = cum_probs > top_p
            mask[..., 1:] = mask[..., :-1].clone()
            mask[..., 0] = False

            remove_mask = torch.zeros_like(probs, dtype=torch.bool)
            remove_mask.scatter_(-1, sorted_idx, mask)
            probs[remove_mask] = 0.0

            prob_sum = probs.sum(dim=-1, keepdim=True).clamp(min=eps)
            probs = probs / prob_sum

            next_token_id = torch.multinomial(probs, num_samples=1)  # [1, 1]

            if next_token_id.item() == eos_token_id:
                break

            generated_ids.append(next_token_id.item())

            if verbose:
                token_str = self.processor.tokenizer.decode([next_token_id.item()])
                w_val = correction_weight.item()
                print(f"{t:<4} {alpha_attention.item():<10.4f} {gamma_t.item():<8.4f} {w_val:<12.4f} {repr(token_str)}")

            # === DECODE STEP: feed 1 token mới với KV cache ===
            next_token_tensor = next_token_id.view(1, 1)

            # Conditioned: cần attention + hidden state cho bước sau
            seq_len_c += 1
            out_c = self.model(
                input_ids=next_token_tensor,
                attention_mask=torch.ones((1, seq_len_c), dtype=torch.long, device=self.device),
                past_key_values=past_kv_c,
                use_cache=True,
                output_attentions=True,
                output_hidden_states=True,
            )
            logits_c  = out_c.logits[:, -1, :]
            past_kv_c = out_c.past_key_values

            # Cập nhật hₜ và αₜ cho bước tiếp theo
            h_t = out_c.hidden_states[-1][:, -1, :]
            alpha_attention = self._compute_attention_mass(
                out_c.attentions,
                num_image_tokens
            )

            # Unconditioned (không cần hidden/attention)
            seq_len_u += 1
            out_u = self.model(
                input_ids=next_token_tensor,
                attention_mask=torch.ones((1, seq_len_u), dtype=torch.long, device=self.device),
                past_key_values=past_kv_u,
                use_cache=True,
            )
            logits_u  = out_u.logits[:, -1, :]
            past_kv_u = out_u.past_key_values

        return self.processor.tokenizer.decode(generated_ids, skip_special_tokens=True)

    # ----------------------------------------------------------
    # TRAINING: Fine-tune γ network + hybrid encoder
    # Dùng khi có labeled data (response có/không hallucinate)
    # ----------------------------------------------------------
    def train_step(
        self,
        prompt: str,
        image_path: str,
        target_response: str,
        hallucination_label: float = 0.0
    ) -> float:
        """
        Fine-tune γ network với supervision đơn giản.
        """
        image = self.load_image(image_path)

        # 1. Chuẩn bị input có chứa cả câu hỏi và câu trả lời
        full_text = prompt + " " + target_response
        inputs_c = self._prepare_inputs_with_image(full_text, image)
        inputs_c = {k: v.to(self.device) for k, v in inputs_c.items()}

        # 2. [FIX QUAN TRỌNG] Tính chiều dài prompt PHẢI BAO GỒM CẢ ẢNH
        # Vì ảnh sẽ chèn hàng trăm tokens <|image_pad|> vào chuỗi.
        prompt_inputs = self._prepare_inputs_with_image(prompt, image)
        prompt_len = prompt_inputs['input_ids'].shape[1]

        # Kiểm tra an toàn: nếu câu trả lời bị tokenizer nuốt mất
        if inputs_c['input_ids'].shape[1] <= prompt_len:
            raise ValueError(f"Sequence length ({inputs_c['input_ids'].shape[1]}) <= prompt_len ({prompt_len})")

        self.optimizer.zero_grad()

        # LLM không tính gradient
        with torch.no_grad():
            outputs = self.model(
                **inputs_c,
                output_hidden_states=True,
                use_cache=False,
            )

        full_hidden = outputs.hidden_states[-1]  # [1, seq_len, hidden_size]

        # CHỈ LẤY HIDDEN STATES CỦA CÂU TRẢ LỜI
        response_hidden = full_hidden[:, prompt_len:, :]

        # Tính γ cho các token trong câu trả lời
        gammas = self.gamma_net(response_hidden.squeeze(0))  # [response_len, 1]
        avg_gamma = gammas.mean()

        # [FIX QUAN TRỌNG] Ép kiểu dữ liệu (dtype) khớp với avg_gamma (thường là float16)
        # Nếu không ép, target sẽ là float32 và gây lỗi RuntimeError.
        target_gamma = torch.tensor(
            1.0 - hallucination_label,
            device=self.device,
            dtype=avg_gamma.dtype
        )

        loss = nn.functional.mse_loss(avg_gamma, target_gamma)

        loss.backward()
        self.optimizer.step()

        return loss.item()

In [5]:
# # ============================================================
# # USAGE EXAMPLE
# # ============================================================
# if __name__ == "__main__":
#     # Load model
#     model_name = "Qwen/Qwen2.5-VL-3B-Instruct"

#     processor = AutoProcessor.from_pretrained(model_name)
#     # model = Qwen2VLForConditionalGeneration.from_pretrained(
#     #     model_name,
#     #     torch_dtype=torch.float16,
#     #     device_map="auto",
#     # )

#     model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
#         model_name,
#         torch_dtype=torch.float16,
#         device_map="auto",
#         attn_implementation='eager'
#     )

#     # Khởi tạo framework
#     framework = HybridAttentionM3ID(
#         model=model,
#         processor=processor,
#         hidden_size=2048,    # Qwen2VL-7B
#         num_regions=16,
#         # device="cuda"
#     )

#     # Inference
#     result = framework.generate(
#         prompt="Describe what you see in this image in detail.",
#         image_path="https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg",
#         max_new_tokens=150,
#         temperature=0.7,
#         top_p=0.9,
#         verbose=True
#     )

#     print(f"\n=== Generated Response ===\n{result}")

## Load dataset for training gamma

In [6]:
!pip install datasets tqdm



In [None]:
import torch
import json
import matplotlib.pyplot as plt
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor

# Giả sử framework M3ID của bạn được lưu trong file m3id.py
# from m3id import HybridAttentionM3ID

def plot_loss_curve(step_losses, save_path="training_loss.png"):
    """
    Hàm vẽ biểu đồ loss.
    Vẽ loss theo từng bước (mờ) và đường trung bình trượt (rõ nét) để dễ quan sát xu hướng.
    """
    plt.figure(figsize=(12, 6))

    # 1. Vẽ loss gốc của từng bước (màu xanh nhạt)
    plt.plot(step_losses, label='Step Loss', color='blue', alpha=0.3, linewidth=1)

    # 2. Vẽ đường làm mượt (Moving Average) để nhìn rõ xu hướng
    window_size = 50
    if len(step_losses) >= window_size:
        # Tính trung bình trượt
        smoothed_losses = np.convolve(step_losses, np.ones(window_size)/window_size, mode='valid')
        # Căn chỉnh trục x cho khớp
        plt.plot(range(window_size-1, len(step_losses)), smoothed_losses,
                 label=f'Trend (Moving Average window={window_size})', color='red', linewidth=2)

    plt.title('Training Loss of Learnable Gamma (M3ID Framework)', fontsize=14)
    plt.xlabel('Training Steps', fontsize=12)
    plt.ylabel('MSE Loss', fontsize=12)
    plt.legend(fontsize=12)
    plt.grid(True, linestyle='--', alpha=0.7)

    # Lưu ra file ảnh
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    print(f"\n=> Đã lưu biểu đồ loss tại: {save_path}")

    # Hiển thị biểu đồ ra màn hình (nếu chạy trên Jupyter Notebook/Colab)
    plt.show()

def train():
    device = "cuda" if torch.cuda.is_available() else "cpu"

    print("1. Loading Model...")
    model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
    processor = AutoProcessor.from_pretrained(model_name)
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto",
        attn_implementation='eager' # Quan trọng để lấy được attentions
    )

    # BẮT BUỘC: Đóng băng toàn bộ LLM, chỉ train Gamma_net
    model.requires_grad_(False)
    model.eval()

    print("2. Initializing Framework...")
    framework = HybridAttentionM3ID(
        model=model,
        processor=processor,
        hidden_size=2048, # 2048 cho Qwen 3B
        num_regions=16,
        gamma_lr=1e-4,
        device=device
    )

    print("3. Loading RLHF-V Dataset...")
    dataset = load_dataset("openbmb/RLHF-V-Dataset", split="train")

    # Lấy 1000 mẫu để train nhanh nghiệm thu thuật toán
    train_data = dataset.select(range(1000))

    epochs = 2

    # --- BIẾN LƯU TRỮ LOSS ĐỂ VẼ BIỂU ĐỒ ---
    history_step_losses = []

    print(f"4. Starting Training for {epochs} epochs...")
    framework.gamma_net.train() # Bật chế độ train cho mạng Gamma

    for epoch in range(epochs):
        total_loss = 0
        valid_steps = 0

        progress_bar = tqdm(train_data, desc=f"Epoch {epoch+1}/{epochs}")

        for item in progress_bar:
            try:
                image = item['image']

                # Bóc tách text_data
                text_data = item.get('text', item) # Đề phòng dataset format bị đổi
                if isinstance(text_data, str):
                    text_data = json.loads(text_data)

                prompt = text_data.get('question', '')
                chosen_response = text_data.get('chosen', '')
                rejected_response = text_data.get('rejected', '')

                if not prompt or not chosen_response or not rejected_response:
                    continue # Bỏ qua nếu data bị thiếu field

                # 1. Câu đúng (không ảo giác) -> target gamma cao (ít correction)
                loss_chosen = framework.train_step(
                    prompt=prompt,
                    image_path=image,
                    target_response=chosen_response,
                    hallucination_label=0.0
                )

                # 2. Câu ảo giác -> target gamma thấp (ép correction mạnh)
                loss_rejected = framework.train_step(
                    prompt=prompt,
                    image_path=image,
                    target_response=rejected_response,
                    hallucination_label=1.0
                )

                # Trung bình loss của cả 2 nhánh
                step_loss = (loss_chosen + loss_rejected) / 2.0

                # LƯU LOSS VÀO LIST ĐỂ TÍ VẼ
                history_step_losses.append(step_loss)

                total_loss += step_loss
                valid_steps += 1

                progress_bar.set_postfix({"Avg Loss": f"{total_loss/valid_steps:.4f}"})

            except Exception as e:
                print(f"\n[Cảnh báo] Bỏ qua sample do lỗi: {str(e)}")
                continue

        epoch_avg = total_loss / max(1, valid_steps)
        print(f"\n=> End of Epoch {epoch+1} | Mean Epoch Loss: {epoch_avg:.4f}")

        # LƯU TRỌNG SỐ
        save_path = f"gamma_net_epoch_{epoch+1}.pt"
        torch.save(framework.gamma_net.state_dict(), save_path)
        print(f"=> Saved Checkpoint to {save_path}\n")

    # 5. KẾT THÚC HUẤN LUYỆN -> VẼ BIỂU ĐỒ
    print("5. Generating Loss Curve...")
    plot_loss_curve(history_step_losses, save_path="gamma_training_loss.png")

if __name__ == "__main__":
    train()

1. Loading Model...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. 


Downloading (incomplete total...): 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/824 [00:00<?, ?it/s]

2. Initializing Framework...
3. Loading RLHF-V Dataset...
4. Starting Training for 2 epochs...


Epoch 1/2:   4%|▍         | 39/1000 [00:45<14:18,  1.12it/s, Avg Loss=nan]


[Cảnh báo] Bỏ qua sample do lỗi: CUDA out of memory. Tried to allocate 1.80 GiB. GPU 0 has a total capacity of 14.56 GiB of which 1.70 GiB is free. Including non-PyTorch memory, this process has 12.87 GiB memory in use. Of the allocated memory 12.65 GiB is allocated by PyTorch, and 81.86 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)


Epoch 1/2:   9%|▉         | 91/1000 [01:57<23:13,  1.53s/it, Avg Loss=nan]