# Fake News Detection with Deep Learning
## CDS525 Group Project

**Model**: BiLSTM + Attention + GloVe Pre-trained Embeddings

**Features**:
- ★ Merged Dataset: Original (5K) + News_dataset (45K) ≈ 50K samples
- ★ Text Data Augmentation (EDA: Random Deletion / Random Swap)
- ★ GloVe Pre-trained Word Embeddings (97%+ coverage)
- ★ Frozen GloVe → Reduces 2M trainable params, prevents overfitting
- ★ AdamW + Weight Decay + LR Scheduler + Early Stopping
- ★ Chain-of-Thought Reasoning for Explainability

**Required**: GPU Runtime (Runtime → Change runtime type → T4/A100)

In [None]:
# ============================================================
# Cell 1: Environment Setup
# ============================================================
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠ No GPU detected! Go to Runtime → Change runtime type → GPU")

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nUsing device: {DEVICE}")

## Step 1: Upload Data Files

**Option A (推荐)**: Mount Google Drive → 把数据文件放在 Google Drive 里

**Option B**: 直接上传文件

需要的文件:
1. `fakenews 2.csv` — 主数据集
2. `News _dataset/Fake.csv` — 外部假新闻数据集
3. `News _dataset/True.csv` — 外部真新闻数据集

In [None]:
# ============================================================
# Cell 2: Mount Google Drive & Set Data Paths
# ============================================================
# Option A: Google Drive (推荐 — 数据持久保存)
# 请先在 Google Drive 创建文件夹 "fakenews", 上传数据文件:
#   My Drive/fakenews/fakenews 2.csv
#   My Drive/fakenews/News _dataset/Fake.csv
#   My Drive/fakenews/News _dataset/True.csv

from google.colab import drive
drive.mount('/content/drive')

import os

# ★ 修改这里指向你 Google Drive 中的数据目录
DATA_ROOT = '/content/drive/MyDrive/fakenews'

DATA_PATH = os.path.join(DATA_ROOT, 'fakenews 2.csv')
EXTRA_DATA_DIR = os.path.join(DATA_ROOT, 'News _dataset')
OUTPUT_DIR = '/content/outputs'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# 验证文件存在
for p in [DATA_PATH,
          os.path.join(EXTRA_DATA_DIR, 'Fake.csv'),
          os.path.join(EXTRA_DATA_DIR, 'True.csv')]:
    if os.path.exists(p):
        size_mb = os.path.getsize(p) / 1e6
        print(f"  ✓ {os.path.basename(p)} ({size_mb:.1f} MB)")
    else:
        print(f"  ✗ NOT FOUND: {p}")

In [None]:
# ============================================================
# Cell 3: Download GloVe Pre-trained Embeddings
# ============================================================
import os, zipfile, urllib.request

GLOVE_DIR = '/content/glove'
GLOVE_PATH = os.path.join(GLOVE_DIR, 'glove.6B.100d.txt')

if not os.path.exists(GLOVE_PATH):
    os.makedirs(GLOVE_DIR, exist_ok=True)
    glove_url = 'https://nlp.stanford.edu/data/glove.6B.zip'
    zip_path = os.path.join(GLOVE_DIR, 'glove.6B.zip')

    print("Downloading GloVe embeddings (822 MB)... This may take a few minutes.")
    urllib.request.urlretrieve(glove_url, zip_path)
    print("Download complete. Extracting...")

    with zipfile.ZipFile(zip_path, 'r') as z:
        z.extract('glove.6B.100d.txt', GLOVE_DIR)

    os.remove(zip_path)  # 删除zip节省空间
    print(f"✓ GloVe extracted: {GLOVE_PATH}")
else:
    print(f"✓ GloVe already exists: {GLOVE_PATH}")

print(f"  File size: {os.path.getsize(GLOVE_PATH) / 1e6:.0f} MB")

## Step 2: Define All Modules

Below cells define the complete model pipeline (inline, no external files needed).

In [None]:
# ============================================================
# Cell 4: Data Augmentation Module (data_augment)
# ============================================================
"""
数据增强模块: 加载外部数据集 + EDA文本增强
"""
import os, random
import numpy as np
import pandas as pd

# ========================= 加载外部新闻数据集 =========================
def load_news_dataset(dataset_dir):
    """加载 News_dataset 目录下的 Fake.csv 和 True.csv"""
    fake_path = os.path.join(dataset_dir, "Fake.csv")
    true_path = os.path.join(dataset_dir, "True.csv")
    dfs = []

    if os.path.exists(fake_path):
        df_fake = pd.read_csv(fake_path)
        print(f"    外部数据集 Fake.csv: {len(df_fake)} 条")
        if 'title' in df_fake.columns:
            df_fake['text'] = df_fake.apply(
                lambda row: f"{row['title']}. {row['text']}"
                if pd.notna(row.get('title')) and str(row.get('title', '')).strip()
                else str(row.get('text', '')), axis=1)
        df_fake['label'] = 0
        dfs.append(df_fake[['text', 'label']])

    if os.path.exists(true_path):
        df_true = pd.read_csv(true_path)
        print(f"    外部数据集 True.csv: {len(df_true)} 条")
        if 'title' in df_true.columns:
            df_true['text'] = df_true.apply(
                lambda row: f"{row['title']}. {row['text']}"
                if pd.notna(row.get('title')) and str(row.get('title', '')).strip()
                else str(row.get('text', '')), axis=1)
        df_true['label'] = 1
        dfs.append(df_true[['text', 'label']])

    if dfs:
        df_extra = pd.concat(dfs, ignore_index=True)
        print(f"    外部数据集合计: {len(df_extra)} 条 "
              f"(Fake: {(df_extra['label']==0).sum()}, Real: {(df_extra['label']==1).sum()})")
        return df_extra
    else:
        return pd.DataFrame(columns=['text', 'label'])


