In [1]:
import sys
!{sys.executable} -m pip install transformers accelerate matplotlib qwen_vl_utils datasets



In [2]:
import torch
import numpy as np
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Optional, Dict, List
import matplotlib.pyplot as plt

In [3]:
import torch
import math
from typing import Optional
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from PIL import Image
import requests
from io import BytesIO


class M3ID_Paper:
    """
    Implementation of Algorithm 1 from the paper:
    "Multi-Modal Hallucination Control by Visual Information Grounding"

    Optimized with KV caching: each decode step processes only the NEW token
    instead of the full sequence, yielding ~10-50x speedup for long generations.

    Main formula (Equation 4):
    l̂* = lc + [max_k(lc)_k < log α] * ((1-αt)/αt) * (lc - lu)
    """

    def __init__(
        self,
        model: Qwen2VLForConditionalGeneration,
        processor: AutoProcessor,
        lambda_param: float = 0.02,  # λ: forgetting rate
        alpha: float = 0.3,           # α: confidence threshold
        device: str = "cuda" if torch.cuda.is_available() else "cpu"
    ):
        self.model = model
        self.processor = processor
        self.lambda_param = lambda_param
        self.alpha = alpha
        self.device = device
        # Pre-compute log(α) for indicator comparison in log-space
        self.log_alpha = math.log(alpha)

    def load_image(self, image_source) -> Image.Image:
        """Load image from URL or file path"""
        if isinstance(image_source, Image.Image):
            return image_source
        if image_source.startswith(('http://', 'https://')):
            response = requests.get(image_source)
            return Image.open(BytesIO(response.content))
        return Image.open(image_source)

    def prepare_inputs_with_image(self, prompt: str, image: Image.Image):
        """Prepare inputs with image (conditioned)"""
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": prompt},
                ],
            }
        ]
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        return self.processor(
            text=[text], images=[image], return_tensors="pt", padding=True
        )

    def prepare_inputs_without_image(self, prompt: str):
        """Prepare inputs without image (unconditioned)"""
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt},
                ],
            }
        ]
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        return self.processor(
            text=[text], images=None, return_tensors="pt", padding=True
        )

    @torch.inference_mode()
    def generate(
        self,
        prompt: str,
        image_path: str,
        max_new_tokens: int = 100,
        temperature: float = 0.2,
        verbose: bool = False
    ) -> str:
        """
        Generate text using M3ID (Algorithm 1) with KV caching.

        Optimization: instead of reprocessing the full sequence at every step,
        we cache key/value states and only feed the single new token each step.
        This reduces per-step complexity from O(seq_len) to O(1) for attention.
        """
        image = self.load_image(image_path)

        # Prepare inputs
        inputs_c = self.prepare_inputs_with_image(prompt, image)
        inputs_u = self.prepare_inputs_without_image(prompt)
        inputs_c = {k: v.to(self.device) for k, v in inputs_c.items()}
        inputs_u = {k: v.to(self.device) for k, v in inputs_u.items()}

        # === PREFILL PASS: process full sequences, initialize KV caches ===
        outputs_c = self.model(**inputs_c, use_cache=True)
        past_kv_c = outputs_c.past_key_values
        logits_c = outputs_c.logits[:, -1, :]

        outputs_u = self.model(**inputs_u, use_cache=True)
        past_kv_u = outputs_u.past_key_values
        logits_u = outputs_u.logits[:, -1, :]

        # Track total sequence lengths for attention masks
        seq_len_c = inputs_c['input_ids'].shape[1]
        seq_len_u = inputs_u['input_ids'].shape[1]

        # Pre-compute constants
        inv_temp = 1.0 / temperature
        eos_token_id = self.processor.tokenizer.eos_token_id
        generated_ids = []

        if verbose:
            print(f"{'t':<4} {'αt':<8} {'max(pc)':<10} {'Indicator':<12} {'Token'}")
            print("-" * 70)

        # === DECODE LOOP: each step processes only 1 token ===
        for t in range(1, max_new_tokens + 1):
            # Step 1: αt ← exp(-λt)
            alpha_t = math.exp(-self.lambda_param * t)

            # Step 2-3: Compute log-probs from cached logits (no re-encoding)
            lc = torch.log_softmax(logits_c * inv_temp, dim=-1)
            lu = torch.log_softmax(logits_u * inv_temp, dim=-1)

            # Step 4: Indicator [max_k p(y_k|...) < α]
            # Equivalent in log-space: max(log_softmax) < log(α)
            max_log_prob_c = lc.max(dim=-1).values
            indicator = (max_log_prob_c < self.log_alpha).float()

            # Step 5: l̂* = lc + [indicator] * ((1-αt)/αt) * (lc - lu)
            if alpha_t > 0:
                w = (1.0 - alpha_t) / alpha_t
                l_star = lc + indicator.unsqueeze(-1) * w * (lc - lu)
            else:
                l_star = lc

            # Step 6: yt = argmax l̂*
            next_token_id = l_star.argmax(dim=-1)

            if next_token_id.item() == eos_token_id:
                break

            generated_ids.append(next_token_id.item())

            if verbose:
                token_str = self.processor.tokenizer.decode([next_token_id.item()])
                max_prob = max_log_prob_c.exp().item()
                ind_str = "INTERVENE" if indicator.item() > 0 else "NO"
                print(f"{t:<4} {alpha_t:<8.4f} {max_prob:<10.4f} {ind_str:<12} {repr(token_str)}")

            # === DECODE STEP: feed only the new token with KV cache ===
            next_token_tensor = next_token_id.view(1, 1)

            seq_len_c += 1
            attn_mask_c = torch.ones((1, seq_len_c), dtype=torch.long, device=self.device)
            out_c = self.model(
                input_ids=next_token_tensor,
                attention_mask=attn_mask_c,
                past_key_values=past_kv_c,
                use_cache=True,
            )
            logits_c = out_c.logits[:, -1, :]
            past_kv_c = out_c.past_key_values

            seq_len_u += 1
            attn_mask_u = torch.ones((1, seq_len_u), dtype=torch.long, device=self.device)
            out_u = self.model(
                input_ids=next_token_tensor,
                attention_mask=attn_mask_u,
                past_key_values=past_kv_u,
                use_cache=True,
            )
            logits_u = out_u.logits[:, -1, :]
            past_kv_u = out_u.past_key_values

        return self.processor.tokenizer.decode(generated_ids, skip_special_tokens=True)


