In [None]:
# Block 1: Installation
%pip install transformers==4.35.0
%pip install torch torchvision torchaudio
%pip install underthesea
%pip install onnx onnxruntime
%pip install optimum[onnxruntime]
%pip install datasets
%pip install accelerate -U
%pip install onnxscript

In [None]:
# ================================================================================
# CELL 1: IMPORTS VÀ SETUP
# ================================================================================

import re
import numpy as np
from typing import Dict, List, Optional
import torch
import unicodedata
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, f1_score, precision_score, recall_score,
    classification_report, confusion_matrix
)
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.utils.class_weight import compute_class_weight
from tqdm.auto import tqdm
import os
import pandas as pd
from sentence_transformers import SentenceTransformer, util
from torch.utils.data import DataLoader, Dataset
import time
import json
import shutil
import onnx
import onnxruntime as ort

# Check dependencies
try:
    from underthesea import word_tokenize, sent_tokenize
    UNDERTHESEA_AVAILABLE = True
    print("✓ underthesea available")
except ImportError:
    UNDERTHESEA_AVAILABLE = False
    print("⚠️ underthesea not available")

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✓ Device: {device}")

In [None]:
# ================================================================================
# CELL 2: LOAD DATA
# ================================================================================

# Mount Google Drive (nếu dùng Colab)
from google.colab import drive
drive.mount('/content/drive')

# Load your dataset
# Thay đổi path theo file của bạn
data_path = '/content/drive/MyDrive/FakeNewsModels/dataset_balanced.csv'
df = pd.read_csv(data_path)

# Kiểm tra data
print(f"Dataset shape: {df.shape}")
print(f"Columns: {df.columns.tolist()}")
print(f"\nLabel distribution:")
print(df['label'].value_counts())
print(f"\nSample data:")
print(df.head(2))

In [None]:
# ================================================================================
# CELL 3: TRAIN/VAL/TEST SPLIT (FIXED)
# ================================================================================

# Split data theo tỷ lệ: 70% train, 15% val, 15% test
print("Splitting dataset into train/val/test...")

# Bước 1: Split train vs (val+test)
train_df, temp_df = train_test_split(
    df,
    test_size=0.30,  # 30% cho val+test
    random_state=42,
    stratify=df['label']
)

# Bước 2: Split (val+test) thành val và test
val_df, test_df = train_test_split(
    temp_df,
    test_size=0.50,  # 50% của 30% = 15% cho test
    random_state=42,
    stratify=temp_df['label']
)

print(f"\n{'='*60}")
print(f"Dataset Split Summary")
print(f"{'='*60}")
print(f"Total samples:  {len(df):,}")
print(f"Train set:      {len(train_df):,} ({len(train_df)/len(df)*100:.1f}%)")
print(f"Val set:        {len(val_df):,} ({len(val_df)/len(df)*100:.1f}%)")
print(f"Test set:       {len(test_df):,} ({len(test_df)/len(df)*100:.1f}%)")

print(f"\n{'Label':<10} {'Train':<10} {'Val':<10} {'Test':<10}")
print(f"{'-'*40}")
for label in sorted(df['label'].unique()):
    train_count = (train_df['label'] == label).sum()
    val_count = (val_df['label'] == label).sum()
    test_count = (test_df['label'] == label).sum()
    print(f"{label:<10} {train_count:<10} {val_count:<10} {test_count:<10}")

# Reset index
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

print("\n✓ Data split completed")


In [None]:
# ================================================================================
# CELL 4: TEXT NORMALIZER (OPTIMIZED)
# ================================================================================

class VietnameseTextNormalizer:
    """Normalizer cho PhoBERT với caching và error handling"""

    def __init__(self):
        self.use_word_segment = UNDERTHESEA_AVAILABLE

        # Cache word_tokenize function để tránh import lại
        self._word_tokenize = None
        if self.use_word_segment:
            try:
                from underthesea import word_tokenize
                self._word_tokenize = word_tokenize
            except ImportError:
                self.use_word_segment = False

        # Compile regex một lần (tối ưu performance)
        self.url_pattern = re.compile(r'http[s]?://\S+')
        self.special_chars_pattern = re.compile(
            r'[^\w\s.,!?àáảãạăằắẳẵặâầấẩẫậèéẻẽẹêềếểễệìíỉĩịòóỏõọôồốổỗộơờớởỡợùúủũụưừứửữựỳýỷỹỵđĐ]'
        )
        self.whitespace_pattern = re.compile(r'\s+')

    def normalize_unicode(self, text: str) -> str:
        return unicodedata.normalize('NFC', text)

    def clean_special_chars(self, text: str) -> str:
        # Remove URLs
        text = self.url_pattern.sub(' ', text)
        # Remove special chars
        text = self.special_chars_pattern.sub(' ', text)
        return text

    def word_segment(self, text: str) -> str:
        if not self.use_word_segment or not self._word_tokenize:
            return text

        try:
            return self._word_tokenize(text, format="text")
        except Exception as e:
            # Fallback: trả về text gốc nếu lỗi
            return text

    def normalize(self, text: Optional[str], preserve_mask: bool = False) -> str:
        if not text or not isinstance(text, str):
            return ""

        # Unicode normalization
        text = self.normalize_unicode(text)

        # Clean special chars
        text = text.strip()
        text = self.clean_special_chars(text)
        text = self.whitespace_pattern.sub(' ', text)

        # Word segmentation
        if not preserve_mask:
            text = self.word_segment(text)
            # Clean whitespace again (word_tokenize có thể tạo thêm spaces)
            text = self.whitespace_pattern.sub(' ', text)

        return text.strip()


# Initialize
normalizer = VietnameseTextNormalizer()
print("✓ Text normalizer initialized")