# ========================= 多数据集合并 =========================
def merge_datasets(df_list, dedup=True):
    """合并多个DataFrame, 可选去重"""
    df = pd.concat(df_list, ignore_index=True)
    total_before = len(df)
    if dedup:
        df['_dedup_key'] = df['text'].astype(str).str[:100]
        df = df.drop_duplicates(subset='_dedup_key', keep='first')
        df = df.drop(columns='_dedup_key').reset_index(drop=True)
    total_after = len(df)
    print(f"  合并数据集: {total_before} 条 → 去重后 {total_after} 条 (移除 {total_before - total_after} 条重复)")
    print(f"  合并后标签分布: Fake={int((df['label']==0).sum())}, Real={int((df['label']==1).sum())}")
    return df


# ========================= 文本数据增强 (EDA) =========================
def _random_deletion(words, p=0.1):
    if len(words) <= 1:
        return words
    new_words = [w for w in words if random.random() > p]
    return new_words if new_words else [random.choice(words)]

def _random_swap(words, n=1):
    if len(words) < 2:
        return words
    new_words = words.copy()
    for _ in range(n):
        idx1, idx2 = random.sample(range(len(new_words)), 2)
        new_words[idx1], new_words[idx2] = new_words[idx2], new_words[idx1]
    return new_words