class M3ID_WithSampling(M3ID_Paper):
    """
    Extended M3ID with nucleus (top-p) sampling instead of greedy search.
    Inherits KV-cache optimization from M3ID_Paper.
    """

    @torch.inference_mode()
    def generate(
        self,
        prompt: str,
        image_path: str,
        max_new_tokens: int = 100,
        temperature: float = 0.2,
        top_p: float = 0.9,
        verbose: bool = True
    ) -> str:
        """M3ID with top-p (nucleus) sampling, KV-cached."""
        image = self.load_image(image_path)

        inputs_c = self.prepare_inputs_with_image(prompt, image)
        inputs_u = self.prepare_inputs_without_image(prompt)
        inputs_c = {k: v.to(self.device) for k, v in inputs_c.items()}
        inputs_u = {k: v.to(self.device) for k, v in inputs_u.items()}

        # === PREFILL ===
        outputs_c = self.model(**inputs_c, use_cache=True)
        past_kv_c = outputs_c.past_key_values
        logits_c = outputs_c.logits[:, -1, :]

        outputs_u = self.model(**inputs_u, use_cache=True)
        past_kv_u = outputs_u.past_key_values
        logits_u = outputs_u.logits[:, -1, :]

        seq_len_c = inputs_c['input_ids'].shape[1]
        seq_len_u = inputs_u['input_ids'].shape[1]

        inv_temp = 1.0 / temperature
        eos_token_id = self.processor.tokenizer.eos_token_id
        generated_ids = []

        if verbose:
            print(f"{'t':<4} {'αt':<8} {'max(pc)':<10} {'Indicator':<12} {'Token'}")
            print("-" * 70)

        for t in range(1, max_new_tokens + 1):
            alpha_t = math.exp(-self.lambda_param * t)

            lc = torch.log_softmax(logits_c * inv_temp, dim=-1)
            lu = torch.log_softmax(logits_u * inv_temp, dim=-1)

            max_log_prob_c = lc.max(dim=-1).values
            indicator = (max_log_prob_c < self.log_alpha).float()

            if alpha_t > 0:
                w = (1.0 - alpha_t) / alpha_t
                l_star = lc + indicator.unsqueeze(-1) * w * (lc - lu)
            else:
                l_star = lc

            # Top-p sampling on adjusted log-probs
            probs_final = torch.softmax(l_star, dim=-1)

            sorted_probs, sorted_indices = torch.sort(probs_final, descending=True, dim=-1)
            cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

            # Mask tokens beyond top-p threshold
            sorted_mask = cumulative_probs > top_p
            sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
            sorted_mask[..., 0] = False

            indices_to_remove = torch.zeros_like(probs_final, dtype=torch.bool)
            indices_to_remove.scatter_(-1, sorted_indices, sorted_mask)
            probs_final[indices_to_remove] = 0.0

            # Renormalize
            prob_sum = probs_final.sum(dim=-1, keepdim=True)
            probs_final = torch.where(
                prob_sum > 0,
                probs_final / prob_sum,
                torch.ones_like(probs_final) / probs_final.shape[-1]
            )

            next_token_id = torch.multinomial(probs_final, num_samples=1)  # [1, 1]

            if next_token_id.item() == eos_token_id:
                break

            generated_ids.append(next_token_id.item())

            if verbose:
                token_str = self.processor.tokenizer.decode([next_token_id.item()])
                max_prob = max_log_prob_c.exp().item()
                ind_str = "INTERVENE" if indicator.item() > 0 else "NO"
                print(f"{t:<4} {alpha_t:<8.4f} {max_prob:<10.4f} {ind_str:<12} {repr(token_str)}")

            # === DECODE STEP with KV cache ===
            next_token_tensor = next_token_id.view(1, 1)

            seq_len_c += 1
            attn_mask_c = torch.ones((1, seq_len_c), dtype=torch.long, device=self.device)
            out_c = self.model(
                input_ids=next_token_tensor,
                attention_mask=attn_mask_c,
                past_key_values=past_kv_c,
                use_cache=True,
            )
            logits_c = out_c.logits[:, -1, :]
            past_kv_c = out_c.past_key_values

            seq_len_u += 1
            attn_mask_u = torch.ones((1, seq_len_u), dtype=torch.long, device=self.device)
            out_u = self.model(
                input_ids=next_token_tensor,
                attention_mask=attn_mask_u,
                past_key_values=past_kv_u,
                use_cache=True,
            )
            logits_u = out_u.logits[:, -1, :]
            past_kv_u = out_u.past_key_values

        return self.processor.tokenizer.decode(generated_ids, skip_special_tokens=True)