In [None]:
# ================================================================================
# CELL 5: SEMANTIC CHUNKER (FIXED VERSION)
# ================================================================================

class SemanticChunkRetriever:
    def __init__(self, chunk_size=400, chunk_overlap=50):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap

    def chunk_document(self, text):
        # Dùng sent_tokenize
        try:
            sentences = sent_tokenize(text)
        except:
            sentences = re.split(r'[.!?]\s+', text)

        chunks = []
        current_chunk = []
        current_len = 0

        for sent in sentences:
            sent = sent.strip()
            if not sent:
                continue

            sent_len = len(sent)

            # Handle câu đơn quá dài TRƯỚC KHI thêm vào chunk
            if sent_len > self.chunk_size * 1.5:
                # Lưu chunk hiện tại nếu có
                if current_chunk:
                    chunks.append(' '.join(current_chunk))
                    current_chunk = []
                    current_len = 0

                # Split câu dài thành sub-chunks
                words = sent.split()
                for i in range(0, len(words), 50):
                    sub_chunk = ' '.join(words[i:i+50])
                    chunks.append(sub_chunk)
                continue

            # Nếu thêm câu này vượt quá chunk_size
            if current_len + sent_len > self.chunk_size:
                if current_chunk:
                    # Lưu chunk hiện tại
                    chunk_text = ' '.join(current_chunk)
                    chunks.append(chunk_text)

                    # Tạo overlap bằng cách giữ N câu cuối
                    # Tính số câu cần giữ để overlap ~50 chars
                    overlap_sents = []
                    overlap_len = 0

                    for s in reversed(current_chunk):
                        if overlap_len + len(s) <= self.chunk_overlap:
                            overlap_sents.insert(0, s)
                            overlap_len += len(s) + 1
                        else:
                            break

                    # Reset chunk với overlap
                    current_chunk = overlap_sents
                    current_len = overlap_len

            # Thêm câu hiện tại
            current_chunk.append(sent)
            current_len += sent_len + 1

        # Lưu chunk cuối
        if current_chunk:
            chunks.append(' '.join(current_chunk))

        return chunks


# Initialize
chunker = SemanticChunkRetriever(chunk_size=400, chunk_overlap=50)
print("✓ Semantic chunker initialized with overlap")

In [None]:
# ================================================================================
# CELL 6: LOAD MODELS (TOKENIZER & RETRIEVER)
# ================================================================================

print("Loading PhoBERT tokenizer...")
tokenizer = AutoTokenizer.from_pretrained('vinai/phobert-base-v2')
print("✓ PhoBERT tokenizer loaded")

print("\nLoading Vietnamese-SBERT for RAG retrieval...")
retriever = SentenceTransformer('keepitreal/vietnamese-sbert')
retriever.eval()
print("✓ Vietnamese-SBERT loaded")


In [None]:
# ================================================================================
# CELL 7: CHUNK ATTENTION LAYER - FIXED FOR FP16
# ================================================================================

class ChunkAttentionLayer(nn.Module):
    """Attention mechanism ở chunk level - FP16 compatible"""

    def __init__(self, hidden_size: int, attention_hidden: int = 128):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(hidden_size, attention_hidden),
            nn.Tanh(),
            nn.Linear(attention_hidden, 1)
        )

    def forward(self, chunk_embeddings: torch.Tensor, mask: torch.Tensor = None):
        """
        Args:
            chunk_embeddings: [batch, num_chunks, hidden_size]
            mask: [batch, num_chunks] - 1 for valid chunks, 0 for padding

        Returns:
            context: [batch, hidden_size]
            attention_weights: [batch, num_chunks]
        """
        scores = self.attention(chunk_embeddings)
        scores = scores.squeeze(-1)  # [batch, num_chunks]

        if mask is not None:
            # FP16 max value is ~65504, so use -65504 instead of -1e9
            mask_value = -65504.0 if scores.dtype == torch.float16 else -1e9
            scores = scores.masked_fill(mask == 0, mask_value)

        attention_weights = F.softmax(scores, dim=1)

        context = torch.bmm(
            attention_weights.unsqueeze(1),
            chunk_embeddings
        ).squeeze(1)

        return context, attention_weights

print("✓ ChunkAttentionLayer defined (FP16 compatible)")

In [None]:
# ================================================================================
# CELL 8: DATASET (FIXED VERSION)
# ================================================================================