def augment_text(text, num_aug=1):
    words = text.split()
    if len(words) < 3:
        return [text] * num_aug
    augmented = []
    for _ in range(num_aug):
        method = random.choice(['delete', 'swap'])
        if method == 'delete':
            new_words = _random_deletion(words, p=0.1)
        else:
            n_swaps = max(1, len(words) // 20)
            new_words = _random_swap(words, n=n_swaps)
        augmented.append(' '.join(new_words))
    return augmented

def augment_dataset(X_train, y_train, num_aug=1, seed=42):
    """对训练集进行文本数据增强"""
    random.seed(seed)
    np.random.seed(seed)
    X_list, y_list = list(X_train), list(y_train)
    orig_size = len(X_list)
    for i in range(orig_size):
        aug_texts = augment_text(X_list[i], num_aug=num_aug)
        for aug_text in aug_texts:
            X_list.append(aug_text)
            y_list.append(y_list[i])
    combined = list(zip(X_list, y_list))
    random.shuffle(combined)
    X_aug, y_aug = zip(*combined)
    print(f"    数据增强: {orig_size} → {len(X_aug)} (+{len(X_aug) - orig_size} 条增强样本)")
    return np.array(X_aug), np.array(y_aug)

print("✓ Data Augmentation module loaded")

In [None]:
# ============================================================
# Cell 5: Data Loading & Preprocessing Module (data_utils)
# ============================================================
"""
数据加载与预处理: 清洗 → 词汇表 → Dataset → GloVe
"""
import re
from collections import Counter
import torch
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split

# ========================= 停用词列表 =========================
STOP_WORDS = set([
    'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
    'of', 'with', 'by', 'from', 'is', 'are', 'was', 'were', 'be', 'been',
    'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would',
    'could', 'should', 'may', 'might', 'shall', 'can', 'it', 'its',
    'that', 'which', 'who', 'whom', 'this', 'these', 'those', 'am',
    'i', 'me', 'my', 'we', 'our', 'you', 'your', 'he', 'him', 'his',
    'she', 'her', 'they', 'them', 'their', 'as', 'if', 'when', 'than',
    'so', 'no', 'not', 'up', 'out', 'about', 'into', 'over', 'after',
    'before', 'between', 'under', 'again', 'then', 'once', 'here',
    'there', 'where', 'how', 'all', 'both', 'each', 'more', 'other',
    'some', 'such', 'own', 'same', 'just', 'now', 's', 't', 'd', 'm',
])


# ========================= 文本清洗 =========================
def clean_text(text):
    """文本清洗: 小写 → 去HTML/URL → 只保留字母 → 去停用词"""
    if not isinstance(text, str):
        return ""
    text = text.lower()
    text = re.sub(r'<.*?>', ' ', text)
    text = re.sub(r'https?://\S+|www\.\S+', ' ', text)
    text = re.sub(r'[^a-z\s]', ' ', text)
    text = re.sub(r'\s+', ' ', text).strip()
    words = [w for w in text.split() if w not in STOP_WORDS and len(w) > 1]
    return " ".join(words)


# ========================= 词汇表 =========================
class Vocabulary:
    def __init__(self, max_vocab_size=20000):
        self.word2idx = {'<PAD>': 0, '<UNK>': 1}
        self.idx2word = {0: '<PAD>', 1: '<UNK>'}
        self.max_vocab_size = max_vocab_size

    def build(self, texts):
        counter = Counter()
        for text in texts:
            counter.update(text.split())
        for word, _ in counter.most_common(self.max_vocab_size - 2):
            idx = len(self.word2idx)
            self.word2idx[word] = idx
            self.idx2word[idx] = word

    def encode(self, text, max_length):
        """智能截断: 超长文本取前70%+后30%"""
        tokens = text.split()
        if len(tokens) > max_length:
            head_len = int(max_length * 0.7)
            tail_len = max_length - head_len
            tokens = tokens[:head_len] + tokens[-tail_len:]
        indices = [self.word2idx.get(w, 1) for w in tokens]
        padding_len = max_length - len(indices)
        if padding_len > 0:
            indices = indices + [0] * padding_len
        return indices

    def decode(self, indices):
        return ' '.join(self.idx2word.get(idx, '<UNK>') for idx in indices if idx != 0)

    @property
    def vocab_size(self):
        return len(self.word2idx)


# ========================= PyTorch Dataset =========================
class FakeNewsDataset(Dataset):
    def __init__(self, texts, labels, vocab, max_length=500):
        self.texts = texts
        self.labels = labels
        self.vocab = vocab
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        indices = self.vocab.encode(self.texts[idx], self.max_length)
        return (torch.tensor(indices, dtype=torch.long),
                torch.tensor(self.labels[idx], dtype=torch.float))


# ========================= 多数据集加载主函数 =========================
def load_and_preprocess_multi_data(csv_path, extra_dataset_dir=None,
                                   max_vocab_size=20000, max_length=500,
                                   test_size=0.2, val_size=0.1,
                                   random_state=42, augment=False, num_aug=1):
    """加载多个数据集 → 合并 → 清洗 → 划分 → (可选)增强 → 创建Dataset"""
    print(f"  加载主数据集: {csv_path}")
    df_main = pd.read_csv(csv_path)
    print(f"    主数据集: {len(df_main)} 条")

    df_list = [df_main]
    if extra_dataset_dir is not None:
        df_extra = load_news_dataset(extra_dataset_dir)
        df_list.append(df_extra)

    if len(df_list) > 1:
        df = merge_datasets(df_list, dedup=True)
    else:
        df = df_main

    df = df.dropna(subset=['text', 'label'])
    print(f"  开始文本清洗...")
    df['clean_text'] = df['text'].apply(clean_text)
    df = df[df['clean_text'].str.len() > 0].reset_index(drop=True)
    print(f"  清洗后数据量: {len(df)}")
    print(f"  标签分布: Fake={int((df['label']==0).sum())}, Real={int((df['label']==1).sum())}")

    texts = df['clean_text'].values
    labels = df['label'].values.astype(int)

    X_train, X_test, y_train, y_test = train_test_split(
        texts, labels, test_size=test_size, random_state=random_state, stratify=labels)
    val_ratio = val_size / (1 - test_size)
    X_train, X_val, y_train, y_val = train_test_split(
        X_train, y_train, test_size=val_ratio, random_state=random_state, stratify=y_train)
    print(f"  训练集: {len(X_train)}, 验证集: {len(X_val)}, 测试集: {len(X_test)}")

    if augment and num_aug > 0:
        print(f"  ★ 对训练集进行数据增强 (num_aug={num_aug})...")
        X_train, y_train = augment_dataset(X_train, y_train, num_aug=num_aug, seed=random_state)
        print(f"  增强后训练集: {len(X_train)}")

    vocab = Vocabulary(max_vocab_size)
    vocab.build(X_train)
    print(f"  词汇表大小: {vocab.vocab_size}")

    train_dataset = FakeNewsDataset(X_train, y_train, vocab, max_length)
    val_dataset = FakeNewsDataset(X_val, y_val, vocab, max_length)
    test_dataset = FakeNewsDataset(X_test, y_test, vocab, max_length)

    return train_dataset, val_dataset, test_dataset, vocab


# ========================= GloVe预训练词向量 =========================
def load_glove_embeddings(glove_path, vocab, embed_dim=100):
    """加载GloVe预训练词向量, 构建embedding矩阵"""
    print(f"  加载GloVe词向量: {glove_path}")
    glove_dict = {}
    with open(glove_path, 'r', encoding='utf-8') as f:
        for line in f:
            values = line.split()
            word = values[0]
            if word in vocab.word2idx:
                vector = np.array(values[1:], dtype=np.float32)
                if len(vector) == embed_dim:
                    glove_dict[word] = vector

    embedding_matrix = np.random.uniform(-0.25, 0.25, (vocab.vocab_size, embed_dim)).astype(np.float32)
    embedding_matrix[0] = 0  # <PAD>

    found = 0
    for word, idx in vocab.word2idx.items():
        if word in glove_dict:
            embedding_matrix[idx] = glove_dict[word]
            found += 1

    coverage = found / vocab.vocab_size
    print(f"  GloVe覆盖率: {found}/{vocab.vocab_size} = {coverage:.2%}")
    return torch.tensor(embedding_matrix), coverage

print("✓ Data Utils module loaded")

In [None]:
# ============================================================
# Cell 6: Model Definition (BiLSTM + Attention)
# ============================================================
"""
BiLSTM + Attention 分类器: 支持GloVe预训练词向量
"""
import torch
import torch.nn as nn


class BiLSTMAttentionClassifier(nn.Module):
    """双向LSTM + Attention 假新闻分类器"""

    def __init__(self, vocab_size, embed_dim=100, hidden_dim=128,
                 num_layers=2, dropout=0.5, pad_idx=0,
                 pretrained_embeddings=None, freeze_embeddings=False):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)

        # ★ 加载预训练词向量
        if pretrained_embeddings is not None:
            self.embedding.weight.data.copy_(pretrained_embeddings)
            if freeze_embeddings:
                self.embedding.weight.requires_grad = False

        self.lstm = nn.LSTM(
            input_size=embed_dim, hidden_size=hidden_dim,
            num_layers=num_layers, bidirectional=True,
            dropout=dropout if num_layers > 1 else 0, batch_first=True)

        lstm_out_dim = hidden_dim * 2  # 双向

        # Attention (两层MLP)
        self.attention = nn.Sequential(
            nn.Linear(lstm_out_dim, hidden_dim), nn.Tanh(), nn.Linear(hidden_dim, 1))

        # 分类头
        self.fc = nn.Sequential(
            nn.Dropout(dropout), nn.Linear(lstm_out_dim, 64),
            nn.ReLU(), nn.Dropout(0.3), nn.Linear(64, 1))

        self.dropout_emb = nn.Dropout(0.2)

    def forward(self, x, return_attention=False):
        embedded = self.dropout_emb(self.embedding(x))
        lstm_out, _ = self.lstm(embedded)

        # Attention
        attention_scores = self.attention(lstm_out).squeeze(-1)
        mask = (x != 0).float()
        attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
        attention_weights = torch.softmax(attention_scores, dim=1)
        context = torch.bmm(attention_weights.unsqueeze(1), lstm_out).squeeze(1)

        logits = self.fc(context).squeeze(-1)

        if return_attention:
            return logits, attention_weights
        return logits

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

print("✓ Model module loaded")

In [None]:
# ============================================================
# Cell 7: Trainer Module (Train & Evaluate)
# ============================================================
"""
训练与评估: FocalLoss, 学习率调度器, 早停, 梯度裁剪
"""
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix


class FocalLoss(nn.Module):
    """Focal Loss: 聚焦困难样本"""
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, logits, targets):
        bce_loss = nn.functional.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        probs = torch.sigmoid(logits)
        p_t = probs * targets + (1 - probs) * (1 - targets)
        focal_weight = (1 - p_t) ** self.gamma
        return (self.alpha * focal_weight * bce_loss).mean()