In [4]:
from huggingface_hub import login
login("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")

# import os
# os.environ["HF_TOKEN"] = "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"


In [5]:
import torch
import torch.nn as nn
import math
from typing import Optional, List, Tuple
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from PIL import Image
import requests
from io import BytesIO
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor


# ============================================================
# COMPONENT 1: Learnable γ network
# γₜ = σ(W hₜ) — học khi nào cần trust image
# ============================================================
class LearnableGamma(nn.Module):
    """
    Học adaptive visual trust coefficient từ hidden state.

    γₜ nhỏ → model tự tin vào visual evidence → tăng correction
    γₜ lớn → token ngữ pháp → giảm correction
    """
    def __init__(self, hidden_size: int):
        super().__init__()
        # Simple linear projection: W ∈ ℝ^(hidden_size → 1)
        self.proj = nn.Linear(hidden_size, 1, bias=True)
        # Init bias nhỏ để γ bắt đầu ~0.5
        nn.init.zeros_(self.proj.weight)
        nn.init.constant_(self.proj.bias, 0.0)

    def forward(self, h_t: torch.Tensor) -> torch.Tensor:
        """
        Args:
            h_t: hidden state [batch, hidden_size]
        Returns:
            gamma: [batch, 1] ∈ (0, 1)
        """
        return torch.sigmoid(self.proj(h_t))  # γₜ = σ(W hₜ)