class HierarchicalAttentionDataset(Dataset):
    """
    FIXED VERSION: No empty padding, better RAG retrieval with overlap chunking
    """
    def __init__(
        self,
        df: pd.DataFrame,
        tokenizer: AutoTokenizer,
        normalizer: VietnameseTextNormalizer,
        retriever: SentenceTransformer,
        chunk_size: int = 400,
        chunk_overlap: int = 50,
        top_k: int = 5,
        max_length: int = 256,
        min_chunks: int = 3,
        min_similarity: float = 0.15,
        verbose: bool = False
    ):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.normalizer = normalizer
        self.retriever = retriever
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.top_k = top_k
        self.max_length = max_length
        self.min_chunks = min_chunks
        self.min_similarity = min_similarity
        self.verbose = verbose

        # Chunker với overlap
        self.chunker = SemanticChunkRetriever(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap
        )

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        row = self.df.iloc[idx]
        label = int(row['label'])

        # 1. Lấy Title và Content
        title = str(row.get('title', ''))
        content = str(row.get('content', ''))

        # Normalize
        title = self.normalizer.normalize(title)
        content = self.normalizer.normalize(content)

        # 2. Chunking Content với overlap
        raw_chunks = self.chunker.chunk_document(content)

        # Validate minimum chunks
        if len(raw_chunks) < self.min_chunks:
            if self.verbose:
                print(f"⚠️ Sample {idx}: {len(raw_chunks)} chunks < min {self.min_chunks}")
            while len(raw_chunks) < self.min_chunks:
                raw_chunks.extend(raw_chunks[:self.min_chunks - len(raw_chunks)])

        # 3. RAG LOGIC: Retrieve Top-K chunks
        if len(raw_chunks) <= self.top_k:
            selected_chunks = raw_chunks
        else:
            # Use title as query, fallback to first chunk
            query = title if title.strip() else raw_chunks[0]

            # Encode query and chunks
            query_emb = self.retriever.encode(query, convert_to_tensor=True)
            chunk_embs = self.retriever.encode(raw_chunks, convert_to_tensor=True)

            # Compute similarities
            similarities = F.cosine_similarity(
                query_emb.unsqueeze(0),
                chunk_embs,
                dim=1
            )

            # Filter low-similarity chunks
            valid_indices = (similarities >= self.min_similarity).nonzero(as_tuple=True)[0]

            if len(valid_indices) < self.top_k:
                top_indices = similarities.argsort(descending=True)[:self.top_k]
            else:
                valid_sims = similarities[valid_indices]
                sorted_valid = valid_indices[valid_sims.argsort(descending=True)]
                top_indices = sorted_valid[:self.top_k]

            selected_chunks = [raw_chunks[i] for i in top_indices]

        # Pad nếu thiếu
        while len(selected_chunks) < self.top_k:
            selected_chunks.append(
                selected_chunks[len(selected_chunks) % len(raw_chunks)]
            )

        selected_chunks = selected_chunks[:self.top_k]

        # Validate chunk lengths
        max_chars = self.max_length * 4
        validated_chunks = []
        for chunk in selected_chunks:
            if len(chunk) > max_chars:
                if self.verbose:
                    print(f"⚠️ Chunk too long ({len(chunk)} chars), truncating")
                chunk = chunk[:max_chars]
            validated_chunks.append(chunk)

        # Tokenize chunks
        chunk_encodings = self.tokenizer(
            validated_chunks,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'chunk_input_ids': chunk_encodings['input_ids'],
            'chunk_attention_masks': chunk_encodings['attention_mask'],
            'label': torch.tensor(label, dtype=torch.long)
        }


print("✓ HierarchicalAttentionDataset defined (with overlap)")


In [None]:
# ================================================================================
# CELL 9: HAN MODEL
# ================================================================================

class HierarchicalAttentionClassifier(nn.Module):
    """
    Hierarchical Attention Network cho phân loại fake news

    Architecture:
    1. PhoBERT encode từng chunk → chunk embeddings
    2. Chunk-level attention → document representation
    3. Classification head → logits
    """
    def __init__(
        self,
        phobert_name: str = "vinai/phobert-base-v2",
        chunk_attention_hidden: int = 128,
        num_classes: int = 2,
        dropout: float = 0.3
    ):
        super().__init__()

        # PhoBERT encoder
        self.phobert = AutoModel.from_pretrained(phobert_name)
        self.hidden_size = self.phobert.config.hidden_size  # 768

        # Chunk-level attention
        self.chunk_attention = ChunkAttentionLayer(
            hidden_size=self.hidden_size,
            attention_hidden=chunk_attention_hidden
        )

        # Classification head
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(self.hidden_size, self.hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(self.hidden_size // 2, num_classes)
        )

    def encode_chunks(
        self,
        chunk_input_ids: torch.Tensor,
        chunk_attention_masks: torch.Tensor
    ) -> torch.Tensor:
        """
        Encode từng chunk bằng PhoBERT

        Args:
            chunk_input_ids: [batch, num_chunks, max_length]
            chunk_attention_masks: [batch, num_chunks, max_length]

        Returns:
            chunk_embeddings: [batch, num_chunks, hidden_size]
        """
        batch_size, num_chunks, max_length = chunk_input_ids.shape

        # Reshape to [batch * num_chunks, max_length]
        flat_input_ids = chunk_input_ids.view(-1, max_length)
        flat_attention_masks = chunk_attention_masks.view(-1, max_length)

        # Encode with PhoBERT
        outputs = self.phobert(
            input_ids=flat_input_ids,
            attention_mask=flat_attention_masks
        )

        # Get [CLS] token embeddings: [batch * num_chunks, hidden_size]
        chunk_embeddings = outputs.last_hidden_state[:, 0, :]

        # Reshape back: [batch, num_chunks, hidden_size]
        chunk_embeddings = chunk_embeddings.view(batch_size, num_chunks, -1)

        return chunk_embeddings

    def forward(
        self,
        chunk_input_ids: torch.Tensor,
        chunk_attention_masks: torch.Tensor
    ):
        """
        Forward pass

        Args:
            chunk_input_ids: [batch, num_chunks, max_length]
            chunk_attention_masks: [batch, num_chunks, max_length]

        Returns:
            dict with:
                - logits: [batch, num_classes]
                - chunk_attention: [batch, num_chunks]
        """
        # 1. Encode chunks: [batch, num_chunks, hidden_size]
        chunk_embeddings = self.encode_chunks(chunk_input_ids, chunk_attention_masks)

        # 2. Create chunk-level mask
        chunk_mask = (chunk_attention_masks.sum(dim=2) > 0).float()

        # 3. Apply chunk-level attention
        doc_representation, chunk_attention_weights = self.chunk_attention(
            chunk_embeddings,
            mask=chunk_mask
        )

        # 4. Classification
        logits = self.classifier(doc_representation)

        return {
            'logits': logits,
            'chunk_attention': chunk_attention_weights
        }

print("✓ HierarchicalAttentionClassifier defined")


In [None]:
# ================================================================================
# CELL 10: CREATE DATASETS (UPDATED - THÊM TEST SET)
# ================================================================================

print("Creating training dataset...")
train_dataset = HierarchicalAttentionDataset(
    df=train_df,
    tokenizer=tokenizer,
    normalizer=normalizer,
    retriever=retriever,
    chunk_size=400,
    top_k=5,
    max_length=256,
    min_chunks=3,
    min_similarity=0.15,
    verbose=False
)
print(f"✓ Train dataset: {len(train_dataset)} samples")

print("\nCreating validation dataset...")
val_dataset = HierarchicalAttentionDataset(
    df=val_df,
    tokenizer=tokenizer,
    normalizer=normalizer,
    retriever=retriever,
    chunk_size=400,
    top_k=5,
    max_length=256,
    min_chunks=3,
    min_similarity=0.15,
    verbose=False
)
print(f"✓ Val dataset: {len(val_dataset)} samples")

print("\nCreating test dataset...")
test_dataset = HierarchicalAttentionDataset(
    df=test_df,
    tokenizer=tokenizer,
    normalizer=normalizer,
    retriever=retriever,
    chunk_size=400,
    top_k=5,
    max_length=256,
    min_chunks=3,
    min_similarity=0.15,
    verbose=False
)
print(f"✓ Test dataset: {len(test_dataset)} samples")

# Test one sample
print("\n### Testing Dataset ###")
sample = train_dataset[0]
print(f"Sample shape:")
print(f"  chunk_input_ids: {sample['chunk_input_ids'].shape}")
print(f"  chunk_attention_masks: {sample['chunk_attention_masks'].shape}")
print(f"  label: {sample['label']}")

# Check for empty chunks
for i in range(5):
    num_tokens = (sample['chunk_input_ids'][i] != tokenizer.pad_token_id).sum().item()
    print(f"  Chunk {i}: {num_tokens} tokens")

min_tokens = min(
    (sample['chunk_input_ids'][i] != tokenizer.pad_token_id).sum().item()
    for i in range(5)
)
print(f"\n{' GOOD' if min_tokens > 10 else ' ISSUE'}: Min tokens = {min_tokens}")


In [None]:
# ================================================================================
# CELL 11: TRAINING SETUP WITH CLASS WEIGHTS (OPTIMIZED)
# ================================================================================
# ========================================

# Tính class weights từ training data
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(train_df['label']),
    y=train_df['label'].values
)