def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    all_preds, all_labels = [], []
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(inputs)
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item()
        preds = (torch.sigmoid(logits) > 0.5).float()
        all_preds.extend(preds.detach().cpu().numpy())
        all_labels.extend(labels.detach().cpu().numpy())
    return total_loss / len(dataloader), accuracy_score(all_labels, all_preds)


def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            logits = model(inputs)
            loss = criterion(logits, labels)
            total_loss += loss.item()
            preds = (torch.sigmoid(logits) > 0.5).float()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    avg_loss = total_loss / len(dataloader)
    metrics = {
        'loss': avg_loss,
        'accuracy': accuracy_score(all_labels, all_preds),
        'precision': precision_score(all_labels, all_preds, zero_division=0),
        'recall': recall_score(all_labels, all_preds, zero_division=0),
        'f1': f1_score(all_labels, all_preds, zero_division=0),
        'confusion_matrix': confusion_matrix(all_labels, all_preds)
    }
    return metrics, all_preds, all_labels


def train_model(model, train_dataset, val_dataset, test_dataset,
                criterion, optimizer, device,
                num_epochs=10, batch_size=32, use_scheduler=True):
    """完整训练流程: LR调度器 + 早停 + 保存最优模型"""
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    scheduler = None
    if use_scheduler:
        scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, min_lr=1e-6)

    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'test_acc': []}
    best_val_acc = 0
    best_model_state = None
    patience_counter = 0
    early_stop_patience = 5

    for epoch in range(num_epochs):
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_metrics, _, _ = evaluate(model, val_loader, criterion, device)
        test_metrics, _, _ = evaluate(model, test_loader, criterion, device)

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_metrics['loss'])
        history['val_acc'].append(val_metrics['accuracy'])
        history['test_acc'].append(test_metrics['accuracy'])

        current_lr = optimizer.param_groups[0]['lr']
        print(f"  Epoch [{epoch+1}/{num_epochs}] "
              f"Loss: {train_loss:.4f} | Train: {train_acc:.4f} | "
              f"Val: {val_metrics['accuracy']:.4f}(F1:{val_metrics['f1']:.4f}) | "
              f"Test: {test_metrics['accuracy']:.4f} | LR: {current_lr:.6f}")

        if scheduler is not None:
            scheduler.step(val_metrics['accuracy'])

        if val_metrics['accuracy'] > best_val_acc:
            best_val_acc = val_metrics['accuracy']
            best_model_state = {k: v.clone() for k, v in model.state_dict().items()}
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= early_stop_patience:
            print(f"  Early stopping at epoch {epoch+1} (val acc not improving)")
            break

    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    return history, model

print("✓ Trainer module loaded")

In [None]:
# ============================================================
# Cell 8: Visualization Module
# ============================================================
"""
可视化: 满足作业要求的全部图表 (Fig 1-8)
"""
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix as cm_func

matplotlib.rcParams['font.family'] = ['DejaVu Sans']
import warnings
warnings.filterwarnings('ignore', message='.*findfont.*')
matplotlib.rcParams['axes.unicode_minus'] = False


def plot_training_curves(history, title="Training Curves", save_path=None):
    """Fig 1/2: 训练损失 + 训练准确率 + 测试准确率"""
    epochs = range(1, len(history['train_loss']) + 1)
    fig, ax1 = plt.subplots(figsize=(10, 6))
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', color='tab:red', fontsize=12)
    line1 = ax1.plot(epochs, history['train_loss'], 'r-o', label='Train Loss', markersize=4, linewidth=2)
    ax1.tick_params(axis='y', labelcolor='tab:red')
    ax1.grid(True, alpha=0.3)
    ax2 = ax1.twinx()
    ax2.set_ylabel('Accuracy', color='tab:blue', fontsize=12)
    line2 = ax2.plot(epochs, history['train_acc'], 'b-s', label='Train Accuracy', markersize=4, linewidth=2)
    line3 = ax2.plot(epochs, history['test_acc'], 'g-^', label='Test Accuracy', markersize=4, linewidth=2)
    ax2.tick_params(axis='y', labelcolor='tab:blue')
    ax2.set_ylim([0, 1.05])
    lines = line1 + line2 + line3
    ax1.legend(lines, [l.get_label() for l in lines], loc='center right', fontsize=10)
    plt.title(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"    Saved: {save_path}")
    plt.show()


def plot_lr_comparison(histories_dict, title="LR Comparison", save_path=None):
    """Fig 3/4: 不同学习率对比 (3个子图)"""
    fig, axes = plt.subplots(1, 3, figsize=(20, 6))
    colors = plt.cm.tab10(np.linspace(0, 1, len(histories_dict)))
    for idx, (lr, history) in enumerate(histories_dict.items()):
        epochs = range(1, len(history['train_loss']) + 1)
        c = colors[idx]
        axes[0].plot(epochs, history['train_loss'], color=c, marker='o', markersize=3, linewidth=1.5, label=f'LR={lr}')
        axes[1].plot(epochs, history['train_acc'], color=c, marker='s', markersize=3, linewidth=1.5, label=f'LR={lr}')
        axes[2].plot(epochs, history['test_acc'], color=c, marker='^', markersize=3, linewidth=1.5, label=f'LR={lr}')
    for ax, ylabel, subtitle in zip(axes, ['Loss', 'Accuracy', 'Accuracy'],
                                     ['Training Loss', 'Training Accuracy', 'Test Accuracy']):
        ax.set_xlabel('Epoch', fontsize=11); ax.set_ylabel(ylabel, fontsize=11)
        ax.set_title(subtitle, fontsize=12); ax.legend(fontsize=9); ax.grid(True, alpha=0.3)
    fig.suptitle(title, fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"    Saved: {save_path}")
    plt.show()