# ============================================================
# COMPONENT 2: Hybrid Patch + Region Encoder
# Visual tokens = {patch tokens} ∪ {region tokens}
# ============================================================
class HybridVisualEncoder(nn.Module):
    """
    Kết hợp hai mức trừu tượng:
    - Patch tokens: perception (texture, background, global context)
    - Region tokens: grounding anchor (object-level semantic)

    Trong Qwen2VL, patch tokens đã có sẵn trong visual_hidden_states.
    Region tokens được tổng hợp từ patch tokens qua attention pooling.
    """
    def __init__(self, hidden_size: int, num_regions: int = 16):
        super().__init__()
        self.num_regions = num_regions
        self.hidden_size = hidden_size

        # Learnable region query vectors (như DETR object queries)
        # Mỗi query học cách attend vào một "semantic region" khác nhau
        self.region_queries = nn.Parameter(
            torch.randn(num_regions, hidden_size) * 0.02
        )

        # Cross-attention: region queries attend patch tokens
        self.region_attn = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=8,
            batch_first=True
        )

        # Fusion gate: học trọng số blend patch vs region
        self.fusion_gate = nn.Linear(hidden_size * 2, 1)

    def forward(self, patch_tokens: torch.Tensor) -> torch.Tensor:
        """
        Args:
            patch_tokens: [batch, num_patches, hidden_size]
        Returns:
            hybrid_tokens: [batch, num_patches + num_regions, hidden_size]
        """
        batch_size = patch_tokens.shape[0]

        # Expand region queries cho batch
        queries = self.region_queries.unsqueeze(0).expand(batch_size, -1, -1)
        # queries: [batch, num_regions, hidden_size]

        # Cross-attention: region queries ← patch tokens
        # Mỗi region query học attend vào spatial region khác nhau
        region_tokens, _ = self.region_attn(
            query=queries,       # [batch, num_regions, H]
            key=patch_tokens,    # [batch, num_patches, H]
            value=patch_tokens   # [batch, num_patches, H]
        )
        # region_tokens: [batch, num_regions, hidden_size]

        # Concatenate: Visual tokens = patches ∪ regions
        hybrid_tokens = torch.cat([patch_tokens, region_tokens], dim=1)
        # [batch, num_patches + num_regions, hidden_size]

        return hybrid_tokens