# Convert sang tensor
class_weights_tensor = torch.FloatTensor(class_weights).to(device)

print(f"Class distribution in training set:")
print(f"  REAL (0): {(train_df['label'] == 0).sum()} samples")
print(f"  FAKE (1): {(train_df['label'] == 1).sum()} samples")
print(f"\nComputed class weights:")
print(f"  REAL (0): {class_weights[0]:.4f}")
print(f"  FAKE (1): {class_weights[1]:.4f}")
print(f"  Weight ratio (FAKE/REAL): {class_weights[1]/class_weights[0]:.2f}x")


# ========================================
# INITIALIZE MODEL
# ========================================
model = HierarchicalAttentionClassifier(
    phobert_name='vinai/phobert-base-v2',
    chunk_attention_hidden=128,
    num_classes=2,
    dropout=0.3
).to(device)

print(f"\n✓ Model initialized on {device}")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")


# ========================================
# DATALOADERS WITH OPTIMIZED SETTINGS
# ========================================
use_cuda = torch.cuda.is_available()
num_workers = 0  # Must be 0 with CUDA + sentence_transformers

if use_cuda:
    print("\n CUDA detected: Setting num_workers=0 to avoid multiprocessing errors")
else:
    num_workers = 2

train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=False
)

print(f"\n✓ DataLoaders created")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")


# ========================================
# OPTIMIZER & LOSS WITH CLASS WEIGHTS
# ========================================
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

# WEIGHTED CROSS ENTROPY LOSS
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)

scheduler = ReduceLROnPlateau(
    optimizer,
    mode='max',
    factor=0.5,
    patience=2,
)

print(f"\n✓ Training components ready")
print(f"  Optimizer: AdamW (lr=2e-5)")
print(f"  Loss: CrossEntropyLoss (WEIGHTED)")
print(f"  Scheduler: ReduceLROnPlateau (patience=2)")
print("\n" + "=" * 80)

In [None]:
# ================================================================================
# CELL 12: TRAINING + VALIDATION
# ================================================================================

from torch.cuda.amp import autocast, GradScaler
import gc

# Setup
drive_save_dir = '/content/drive/MyDrive/FakeNewsModels/model_v4.1'
os.makedirs(drive_save_dir, exist_ok=True)
best_model_path = os.path.join(drive_save_dir, 'han_rag_best.pth')

use_amp = torch.cuda.is_available()
scaler = GradScaler() if use_amp else None
accumulation_steps = 4  # Tăng lên 4 để giảm VRAM
best_val_f1 = 0
patience = 3
patience_counter = 0
num_epochs = 15

history = {'train_loss': [], 'train_acc': [], 'val_acc': [], 'val_f1': [],
           'val_precision': [], 'val_recall': []}

print(f"✓ Save: {drive_save_dir}")
if use_amp:
    print(f"⚡ FP16 + Accum={accumulation_steps} (effective batch={16*accumulation_steps})")

print("\n" + "=" * 80)
print(f"TRAINING ({num_epochs} epochs)")
print("=" * 80)