def plot_batch_size_comparison(histories_dict, title="Batch Size Comparison", save_path=None):
    """Fig 5/6: 不同batch size对比"""
    fig, axes = plt.subplots(1, 3, figsize=(20, 6))
    colors = plt.cm.tab10(np.linspace(0, 1, len(histories_dict)))
    for idx, (bs, history) in enumerate(histories_dict.items()):
        epochs = range(1, len(history['train_loss']) + 1)
        c = colors[idx]
        axes[0].plot(epochs, history['train_loss'], color=c, marker='o', markersize=3, linewidth=1.5, label=f'BS={bs}')
        axes[1].plot(epochs, history['train_acc'], color=c, marker='s', markersize=3, linewidth=1.5, label=f'BS={bs}')
        axes[2].plot(epochs, history['test_acc'], color=c, marker='^', markersize=3, linewidth=1.5, label=f'BS={bs}')
    for ax, ylabel, subtitle in zip(axes, ['Loss', 'Accuracy', 'Accuracy'],
                                     ['Training Loss', 'Training Accuracy', 'Test Accuracy']):
        ax.set_xlabel('Epoch', fontsize=11); ax.set_ylabel(ylabel, fontsize=11)
        ax.set_title(subtitle, fontsize=12); ax.legend(fontsize=9); ax.grid(True, alpha=0.3)
    fig.suptitle(title, fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"    Saved: {save_path}")
    plt.show()


def plot_predictions_table(texts, true_labels, pred_labels, n=100, save_path=None):
    """Fig 7: 前n条测试集预测结果"""
    from matplotlib.colors import ListedColormap
    n = min(n, len(texts))
    label_map = {0: 'Fake', 1: 'Real'}
    fig, axes = plt.subplots(1, 2, figsize=(16, 8), gridspec_kw={'width_ratios': [1, 3]})
    grid_size = int(np.ceil(np.sqrt(n)))
    grid = np.zeros((grid_size, grid_size))
    for i in range(n):
        row, col = i // grid_size, i % grid_size
        grid[row][col] = 1 if true_labels[i] == pred_labels[i] else -1
    cmap = ListedColormap(['#ff6b6b', 'white', '#51cf66'])
    axes[0].imshow(grid, cmap=cmap, vmin=-1, vmax=1, aspect='equal')
    axes[0].set_title(f'Prediction Grid (First {n})\nGreen=Correct, Red=Wrong', fontsize=11)
    axes[1].axis('off')
    show_n = min(20, n)
    table_data, cell_colors = [], []
    for i in range(show_n):
        text_preview = texts[i][:60] + '...' if len(texts[i]) > 60 else texts[i]
        true_l, pred_l = label_map[int(true_labels[i])], label_map[int(pred_labels[i])]
        correct = 'Yes' if true_labels[i] == pred_labels[i] else 'No'
        table_data.append([i+1, text_preview, true_l, pred_l, correct])
        cell_colors.append(['#d4edda' if correct == 'Yes' else '#f8d7da'] * 5)
    table = axes[1].table(cellText=table_data, colLabels=['#', 'Text', 'True', 'Pred', 'Correct'],
                          cellColours=cell_colors, loc='upper center', cellLoc='left')
    table.auto_set_font_size(False); table.set_fontsize(8); table.auto_set_column_width([0,1,2,3,4])
    correct_count = sum(1 for i in range(n) if true_labels[i] == pred_labels[i])
    fig.suptitle(f'Test Predictions: {correct_count}/{n} correct ({100*correct_count/n:.1f}%)',
                 fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"    Saved: {save_path}")
    plt.show()


def plot_confusion_matrix(true_labels, pred_labels, save_path=None):
    """Fig 8: 混淆矩阵"""
    cm = cm_func(true_labels, pred_labels)
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Fake (0)', 'Real (1)'], yticklabels=['Fake (0)', 'Real (1)'],
                ax=ax, annot_kws={"size": 16})
    ax.set_xlabel('Predicted Label', fontsize=12); ax.set_ylabel('True Label', fontsize=12)
    ax.set_title('Confusion Matrix', fontsize=14, fontweight='bold')
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"    Saved: {save_path}")
    plt.show()

print("✓ Visualization module loaded")

In [None]:
# ============================================================
# Cell 9: Chain-of-Thought Reasoning Module
# ============================================================
"""
推理链 (CoT): 结合Attention权重 + 文本特征分析, 生成可解释的判断依据
"""
import torch
import numpy as np

# ========================= 特征词典 =========================
SENSATIONAL_WORDS = {
    'shocking', 'unbelievable', 'breaking', 'exclusive', 'urgent',
    'bombshell', 'horrifying', 'terrifying', 'scandal', 'outrage',
    'devastating', 'explosive', 'stunning', 'alarming', 'incredible',
    'exposed', 'secret', 'leaked', 'conspiracy', 'hoax', 'cover',
    'destroyed', 'slammed', 'blasted', 'ripped', 'epic', 'insane'
}
CREDIBILITY_PHRASES = [
    'according to', 'research shows', 'study finds', 'officials said',
    'reuters', 'associated press', 'confirmed by', 'evidence suggests',
    'data shows', 'report says', 'peer reviewed', 'investigation found',
    'spokesperson said', 'published in', 'university of', 'department of',
    'official statement', 'press release', 'government report'
]
EMOTIONAL_WORDS = {
    'hate', 'love', 'angry', 'furious', 'amazing', 'terrible',
    'disgusting', 'wonderful', 'horrible', 'fantastic', 'awful',
    'outrageous', 'evil', 'hero', 'villain', 'miracle', 'disaster',
    'tragic', 'brilliant', 'stupid', 'genius', 'idiot', 'corrupt',
    'patriot', 'traitor', 'danger', 'threat', 'crisis', 'doom'
}
CLICKBAIT_PATTERNS = [
    'you won', 'believe', 'click here', 'share this', 'going viral',
    'mind blowing', 'what happened next', 'number', 'will shock',
    'doctors hate', 'one weird trick', 'exposed', 'they don want'
]