# ============================================================
# MAIN: Attention-aware M3ID với Learnable γ
# Formula: l̂ₜ = lc + αₜ · ((1-γₜ)/γₜ) · (lc - lu)
# ============================================================
class HybridAttentionM3ID:
    """
    Full framework:
    1. Hybrid Patch + Region Encoder → richer visual representation
    2. Attention-aware scaling (αₜ) → correction chỉ khi attend image
    3. Learnable γₜ → adaptive visual trust

    Decode formula:
        l̂ₜ = lc + αₜ · ((1-γₜ)/γₜ) · (lc - lu)

    So sánh với M3ID gốc:
        l̂ₜ = lc + [indicator] · ((1-exp(-λt))/exp(-λt)) · (lc - lu)

    Thay thế:
    - exp(-λt) heuristic → γₜ = σ(W hₜ) learnable
    - binary indicator → αₜ = attention mass (continuous)
    """

    def __init__(
        self,
        model: Qwen2_5_VLForConditionalGeneration,
        processor: AutoProcessor,
        hidden_size: int = 2048,       # Qwen2VL-7B hidden size is 3584 and 3B is 2048
        num_regions: int = 16,          # số region tokens
        gamma_lr: float = 1e-4,         # learning rate cho γ network
        device: str = "cuda" if torch.cuda.is_available() else "cpu"
    ):
        self.model = model
        self.processor = processor
        self.device = device
        self.hidden_size = hidden_size

        # === Learnable components ===
        self.gamma_net = LearnableGamma(hidden_size).to(device=device, dtype=torch.float16)
        self.hybrid_encoder = HybridVisualEncoder(hidden_size, num_regions).to(device=device, dtype=torch.float16)

        # Optimizer chỉ cho các learnable components (không train lại LLM)
        self.optimizer = torch.optim.Adam(
            list(self.gamma_net.parameters()) +
            list(self.hybrid_encoder.parameters()),
            lr=gamma_lr
        )

    # ----------------------------------------------------------
    # Image loading
    # ----------------------------------------------------------
    def load_image(self, image_source) -> Image.Image:
      if isinstance(image_source, Image.Image):
          return image_source
      if isinstance(image_source, str) and image_source.startswith(('http://', 'https://')):
          headers = {
              "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
          }
          response = requests.get(image_source, headers=headers, timeout=10)
          # Check HTTP status
          response.raise_for_status()
          # Check content type
          content_type = response.headers.get("Content-Type", "")
          if "image" not in content_type:
              raise ValueError(f"URL did not return an image. Content-Type: {content_type}")
          return Image.open(BytesIO(response.content)).convert("RGB")
      return Image.open(image_source).convert("RGB")

    # ----------------------------------------------------------
    # Input preparation
    # ----------------------------------------------------------
    def _prepare_inputs_with_image(self, prompt: str, image: Image.Image) -> dict:
        messages = [{
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt},
            ],
        }]
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        return self.processor(text=[text], images=[image], return_tensors="pt", padding=True)

    def _prepare_inputs_without_image(self, prompt: str) -> dict:
        messages = [{
            "role": "user",
            "content": [{"type": "text", "text": prompt}],
        }]
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        return self.processor(text=[text], images=None, return_tensors="pt", padding=True)

    # ----------------------------------------------------------
    # COMPONENT: Tính attention mass αₜ
    # αₜ = tổng attention weight từ text token hiện tại → image tokens
    # ----------------------------------------------------------
    def _compute_attention_mass(
        self,
        attentions: Tuple,              # tuple of [batch, heads, seq, seq]
        num_image_tokens: int,
        layer_idx: int = -1             # dùng layer cuối (most semantic)
    ) -> torch.Tensor:
        """
        Tính αₜ = attention mass từ token cuối → image tokens.

        Dùng layer attention cuối cùng vì nó capture
        semantic-level dependencies (không phải syntactic).

        Returns:
            alpha_t: scalar tensor ∈ [0, 1]
        """
        if attentions is None:
            # Fallback: không có attention → assume moderate attention
            return torch.tensor(0.5, device=self.device)

        # Lọc ra các layer có attention weights thực sự (không phải None)
        valid_attentions = [a for a in attentions if a is not None]

        # Fallback 2: tất cả layers đều là None
        if len(valid_attentions) == 0:
            return torch.tensor(0.5, device=self.device)

        # Lấy layer hợp lệ theo layer_idx
        # Nếu layer_idx=-1 → lấy layer cuối trong danh sách hợp lệ
        attn_layer = valid_attentions[layer_idx]  # [batch, heads, seq_len, seq_len]

        # Attention từ token cuối (vị trí -1) đến tất cả positions
        # Average across heads
        attn_last_token = attn_layer[0, :, -1, :]   # [heads, seq_len]
        attn_avg = attn_last_token.mean(dim=0)       # [seq_len]

        # Image tokens nằm ở đầu sequence
        image_start = 1
        image_end = min(image_start + num_image_tokens, attn_avg.shape[0])

        # αₜ = tổng attention mass đến image tokens
        alpha_t = attn_avg[image_start:image_end].mean()

        return alpha_t.clamp(0.0, 1.0)

    # ----------------------------------------------------------
    # GENERATE: Main decode loop
    # ----------------------------------------------------------
    @torch.no_grad()
    def generate(
        self,
        prompt: str,
        image_path: str,
        max_new_tokens: int = 100,
        temperature: float = 0.7,
        top_p: float = 0.9,
        verbose: bool = True
    ) -> str:
        """
        Decode với full framework:

        Mỗi bước t:
        1. lc = log p(yₜ | x, image)
        2. lu = log p(yₜ | x)
        3. αₜ = attention mass → image tokens
        4. γₜ = σ(W hₜ) từ hidden state
        5. l̂ₜ = lc + αₜ · ((1-γₜ)/γₜ) · (lc - lu)
        6. Sample yₜ ~ softmax(l̂ₜ)
        """
        image = self.load_image(image_path)

        # Prepare inputs
        inputs_c = self._prepare_inputs_with_image(prompt, image)
        inputs_u = self._prepare_inputs_without_image(prompt)
        inputs_c = {k: v.to(self.device) for k, v in inputs_c.items()}
        inputs_u = {k: v.to(self.device) for k, v in inputs_u.items()}

        # === PREFILL: lần đầu chạy cả sequence, khởi tạo KV cache ===
        # output_attentions=True để lấy αₜ
        outputs_c = self.model(
            **inputs_c,
            use_cache=True,
            output_attentions=True,
            output_hidden_states=True,   # cần hₜ cho γₜ
        )
        past_kv_c = outputs_c.past_key_values
        logits_c   = outputs_c.logits[:, -1, :]

        # Hidden state của layer cuối, token cuối → dùng cho γₜ
        # hidden_states: tuple of [batch, seq, hidden], lấy layer cuối
        h_t = outputs_c.hidden_states[-1][:, -1, :]  # [1, hidden_size]

        # Số image tokens trong conditioned input
        # Qwen2VL trả về image_grid_thw để tính số patch tokens
        num_image_tokens = 0
        if hasattr(outputs_c, 'image_grid_thw') or 'image_grid_thw' in inputs_c:
            grid = inputs_c.get('image_grid_thw', None)
            if grid is not None:
                # num_patches = T * H * W (với T=1 cho ảnh tĩnh)
                num_image_tokens = int(grid[0].prod().item())
        # Fallback estimate
        if num_image_tokens == 0:
            num_image_tokens = 256  # typical for 448px image

        # Tính αₜ từ prefill attention
        alpha_attention = self._compute_attention_mass(
            outputs_c.attentions,
            num_image_tokens
        )

        # Unconditioned prefill (không cần attention/hidden)
        outputs_u = self.model(**inputs_u, use_cache=True)
        past_kv_u = outputs_u.past_key_values
        logits_u  = outputs_u.logits[:, -1, :]

        seq_len_c = inputs_c['input_ids'].shape[1]
        seq_len_u = inputs_u['input_ids'].shape[1]

        inv_temp = 1.0 / temperature
        eos_token_id = self.processor.tokenizer.eos_token_id
        generated_ids = []

        if verbose:
            print(f"\n{'t':<4} {'α_attn':<10} {'γt':<8} {'w=(1-γ)/γ':<12} {'Token'}")
            print("-" * 60)

        # === DECODE LOOP ===
        for t in range(1, max_new_tokens + 1):

            # --- Step 4: Tính γₜ từ hidden state ---
            # Dùng gamma_net (có thể fine-tune sau)
            # inference_mode: tạm thời enable grad chỉ cho gamma_net
            with torch.enable_grad():
                h_t_input = h_t.detach().to(dtype=next(self.gamma_net.parameters()).dtype)
                gamma_t = self.gamma_net(h_t_input)  # [1, 1]
            gamma_t = gamma_t.detach().squeeze()         # scalar

            # --- Step 2-3: Log-probs ---
            lc = torch.log_softmax(logits_c * inv_temp, dim=-1)  # [1, vocab]
            lu = torch.log_softmax(logits_u * inv_temp, dim=-1)  # [1, vocab]

            # --- Step 5: Attention-aware M3ID formula ---
            # l̂ₜ = lc + αₜ · ((1-γₜ)/γₜ) · (lc - lu)
            # αₜ: continuous attention mass (thay cho binary indicator)
            # (1-γₜ)/γₜ: learnable correction weight (thay cho heuristic)

            eps = 1e-6
            correction_weight = ((1.0 - gamma_t) / (gamma_t + eps)).clamp(0.0, 5.0)

            # αₜ đóng vai trò gate: chỉ correct khi thực sự attend image
            l_star = lc + alpha_attention * correction_weight * (lc - lu)
            # [1, vocab]

            # --- Step 6: Top-p sampling ---
            probs = torch.softmax(l_star, dim=-1)

            # Nucleus sampling
            sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
            cum_probs = torch.cumsum(sorted_probs, dim=-1)

            mask = cum_probs > top_p
            mask[..., 1:] = mask[..., :-1].clone()
            mask[..., 0] = False

            remove_mask = torch.zeros_like(probs, dtype=torch.bool)
            remove_mask.scatter_(-1, sorted_idx, mask)
            probs[remove_mask] = 0.0

            prob_sum = probs.sum(dim=-1, keepdim=True).clamp(min=eps)
            probs = probs / prob_sum

            next_token_id = torch.multinomial(probs, num_samples=1)  # [1, 1]

            if next_token_id.item() == eos_token_id:
                break

            generated_ids.append(next_token_id.item())

            if verbose:
                token_str = self.processor.tokenizer.decode([next_token_id.item()])
                w_val = correction_weight.item()
                print(f"{t:<4} {alpha_attention.item():<10.4f} {gamma_t.item():<8.4f} {w_val:<12.4f} {repr(token_str)}")

            # === DECODE STEP: feed 1 token mới với KV cache ===
            next_token_tensor = next_token_id.view(1, 1)

            # Conditioned: cần attention + hidden state cho bước sau
            seq_len_c += 1
            out_c = self.model(
                input_ids=next_token_tensor,
                attention_mask=torch.ones((1, seq_len_c), dtype=torch.long, device=self.device),
                past_key_values=past_kv_c,
                use_cache=True,
                output_attentions=True,
                output_hidden_states=True,
            )
            logits_c  = out_c.logits[:, -1, :]
            past_kv_c = out_c.past_key_values

            # Cập nhật hₜ và αₜ cho bước tiếp theo
            h_t = out_c.hidden_states[-1][:, -1, :]
            alpha_attention = self._compute_attention_mass(
                out_c.attentions,
                num_image_tokens
            )

            # Unconditioned (không cần hidden/attention)
            seq_len_u += 1
            out_u = self.model(
                input_ids=next_token_tensor,
                attention_mask=torch.ones((1, seq_len_u), dtype=torch.long, device=self.device),
                past_key_values=past_kv_u,
                use_cache=True,
            )
            logits_u  = out_u.logits[:, -1, :]
            past_kv_u = out_u.past_key_values

        return self.processor.tokenizer.decode(generated_ids, skip_special_tokens=True)

    # ----------------------------------------------------------
    # TRAINING: Fine-tune γ network + hybrid encoder
    # Dùng khi có labeled data (response có/không hallucinate)
    # ----------------------------------------------------------
    def train_step(
        self,
        prompt: str,
        image_path: str,
        target_response: str,
        hallucination_label: float = 0.0  # 0=no hallucination, 1=hallucination
    ) -> float:
        """
        Fine-tune γ network với supervision đơn giản.

        Loss: nếu hallucination_label=1 → push γ nhỏ (tăng correction)
              nếu hallucination_label=0 → γ có thể lớn hơn

        Returns:
            loss value
        """
        image = self.load_image(image_path)
        inputs_c = self._prepare_inputs_with_image(prompt, image)
        inputs_c = {k: v.to(self.device) for k, v in inputs_c.items()}

        self.optimizer.zero_grad()

        with torch.no_grad():
            outputs = self.model(
                **inputs_c,
                output_hidden_states=True,
                use_cache=False,
            )

        # Lấy hidden states của toàn bộ generated positions
        hidden = outputs.hidden_states[-1]  # [1, seq, hidden]

        # Tính γ cho từng position
        gammas = self.gamma_net(hidden.squeeze(0))  # [seq, 1]
        avg_gamma = gammas.mean()

        # Loss: hallucination → muốn γ nhỏ (model nên trust image nhiều hơn)
        # no hallucination → γ tự do
        target_gamma = torch.tensor(
            1.0 - hallucination_label,   # hallucination → target γ = 0 (max correction)
            device=self.device
        )
        loss = nn.functional.mse_loss(avg_gamma, target_gamma)

        loss.backward()
        self.optimizer.step()

        return loss.item()