for epoch in range(num_epochs):
    print(f"\n{'='*80}\nEPOCH {epoch+1}/{num_epochs}\n{'='*80}")

    # Training
    model.train()
    train_loss = 0
    train_correct = 0
    train_total = 0
    pbar = tqdm(train_loader, desc=f"Training {epoch+1}")
    optimizer.zero_grad()

    for batch_idx, batch in enumerate(pbar):
        ids = batch['chunk_input_ids'].to(device, non_blocking=True)
        masks = batch['chunk_attention_masks'].to(device, non_blocking=True)
        labels = batch['label'].to(device, non_blocking=True)

        if use_amp:
            with autocast(dtype=torch.float16):
                logits = model(ids, masks)['logits']  # Tách ra để save memory
                loss = criterion(logits, labels) / accumulation_steps
            scaler.scale(loss).backward()
        else:
            logits = model(ids, masks)['logits']
            loss = criterion(logits, labels) / accumulation_steps
            loss.backward()

        # Free memory after backward
        del ids, masks, logits

        if (batch_idx + 1) % accumulation_steps == 0:
            if use_amp:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            optimizer.zero_grad()

        train_loss += loss.item() * accumulation_steps

        # Metrics without extra forward pass
        with torch.no_grad():
            preds = torch.argmax(criterion.weight.new_zeros(labels.size(0), 2).copy_(
                model(batch['chunk_input_ids'].to(device),
                      batch['chunk_attention_masks'].to(device))['logits']), dim=1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)

        pbar.set_postfix({'loss': f'{loss.item()*accumulation_steps:.4f}',
                         'acc': f'{train_correct/train_total:.4f}'})

        # More aggressive cache clearing
        if batch_idx % 20 == 0:
            torch.cuda.empty_cache()

    train_loss /= len(train_loader)
    train_acc = train_correct / train_total

    # Clear before validation
    torch.cuda.empty_cache()
    gc.collect()

    # Validation
    model.eval()
    val_preds = []
    val_labels = []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            ids = batch['chunk_input_ids'].to(device, non_blocking=True)
            masks = batch['chunk_attention_masks'].to(device, non_blocking=True)
            labels = batch['label'].to(device, non_blocking=True)

            if use_amp:
                with autocast(dtype=torch.float16):
                    logits = model(ids, masks)['logits']
            else:
                logits = model(ids, masks)['logits']

            val_preds.extend(torch.argmax(logits, dim=1).cpu().numpy())
            val_labels.extend(labels.cpu().numpy())

            # Free immediately
            del ids, masks, logits

    val_acc = accuracy_score(val_labels, val_preds)
    val_f1 = f1_score(val_labels, val_preds, average='macro')
    val_precision = precision_score(val_labels, val_preds, average='macro')
    val_recall = recall_score(val_labels, val_preds, average='macro')

    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)
    history['val_f1'].append(val_f1)
    history['val_precision'].append(val_precision)
    history['val_recall'].append(val_recall)

    print(f"\n### EPOCH {epoch+1} ###")
    print(f"Train: Loss={train_loss:.4f}, Acc={train_acc:.4f}")
    print(f"Val:   Acc={val_acc:.4f}, F1={val_f1:.4f}, P={val_precision:.4f}, R={val_recall:.4f}")

    scheduler.step(val_f1)

    if val_f1 > best_val_f1:
        improvement = val_f1 - best_val_f1
        best_val_f1 = val_f1
        patience_counter = 0

        # Move to CPU before saving
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': {k: v.cpu() for k, v in model.state_dict().items()},
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'scaler_state_dict': scaler.state_dict() if use_amp else None,
            'val_f1': val_f1,
            'val_acc': val_acc,
            'history': history
        }
        torch.save(checkpoint, best_model_path)
        # Move back to GPU
        model.load_state_dict({k: v.to(device) for k, v in checkpoint['model_state_dict'].items()})

        print(f"Best saved! F1={best_val_f1:.4f} (↑{improvement:.4f})")
    else:
        patience_counter += 1
        print(f"No improve. Patience: {patience_counter}/{patience}")
        if patience_counter >= patience:
            print(f"\nEarly stop at epoch {epoch+1}, Best F1={best_val_f1:.4f}")
            break

    # Aggressive cleanup
    torch.cuda.empty_cache()
    gc.collect()

print("\n" + "="*80)
print("TRAINING DONE")
print("="*80)
print(f"Best Val F1: {best_val_f1:.4f} (Epoch {history['val_f1'].index(max(history['val_f1']))+1})")

# Visualize training
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

axes[0,0].plot(history['train_loss'], 'o-', label='Train Loss')
axes[0,0].set_title('Loss')
axes[0,0].legend()
axes[0,0].grid(alpha=0.3)

axes[0,1].plot(history['train_acc'], 'o-', label='Train')
axes[0,1].plot(history['val_acc'], 's-', label='Val')
axes[0,1].set_title('Accuracy')
axes[0,1].legend()
axes[0,1].grid(alpha=0.3)

axes[1,0].plot(history['val_f1'], 'o-', color='green', label='Val F1')
best_epoch = history['val_f1'].index(max(history['val_f1']))
axes[1,0].scatter(best_epoch, max(history['val_f1']), color='red', s=100, label='Best', zorder=5)
axes[1,0].set_title('F1 Score')
axes[1,0].legend()
axes[1,0].grid(alpha=0.3)

axes[1,1].plot(history['val_precision'], 'o-', label='Precision')
axes[1,1].plot(history['val_recall'], 's-', label='Recall')
axes[1,1].set_title('Precision & Recall')
axes[1,1].legend()
axes[1,1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(drive_save_dir, 'training.png'), dpi=300, bbox_inches='tight')
plt.show()

# Final validation
checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