class ChainOfThoughtAnalyzer:
    """推理链分析器: Attention + 文本特征 → 可解释推理过程"""

    def __init__(self, model, vocab, device, max_length=500):
        self.model = model
        self.vocab = vocab
        self.device = device
        self.max_length = max_length

    def analyze(self, text, original_text=None):
        display_text = original_text if original_text else text
        self.model.eval()
        indices = self.vocab.encode(text, self.max_length)
        input_tensor = torch.tensor([indices], dtype=torch.long).to(self.device)
        with torch.no_grad():
            logits, attention_weights = self.model(input_tensor, return_attention=True)
        prob = torch.sigmoid(logits).item()
        prediction = 'Real' if prob > 0.5 else 'Fake'
        confidence = prob if prob > 0.5 else 1 - prob

        attention = attention_weights.squeeze().cpu().numpy()
        tokens = text.split()[:self.max_length]
        word_attention = {}
        for i, token in enumerate(tokens):
            if i < len(attention):
                word_attention[token] = max(word_attention.get(token, 0), attention[i])
        top_words = sorted(word_attention.items(), key=lambda x: x[1], reverse=True)[:10]

        features = self._analyze_features(text)
        reasoning = self._generate_reasoning(prediction, confidence, top_words, features)

        return {
            'text_preview': display_text[:300] + '...' if len(display_text) > 300 else display_text,
            'prediction': prediction,
            'confidence': f"{confidence:.2%}",
            'top_attention_words': [(w, f"{a:.4f}") for w, a in top_words[:5]],
            'text_features': features,
            'reasoning_chain': reasoning
        }

    def _analyze_features(self, text):
        words = text.split()
        text_lower = text.lower()
        sensational_found = [w for w in words if w in SENSATIONAL_WORDS]
        sensational_ratio = len(sensational_found) / max(len(words), 1)
        credibility_found = [p for p in CREDIBILITY_PHRASES if p in text_lower]
        emotional_found = [w for w in words if w in EMOTIONAL_WORDS]
        emotional_ratio = len(emotional_found) / max(len(words), 1)
        clickbait_found = [p for p in CLICKBAIT_PATTERNS if p in text_lower]
        avg_word_len = np.mean([len(w) for w in words]) if words else 0

        def score(ratio, thresholds=(0.01, 0.005)):
            if ratio > thresholds[0]: return 'HIGH'
            elif ratio > thresholds[1]: return 'MEDIUM'
            return 'LOW'

        return {
            'sensational_score': score(sensational_ratio),
            'sensational_words': sensational_found[:5],
            'credibility_score': 'HIGH' if len(credibility_found) >= 2 else 'MEDIUM' if len(credibility_found) == 1 else 'LOW',
            'credibility_indicators': credibility_found[:3],
            'emotional_score': score(emotional_ratio),
            'emotional_words': emotional_found[:5],
            'clickbait_score': 'HIGH' if len(clickbait_found) >= 2 else 'MEDIUM' if len(clickbait_found) == 1 else 'LOW',
            'clickbait_patterns': clickbait_found[:3],
            'text_length': len(words),
            'avg_word_length': f"{avg_word_len:.1f}"
        }

    def _generate_reasoning(self, prediction, confidence, top_words, features):
        lines = []
        lines.append("=" * 50)
        lines.append("Step 1 - Text Feature Analysis:")
        lines.append(f"  [Sensational Language] {features['sensational_score']}")
        if features['sensational_words']:
            lines.append(f"    Found: {', '.join(features['sensational_words'])}")
        lines.append(f"  [Source Credibility]  {features['credibility_score']}")
        if features['credibility_indicators']:
            lines.append(f"    Found: {', '.join(features['credibility_indicators'])}")
        lines.append(f"  [Emotional Tone]     {features['emotional_score']}")
        if features['emotional_words']:
            lines.append(f"    Found: {', '.join(features['emotional_words'])}")
        lines.append(f"  [Clickbait Pattern]  {features['clickbait_score']}")
        if features['clickbait_patterns']:
            lines.append(f"    Found: {', '.join(features['clickbait_patterns'])}")
        lines.append(f"  [Text Length]        {features['text_length']} words")

        lines.append("\nStep 2 - Model Attention Key Words:")
        for word, weight in top_words[:5]:
            bar = '█' * int(float(weight) * 500)
            lines.append(f"  '{word}' [{weight}] {bar}")

        lines.append("\nStep 3 - Reasoning Chain:")
        reasons = []
        if prediction == 'Fake':
            if features['sensational_score'] in ('HIGH', 'MEDIUM'):
                reasons.append("The text contains sensational/exaggerated language, common in fabricated news.")
            if features['credibility_score'] == 'LOW':
                reasons.append("The text lacks references to credible sources, reducing its reliability.")
            if features['emotional_score'] in ('HIGH', 'MEDIUM'):
                reasons.append("Highly emotional language is detected, often used to manipulate reader opinions.")
            if features['clickbait_score'] in ('HIGH', 'MEDIUM'):
                reasons.append("Clickbait-style phrases are present, indicating engagement over accuracy.")
            if not reasons:
                reasons.append("The model's learned patterns indicate this text matches characteristics of fake news.")
        else:
            if features['credibility_score'] in ('HIGH', 'MEDIUM'):
                reasons.append("The text references credible sources, consistent with legitimate news.")
            if features['sensational_score'] == 'LOW':
                reasons.append("The text uses neutral, factual language consistent with professional journalism.")
            if features['emotional_score'] == 'LOW':
                reasons.append("The text maintains an objective tone without excessive emotional manipulation.")
            if not reasons:
                reasons.append("The model's learned patterns indicate this text matches characteristics of real news.")
        for i, reason in enumerate(reasons, 1):
            lines.append(f"  {i}. {reason}")

        lines.append(f"\n{'=' * 50}")
        lines.append(f"Conclusion: This article is classified as [{prediction}] with {confidence:.2%} confidence.")
        lines.append("=" * 50)
        return '\n'.join(lines)


def batch_analyze(analyzer, texts, n=5):
    """对多条文本进行推理链分析"""
    print("\n" + "=" * 60)
    print("  Chain-of-Thought Reasoning Demo")
    print("=" * 60)
    for i in range(min(n, len(texts))):
        print(f"\n{'─' * 60}")
        print(f"  Sample {i+1}:")
        print(f"{'─' * 60}")
        result = analyzer.analyze(texts[i])
        print(f"\nText Preview: {result['text_preview']}")
        print(f"\nPrediction: {result['prediction']} (Confidence: {result['confidence']})")
        print(f"\n{result['reasoning_chain']}")
    print(f"\n{'═' * 60}")

print("✓ Chain-of-Thought module loaded")

## Step 3: Configuration & Helper Functions

In [None]:
# ============================================================
# Cell 10: Configuration & Helper Functions
# ============================================================