# ============================================================
# USAGE EXAMPLE
# ============================================================
if __name__ == "__main__":
    # Load model
    model_name = "Qwen/Qwen2.5-VL-3B-Instruct"

    processor = AutoProcessor.from_pretrained(model_name)
    # model = Qwen2VLForConditionalGeneration.from_pretrained(
    #     model_name,
    #     torch_dtype=torch.float16,
    #     device_map="auto",
    # )

    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto",
        attn_implementation='eager'
    )

    # Khởi tạo framework
    framework = HybridAttentionM3ID(
        model=model,
        processor=processor,
        hidden_size=2048,    # Qwen2VL-7B
        num_regions=16,
        # device="cuda"
    )

    # Inference
    result = framework.generate(
        prompt="Describe what you see in this image in detail.",
        image_path="https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg",
        max_new_tokens=150,
        temperature=0.7,
        top_p=0.9,
        verbose=True
    )

    print(f"\n=== Generated Response ===\n{result}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. 


Downloading (incomplete total...): 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/824 [00:00<?, ?it/s]


t    α_attn     γt       w=(1-γ)/γ    Token
------------------------------------------------------------
1    0.0005     0.5000   1.0000       'The'
2    0.0005     0.5000   1.0000       ' image'
3    0.0005     0.5000   1.0000       ' shows'
4    0.0005     0.5000   1.0000       ' a'
5    0.0005     0.5000   1.0000       ' close'
6    0.0005     0.5000   1.0000       '-up'
7    0.0005     0.5000   1.0000       ' of'
8    0.0005     0.5000   1.0000       ' a'
9    0.0005     0.5000   1.0000       ' cat'
10   0.0005     0.5000   1.0000       "'s"
11   0.0005     0.5000   1.0000       ' face'
12   0.0005     0.5000   1.0000       '.'
13   0.0005     0.5000   1.0000       ' The'
14   0.0005     0.5000   1.0000       ' cat'
15   0.0005     0.5000   1.0000       ' has'
16   0.0005     0.5000   1.0000       ' a'
17   0.0005     0.5000   1.0000       ' light'
18   0.0005     0.5000   1.0000       ' orange'
19   0.0005     0.5000   1.0000       ' or'
20   0.0005     0.5000   1.0000       ' gi

In [6]:
# # ============================================================
# # USAGE EXAMPLE
# # ============================================================
# if __name__ == "__main__":
#     # Load model
#     model_name = "Qwen/Qwen2.5-VL-3B-Instruct"

#     processor = AutoProcessor.from_pretrained(model_name)
#     # model = Qwen2VLForConditionalGeneration.from_pretrained(
#     #     model_name,
#     #     torch_dtype=torch.float16,
#     #     device_map="auto",
#     # )

#     model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
#         model_name,
#         torch_dtype=torch.float16,
#         device_map="auto",
#     )

#     # Khởi tạo framework
#     framework = HybridAttentionM3ID(
#         model=model,
#         processor=processor,
#         hidden_size=3584,    # Qwen2VL-7B
#         num_regions=16,
#         device="cuda"
#     )

#     # Inference
#     result = framework.generate(
#         prompt="Describe what you see in this image in detail.",
#         image_path="https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg",
#         max_new_tokens=150,
#         temperature=0.7,
#         top_p=0.9,
#         verbose=True
#     )

#     print(f"\n=== Generated Response ===\n{result}")

In [7]:
# ```

# ---

# ## Giải thích các thay đổi chính

# ### 1. γₜ: Từ heuristic → Learnable
# ```
# Cũ:  gamma_t = exp(-λt)          # chỉ phụ thuộc thời gian
# Mới: gamma_t = σ(W · h_t)        # phụ thuộc nội dung token
# ```
# `LearnableGamma` là một linear layer + sigmoid — nhỏ nhưng học được "khi nào cần trust image". Để train nó bạn dùng `train_step()`.

# ### 2. αₜ: Từ binary indicator → Continuous attention mass
# ```
# Cũ:  indicator = [max(lc) < log(α)]    # binary 0/1
# Mới: alpha_t = sum(attn → image tokens) # continuous [0,1]
# ```
# `_compute_attention_mass` đọc attention weights của layer cuối, tính tổng mass đến image tokens. Token như "the" sẽ có αₜ ≈ 0, token như "dog" sẽ có αₜ cao.

# ### 3. Formula tổng hợp
# ```
# Cũ:  l̂ = lc + [ind] · ((1 - exp(-λt)) / exp(-λt)) · (lc - lu)
# Mới: l̂ = lc + α_attn · ((1 - γₜ) / γₜ) · (lc - lu)