val_preds = []
val_labels = []

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Final Validation"):
        ids = batch['chunk_input_ids'].to(device, non_blocking=True)
        masks = batch['chunk_attention_masks'].to(device, non_blocking=True)
        labels = batch['label'].to(device, non_blocking=True)

        if use_amp:
            with autocast(dtype=torch.float16):
                logits = model(ids, masks)['logits']
        else:
            logits = model(ids, masks)['logits']

        val_preds.extend(torch.argmax(logits, dim=1).cpu().numpy())
        val_labels.extend(labels.cpu().numpy())

cm = confusion_matrix(val_labels, val_preds)

print("\n" + "="*80)
print("VALIDATION (BEST MODEL)")
print("="*80)
print(f"Accuracy:   {accuracy_score(val_labels, val_preds):.4f}")
print(f"F1:         {f1_score(val_labels, val_preds, average='macro'):.4f}")
print(f"Precision:  {precision_score(val_labels, val_preds, average='macro'):.4f}")
print(f"Recall:     {recall_score(val_labels, val_preds, average='macro'):.4f}")
print(f"\nConfusion Matrix:")
print(f"[[{cm[0,0]:4d} {cm[0,1]:4d}]  (REAL)")
print(f" [{cm[1,0]:4d} {cm[1,1]:4d}]] (FAKE)")

In [None]:
# ================================================================================
# CELL 13: TEST EVALUATION
# ================================================================================

# Load best model
checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Testing model from epoch {checkpoint['epoch']} (Val F1={checkpoint['val_f1']:.4f})")

# Test
test_preds = []
test_labels = []
test_probs = []
test_attention = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Testing"):
        ids = batch['chunk_input_ids'].to(device, non_blocking=True)
        masks = batch['chunk_attention_masks'].to(device, non_blocking=True)
        labels = batch['label'].to(device, non_blocking=True)

        if use_amp:
            with autocast():
                outputs = model(ids, masks)
        else:
            outputs = model(ids, masks)

        logits = outputs['logits']
        probs = F.softmax(logits, dim=1)

        test_preds.extend(torch.argmax(logits, dim=1).cpu().numpy())
        test_labels.extend(labels.cpu().numpy())
        test_probs.extend(probs.cpu().numpy())
        test_attention.extend(outputs['chunk_attention'].cpu().numpy())

# Metrics
test_acc = accuracy_score(test_labels, test_preds)
test_f1 = f1_score(test_labels, test_preds, average='macro')
test_f1_weighted = f1_score(test_labels, test_preds, average='weighted')
test_precision = precision_score(test_labels, test_preds, average='macro')
test_recall = recall_score(test_labels, test_preds, average='macro')
cm = confusion_matrix(test_labels, test_preds)

print("\n" + "="*80)
print("TEST RESULTS")
print("="*80)
print(f"Accuracy:     {test_acc:.4f}")
print(f"F1 (Macro):   {test_f1:.4f}")
print(f"F1 (Weight):  {test_f1_weighted:.4f}")
print(f"Precision:    {test_precision:.4f}")
print(f"Recall:       {test_recall:.4f}")

print("\n" + classification_report(test_labels, test_preds, target_names=['REAL', 'FAKE']))

print(f"\nConfusion Matrix:")
print(f"[[{cm[0,0]:4d} {cm[0,1]:4d}]")
print(f" [{cm[1,0]:4d} {cm[1,1]:4d}]]")
print(f"\nTN={cm[0,0]}, FP={cm[0,1]}, FN={cm[1,0]}, TP={cm[1,1]}")

# Attention analysis
mean_attn = np.array(test_attention).mean(axis=0)
print("\n" + "="*80)
print("ATTENTION WEIGHTS")
print("="*80)
for i, w in enumerate(mean_attn):
    bar = "█" * int(w * 40 / mean_attn.max())
    print(f"Chunk {i}: {bar} {w:.4f}")

# Val-Test comparison
gap = abs(checkpoint['val_f1'] - test_f1)
print("\n" + "="*80)
print("VAL vs TEST")
print("="*80)
print(f"Val F1:  {checkpoint['val_f1']:.4f}")
print(f"Test F1: {test_f1:.4f}")
print(f"Gap:     {gap:.4f} {'Good' if gap < 0.02 else ' Overfitting' if gap < 0.05 else ' Bad'}")

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['REAL', 'FAKE'], yticklabels=['REAL', 'FAKE'], ax=axes[0])
axes[0].set_title('Confusion Matrix')
axes[0].set_ylabel('True')
axes[0].set_xlabel('Predicted')

axes[1].bar(range(5), mean_attn, color='skyblue', edgecolor='navy')
axes[1].set_title('Attention Weights')
axes[1].set_xlabel('Chunk')
axes[1].set_ylabel('Weight')
for i, v in enumerate(mean_attn):
    axes[1].text(i, v, f'{v:.3f}', ha='center', va='bottom')

plt.tight_layout()
plt.savefig(os.path.join(drive_save_dir, 'test_results.png'), dpi=300, bbox_inches='tight')
plt.show()

print(f"\nTest complete. Results saved to {drive_save_dir}")

In [None]:
# ================================================================================
# EXPORT TOKENIZER & ONNX - SIMPLE VERSION
# ================================================================================

import os
import shutil
import torch
import onnx

print("="*80)
print("EXPORT TOKENIZER & ONNX MODEL")
print("="*80)

# Paths
drive_save_dir = '/content/drive/MyDrive/FakeNewsModels/model_v4.1'
best_model_path = os.path.join(drive_save_dir, 'han_rag_best.pth')

# ========================================
# 2. EXPORT ONNX (Single batch only)
# ========================================
print("\n2. Exporting ONNX model...")