# 默认超参数 (GloVe优化版)
CONFIG = {
    'embed_dim': 100,           # GloVe维度
    'hidden_dim': 128,          # LSTM隐藏层维度
    'num_layers': 2,            # LSTM层数
    'dropout': 0.5,             # Dropout比率
    'max_vocab_size': 20000,    # 词汇表大小
    'max_length': 500,          # 文本最大长度 (token数)
    'num_epochs': 20,           # 最大训练轮数
    'batch_size': 32,           # 批大小
    'learning_rate': 0.001,     # 学习率
    'weight_decay': 1e-4,       # 权重衰减 (L2正则化)
    'freeze_embeddings': True,  # 冻结GloVe嵌入层
}

# 全局变量
_glove_matrix = None
_vocab = None


def create_model(vocab_size, glove_matrix=None):
    """创建新的模型实例 (带GloVe)"""
    model = BiLSTMAttentionClassifier(
        vocab_size=vocab_size,
        embed_dim=CONFIG['embed_dim'],
        hidden_dim=CONFIG['hidden_dim'],
        num_layers=CONFIG['num_layers'],
        dropout=CONFIG['dropout'],
        pretrained_embeddings=glove_matrix,
        freeze_embeddings=CONFIG['freeze_embeddings']
    )
    return model.to(DEVICE)


def run_experiment(train_ds, val_ds, test_ds, vocab_size,
                   loss_fn='bce', learning_rate=None, batch_size=None,
                   num_epochs=None, use_scheduler=True, glove_matrix=None):
    """运行单次训练实验"""
    lr = learning_rate or CONFIG['learning_rate']
    bs = batch_size or CONFIG['batch_size']
    epochs = num_epochs or CONFIG['num_epochs']

    model = create_model(vocab_size, glove_matrix)

    if loss_fn == 'bce':
        criterion = nn.BCEWithLogitsLoss()
    elif loss_fn == 'focal':
        criterion = FocalLoss(alpha=0.25, gamma=2.0)
    else:
        raise ValueError(f"Unknown loss: {loss_fn}")
    criterion = criterion.to(DEVICE)

    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=lr, weight_decay=CONFIG['weight_decay'])

    history, best_model = train_model(
        model, train_ds, val_ds, test_ds,
        criterion, optimizer, DEVICE,
        num_epochs=epochs, batch_size=bs, use_scheduler=use_scheduler)

    return history, best_model

print("✓ Configuration loaded")
print(f"  Device: {DEVICE}")
print(f"  Config: {CONFIG}")

## Step 4: Load Data + GloVe Embeddings

This will:
1. Load the main dataset (`fakenews 2.csv`)
2. Load the extra dataset (`News_dataset/`)
3. Merge & deduplicate → ~50K samples
4. Apply text data augmentation → ~60K training samples
5. Load GloVe pre-trained word embeddings

In [None]:
# ============================================================
# Cell 11: Load Data + GloVe
# ============================================================
import time

print("=" * 60)
print("  Fake News Detection with Deep Learning")
print("  BiLSTM + Attention + GloVe + Chain-of-Thought")
print(f"  Device: {DEVICE}")
print("=" * 60)

print("\n[Step 1/7] Loading data and GloVe embeddings...")
t0 = time.time()

train_ds, val_ds, test_ds, vocab = load_and_preprocess_multi_data(
    DATA_PATH,
    extra_dataset_dir=EXTRA_DATA_DIR,
    max_vocab_size=CONFIG['max_vocab_size'],
    max_length=CONFIG['max_length'],
    augment=True,     # ★ 开启文本数据增强
    num_aug=1,        # 每条训练样本生成1条增强样本
)
_vocab = vocab
vocab_size = vocab.vocab_size

# ★ 加载GloVe预训练词向量
glove_matrix = None
if os.path.exists(GLOVE_PATH):
    glove_matrix, coverage = load_glove_embeddings(GLOVE_PATH, vocab, embed_dim=CONFIG['embed_dim'])
    _glove_matrix = glove_matrix
    print(f"  ★ GloVe loaded! Coverage: {coverage:.2%}")
else:
    print(f"  ⚠ GloVe not found: {GLOVE_PATH}")

print(f"\n  ✓ Data loaded in {time.time()-t0:.1f}s")
print(f"  Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}")
print(f"  Vocabulary: {vocab_size} words")

## Step 5: Run Experiments (Fig 1-6)

Each experiment trains a fresh model with different hyperparameters.
With GPU, each experiment takes ~2-5 minutes (vs 30+ min on CPU).

In [None]:
# ============================================================
# Cell 12: Experiment 1 - BCE Loss (Default Config) → Fig 1
# ============================================================
print("=" * 60)
print("[Step 2/7] Experiment 1: BCE Loss (default config)")
print("=" * 60)

t0 = time.time()
history_bce, model_bce = run_experiment(
    train_ds, val_ds, test_ds, vocab_size,
    loss_fn='bce', glove_matrix=glove_matrix
)
print(f"  Done in {time.time()-t0:.1f}s")

plot_training_curves(
    history_bce,
    title="Fig 1: BCE Loss - Default Config (GloVe)",
    save_path=os.path.join(OUTPUT_DIR, "fig1_bce_default.png")
)

In [None]:
# ============================================================
# Cell 13: Experiment 2 - Focal Loss (Default Config) → Fig 2
# ============================================================
print("=" * 60)
print("[Step 3/7] Experiment 2: Focal Loss (default config)")
print("=" * 60)

t0 = time.time()
history_focal, model_focal = run_experiment(
    train_ds, val_ds, test_ds, vocab_size,
    loss_fn='focal', glove_matrix=glove_matrix
)
print(f"  Done in {time.time()-t0:.1f}s")

plot_training_curves(
    history_focal,
    title="Fig 2: Focal Loss - Default Config (GloVe)",
    save_path=os.path.join(OUTPUT_DIR, "fig2_focal_default.png")
)

In [None]:
# ============================================================
# Cell 14: Experiment 3 - Learning Rate Comparison → Fig 3 & 4
# ============================================================
print("=" * 60)
print("[Step 4/7] Experiment 3: Learning Rate Comparison")
print("=" * 60)