# Load model
checkpoint = torch.load(best_model_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to('cpu')
model.eval()
print(f"✓ Loaded model from epoch {checkpoint['epoch']}")

# Export with FIXED batch=1 (no dynamic)
onnx_path = os.path.join(export_dir, 'han_rag_model.onnx')
dummy_input = (
    torch.randint(0, tokenizer.vocab_size, (1, 5, 256), dtype=torch.long),
    torch.ones((1, 5, 256), dtype=torch.long)
)

import warnings
warnings.filterwarnings('ignore')

torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    export_params=True,
    opset_version=14,  # Safe version
    do_constant_folding=True,
    input_names=['chunk_input_ids', 'chunk_attention_masks'],
    output_names=['logits', 'chunk_attention']
)

# Verify
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
onnx_size = os.path.getsize(onnx_path) / (1024**2)
print(f"✓ ONNX saved: {onnx_path}")
print(f"  Size: {onnx_size:.2f} MB")

# ========================================
# 3. SUMMARY
# ========================================
print("\n" + "="*80)
print("EXPORT COMPLETED")
print("="*80)
print(f"\nFiles exported:")


In [None]:
# After export
print("\n3. Creating external data file...")

# Load and save with external data
onnx_model = onnx.load(onnx_path)
onnx.save_model(
    onnx_model,
    onnx_path,
    save_as_external_data=True,
    all_tensors_to_one_file=True,
    location='han_rag_model.onnx.data',
    size_threshold=1024  # Save tensors >1KB externally
)

print(f"✓ External data created: han_rag_model.onnx.data")


In [None]:
# ================================================================================
# CELL 19: TEST ONNX MODEL WITH TEXT INPUT (NEW)
# ================================================================================

def predict_with_onnx(
    title: str,
    content: str,
    ort_session: ort.InferenceSession,
    tokenizer,
    normalizer,
    retriever,
    device='gpu'
) -> Dict:
    """
    Predict using ONNX model with raw text input
    """
    # Create mini dataset
    test_df = pd.DataFrame({
        'title': [title],
        'content': [content],
        'label': [0]  # Dummy
    })

    test_dataset = HierarchicalAttentionDataset(
        df=test_df,
        tokenizer=tokenizer,
        normalizer=normalizer,
        retriever=retriever,
        chunk_size=400,
        top_k=5,
        max_length=256,
        min_chunks=3
    )

    sample = test_dataset[0]

    # Prepare ONNX inputs
    onnx_inputs = {
        'chunk_input_ids': sample['chunk_input_ids'].unsqueeze(0).numpy(),
        'chunk_attention_masks': sample['chunk_attention_masks'].unsqueeze(0).numpy()
    }

    # Run inference
    onnx_outputs = ort_session.run(None, onnx_inputs)
    logits = onnx_outputs[0][0]  # [2]
    chunk_attention = onnx_outputs[1][0]  # [5]

    # Calculate probabilities
    probs = np.exp(logits) / np.sum(np.exp(logits))
    pred_class = np.argmax(probs)

    return {
        'prediction': 'FAKE' if pred_class == 1 else 'REAL',
        'confidence': float(probs[pred_class]),
        'probabilities': {
            'REAL': float(probs[0]),
            'FAKE': float(probs[1])
        },
        'chunk_attention': chunk_attention.tolist()
    }


# Test với COVID-19 example
print("=" * 80)
print("TESTING ONNX MODEL WITH TEXT INPUT")
print("=" * 80)

test_title = "NÓNG: Thế giới đối mặt với COVID-19"
test_content = """NÓNG ! 15/3 sẽ là ngày đáng nhớ cho không chỉ Việt Nam mà của toàn thế giới do những gì COVID-19 gây ra !!! - Rạng sáng 15/3, Pháp ra lệnh đóng cửa toàn bộ nhà hàng, rạp chiếu phim, cửa hàng,..trừ.. siêu thị, trạm xăng, ngân hàng, tabac, presse báo chí và pharmacie."""

print(f"\nInput:")
print(f"  Title: {test_title}")
print(f"  Content: {test_content[:150]}...")

# Predict
result = predict_with_onnx(
    test_title,
    test_content,
    ort_session,
    tokenizer,
    normalizer,
    retriever
)

print(f"\n{'='*80}")
print("ONNX PREDICTION RESULT")
print(f"{'='*80}")
print(f"Prediction: {result['prediction']}")
print(f"Confidence: {result['confidence']:.4f}")
print(f"\nProbabilities:")
print(f"  REAL: {result['probabilities']['REAL']:.4f}")
print(f"  FAKE: {result['probabilities']['FAKE']:.4f}")

print(f"\nChunk Attention Weights:")
for i, weight in enumerate(result['chunk_attention']):
    bar_length = int(weight * 40 / max(result['chunk_attention']))
    bar = "█" * bar_length
    print(f"  Chunk {i}: {bar} {weight:.4f}")

print("\nONNX model inference completed successfully!")


In [None]:
# ================================================================================
# CONFIG & PATHS
# ================================================================================
drive_save_dir = '/content/drive/MyDrive/FakeNewsModels/model_v4.1'
best_model_path = os.path.join(drive_save_dir, 'han_rag_best.pth')
onnx_path = os.path.join(drive_save_dir, 'han_rag_model.onnx')
deployment_dir = os.path.join(drive_save_dir, 'deployment_package')

batch_size = 1
num_chunks = 5
max_length = 256

# ================================================================================
# LOAD DEPENDENCIES
# ================================================================================
try:
    import onnx
    import onnxruntime as ort
except ImportError:
    !pip install -q onnx onnxruntime
    import onnx
    import onnxruntime as ort  # [web:21][web:30]