learning_rates = [0.01, 0.001, 0.0001, 0.00001]

# --- Fig 3: BCE + 不同LR ---
print("\n  --- BCE Loss ---")
lr_histories_bce = {}
for lr in learning_rates:
    print(f"\n  LR = {lr}:")
    h, _ = run_experiment(train_ds, val_ds, test_ds, vocab_size,
                          loss_fn='bce', learning_rate=lr, glove_matrix=glove_matrix)
    lr_histories_bce[lr] = h

plot_lr_comparison(
    lr_histories_bce,
    title="Fig 3: BCE Loss - Learning Rate Comparison",
    save_path=os.path.join(OUTPUT_DIR, "fig3_lr_bce.png")
)

# --- Fig 4: Focal + 不同LR ---
print("\n  --- Focal Loss ---")
lr_histories_focal = {}
for lr in learning_rates:
    print(f"\n  LR = {lr}:")
    h, _ = run_experiment(train_ds, val_ds, test_ds, vocab_size,
                          loss_fn='focal', learning_rate=lr, glove_matrix=glove_matrix)
    lr_histories_focal[lr] = h

plot_lr_comparison(
    lr_histories_focal,
    title="Fig 4: Focal Loss - Learning Rate Comparison",
    save_path=os.path.join(OUTPUT_DIR, "fig4_lr_focal.png")
)

In [None]:
# ============================================================
# Cell 15: Experiment 4 - Batch Size Comparison → Fig 5 & 6
# ============================================================
print("=" * 60)
print("[Step 5/7] Experiment 4: Batch Size Comparison")
print("=" * 60)

batch_sizes = [16, 32, 64, 128]

# --- Fig 5: BCE + 不同BS ---
print("\n  --- BCE Loss ---")
bs_histories_bce = {}
for bs in batch_sizes:
    print(f"\n  Batch Size = {bs}:")
    h, _ = run_experiment(train_ds, val_ds, test_ds, vocab_size,
                          loss_fn='bce', batch_size=bs, glove_matrix=glove_matrix)
    bs_histories_bce[bs] = h

plot_batch_size_comparison(
    bs_histories_bce,
    title="Fig 5: BCE Loss - Batch Size Comparison",
    save_path=os.path.join(OUTPUT_DIR, "fig5_bs_bce.png")
)

# --- Fig 6: Focal + 不同BS ---
print("\n  --- Focal Loss ---")
bs_histories_focal = {}
for bs in batch_sizes:
    print(f"\n  Batch Size = {bs}:")
    h, _ = run_experiment(train_ds, val_ds, test_ds, vocab_size,
                          loss_fn='focal', batch_size=bs, glove_matrix=glove_matrix)
    bs_histories_focal[bs] = h

plot_batch_size_comparison(
    bs_histories_focal,
    title="Fig 6: Focal Loss - Batch Size Comparison",
    save_path=os.path.join(OUTPUT_DIR, "fig6_bs_focal.png")
)

## Step 6: Final Evaluation (Fig 7 & 8)

In [None]:
# ============================================================
# Cell 16: Final Evaluation → Fig 7 (Predictions) & Fig 8 (Confusion Matrix)
# ============================================================
print("=" * 60)
print("[Step 6/7] Generating prediction visualizations...")
print("=" * 60)

# 使用BCE模型的最优结果进行预测
criterion_eval = nn.BCEWithLogitsLoss()
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)
test_metrics, all_preds, all_labels = evaluate(model_bce, test_loader, criterion_eval, DEVICE)

print(f"\n  === Final Test Results (Best BCE + GloVe Model) ===")
print(f"  Accuracy:  {test_metrics['accuracy']:.4f}")
print(f"  Precision: {test_metrics['precision']:.4f}")
print(f"  Recall:    {test_metrics['recall']:.4f}")
print(f"  F1 Score:  {test_metrics['f1']:.4f}")
print(f"  Confusion Matrix:\n{test_metrics['confusion_matrix']}")

# Fig 7: 前100条预测结果
plot_predictions_table(
    list(test_ds.texts[:100]),
    all_labels[:100], all_preds[:100], n=100,
    save_path=os.path.join(OUTPUT_DIR, "fig7_predictions.png")
)

# Fig 8: 混淆矩阵
plot_confusion_matrix(
    all_labels, all_preds,
    save_path=os.path.join(OUTPUT_DIR, "fig8_confusion_matrix.png")
)

## Step 7: Chain-of-Thought Reasoning Demo

The CoT module analyzes:
1. **Text features**: sensational language, source credibility, emotional tone, clickbait patterns
2. **Model attention weights**: which words the model focused on most
3. **Reasoning chain**: step-by-step logical explanation for the prediction

In [None]:
# ============================================================
# Cell 17: Chain-of-Thought Reasoning Demo
# ============================================================
print("=" * 60)
print("[Step 7/7] Chain-of-Thought Reasoning Demo")
print("=" * 60)

cot_analyzer = ChainOfThoughtAnalyzer(model_bce, vocab, DEVICE, CONFIG['max_length'])

# 对测试集前5条进行推理链分析
batch_analyze(cot_analyzer, list(test_ds.texts), n=5)

## Step 8: Save Outputs to Google Drive (Optional)

Copy all output figures to your Google Drive for safekeeping.

In [None]:
# ============================================================
# Cell 18: Save outputs to Google Drive
# ============================================================
import shutil

drive_output = os.path.join(DATA_ROOT, 'outputs')
os.makedirs(drive_output, exist_ok=True)

# 复制所有输出图片到 Google Drive
for fname in os.listdir(OUTPUT_DIR):
    if fname.endswith('.png'):
        src = os.path.join(OUTPUT_DIR, fname)
        dst = os.path.join(drive_output, fname)
        shutil.copy2(src, dst)
        print(f"  ✓ Saved: {dst}")

print(f"\n  All {len(os.listdir(drive_output))} figures saved to Google Drive!")
print(f"  Location: {drive_output}")

print("\n" + "=" * 60)
print("  ✅ All experiments completed!")
print("  Generated figures: fig1 ~ fig8")
print("=" * 60)