# ================================================================================
# LOAD MODEL & DUMMY INPUT
# ================================================================================
checkpoint = torch.load(best_model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

dummy_chunk_input_ids = torch.randint(
    0, tokenizer.vocab_size,
    (batch_size, num_chunks, max_length),
    dtype=torch.long, device=device
)
dummy_chunk_attention_masks = torch.ones(
    (batch_size, num_chunks, max_length),
    dtype=torch.long, device=device
)

with torch.no_grad():
    _ = model(dummy_chunk_input_ids, dummy_chunk_attention_masks)

# ================================================================================
# EXPORT TO ONNX
# ================================================================================
input_names = ['chunk_input_ids', 'chunk_attention_masks']
output_names = ['logits', 'chunk_attention']
dynamic_axes = {
    'chunk_input_ids': {0: 'batch_size'},
    'chunk_attention_masks': {0: 'batch_size'},
    'logits': {0: 'batch_size'},
    'chunk_attention': {0: 'batch_size'}
}

torch.onnx.export(
    model,
    (dummy_chunk_input_ids, dummy_chunk_attention_masks),
    onnx_path,
    export_params=True,
    opset_version=14,
    do_constant_folding=True,
    input_names=input_names,
    output_names=output_names,
    dynamic_axes=dynamic_axes,
    verbose=False
)  # [web:21][web:30]

onnx_size_mb = os.path.getsize(onnx_path) / (1024 * 1024)

# ================================================================================
# VALIDATE & QUICK TEST ONNX
# ================================================================================
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)

ort_session = ort.InferenceSession(
    onnx_path,
    providers=['CPUExecutionProvider']
)

test_sample = test_dataset[0]
test_input_ids = test_sample['chunk_input_ids'].unsqueeze(0).numpy()
test_attention_masks = test_sample['chunk_attention_masks'].unsqueeze(0).numpy()

onnx_inputs = {
    'chunk_input_ids': test_input_ids,
    'chunk_attention_masks': test_attention_masks
}
onnx_logits, onnx_attention = ort_session.run(None, onnx_inputs)

with torch.no_grad():
    pt_out = model(
        test_sample['chunk_input_ids'].unsqueeze(0).to(device),
        test_sample['chunk_attention_masks'].unsqueeze(0).to(device)
    )
    pytorch_logits = pt_out['logits'].cpu().numpy()
    pytorch_attention = pt_out['chunk_attention'].cpu().numpy()

logits_diff = float(np.abs(onnx_logits - pytorch_logits).max())
attention_diff = float(np.abs(onnx_attention - pytorch_attention).max())

pytorch_pred = int(np.argmax(pytorch_logits, axis=1)[0])
onnx_pred = int(np.argmax(onnx_logits, axis=1)[0])
pred_match = bool(pytorch_pred == onnx_pred)

# ================================================================================
# SAVE ONNX METADATA
# ================================================================================
onnx_metadata = {
    'model_type': 'HierarchicalAttentionClassifier',
    'framework': 'PyTorch -> ONNX',
    'opset_version': 14,
    'input_shapes': {
        'chunk_input_ids': [batch_size, num_chunks, max_length],
        'chunk_attention_masks': [batch_size, num_chunks, max_length]
    },
    'output_shapes': {
        'logits': [batch_size, 2],
        'chunk_attention': [batch_size, num_chunks]
    },
    'model_config': {
        'phobert_name': 'vinai/phobert-base-v2',
        'chunk_attention_hidden': 128,
        'num_classes': 2,
        'dropout': 0.3,
        'chunk_size': 400,
        'top_k': 5,
        'max_length': 256
    },
    'training_info': {
        'best_epoch': checkpoint['epoch'],
        'val_f1': float(checkpoint['val_f1']),
        'val_acc': float(checkpoint.get('val_acc', 0))
    },
    'validation': {
        'max_logits_diff': logits_diff,
        'max_attention_diff': attention_diff,
        'predictions_match': pred_match
    }
}

metadata_path = os.path.join(drive_save_dir, 'onnx_metadata.json')
with open(metadata_path, 'w', encoding='utf-8') as f:
    json.dump(onnx_metadata, f, indent=2, ensure_ascii=False)

# ================================================================================
# DEPLOYMENT PACKAGE (TOKENIZER, RETRIEVER, ONNX, CONFIG) + ZIP
# ================================================================================
os.makedirs(deployment_dir, exist_ok=True)

# 1. Tokenizer
tokenizer_dir = os.path.join(deployment_dir, 'tokenizer')
tokenizer.save_pretrained(tokenizer_dir)  # [web:16][web:13]

# 2. Retriever
retriever_dir = os.path.join(deployment_dir, 'retriever')
retriever.save(retriever_dir)

# 3. ONNX
shutil.copy(onnx_path, os.path.join(deployment_dir, 'han_rag_model.onnx'))

# 4. Config cho deployment
config = {
    'model_version': 'v2',
    'best_epoch': checkpoint['epoch'],
    'val_f1': float(checkpoint['val_f1']),
    'model_config': {
        'chunk_size': 400,
        'top_k': 5,
        'max_length': 256
    }
}
with open(os.path.join(deployment_dir, 'config.json'), 'w') as f:
    json.dump(config, f, indent=2)

# 5. ZIP
zip_base = os.path.join(drive_save_dir, 'deployment_package')
shutil.make_archive(zip_base, 'zip', deployment_dir)  # [web:24][web:36]
zip_path = zip_base + '.zip'
zip_size = os.path.getsize(zip_path) / (1024 * 1024)

print(f"ONNX: {onnx_path} ({onnx_size_mb:.2f} MB)")
print(f"Diff logits: {logits_diff:.6f}, attention: {attention_diff:.6f}, match: {pred_match}")
print(f"Metadata: {metadata_path}")
print(f"Deployment ZIP: {zip_path} ({zip_size:.2f} MB)")
print("Contents: han_rag_model.onnx, tokenizer/, retriever/, config.json")