# üî¨ Fake News Detection with DistilBERT
## CDS525 Group Project ‚Äî Transformer-based Approach

**Model**: DistilBERT (Distilled BERT) ‚Äî 40% smaller, 60% faster, retains 97% of BERT's performance

**Key Features**:
- ‚òÖ **Pre-trained Transformer**: DistilBERT-base-uncased (66M params) with fine-tuning
- ‚òÖ **Merged Dataset**: Original (5K) + News_dataset (45K) ‚âà 50K samples
- ‚òÖ **Text Data Augmentation** (EDA: Random Deletion / Swap)
- ‚òÖ **Attention-based Explainability**: Extract DistilBERT attention weights for CoT reasoning
- ‚òÖ **Training Strategy**: AdamW + Linear Warmup + Weight Decay + Early Stopping
- ‚òÖ **Comprehensive Evaluation**: Accuracy, F1, ROC-AUC, Confusion Matrix
- ‚òÖ **8 Required Figures** + Chain-of-Thought Reasoning Demo

**Reference**: Training strategy inspired by `with-new-data.ipynb` (BERT/DistilBERT approach)

**‚ö† REQUIRED**: GPU Runtime ‚Üí Runtime ‚Üí Change runtime type ‚Üí **T4 GPU**

In [None]:
# ============================================================
# Cell 1: Environment Setup & Install Dependencies
# ============================================================
%pip install -q transformers accelerate

import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("‚ö† No GPU! Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nUsing: {DEVICE}")

In [None]:
# ============================================================
# Cell 2: Mount Google Drive & Set Data Paths
# ============================================================
from google.colab import drive
drive.mount('/content/drive')

import os

# ‚òÖ ‰øÆÊîπËøôÈáåÊåáÂêë‰Ω† Google Drive ‰∏≠ÁöÑÊï∞ÊçÆÁõÆÂΩï
DATA_ROOT = '/content/drive/MyDrive/fakenews'

DATA_PATH = os.path.join(DATA_ROOT, 'fakenews 2.csv')
EXTRA_DATA_DIR = os.path.join(DATA_ROOT, 'News _dataset')
OUTPUT_DIR = '/content/outputs_bert'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# È™åËØÅÊñá‰ª∂
for p in [DATA_PATH,
          os.path.join(EXTRA_DATA_DIR, 'Fake.csv'),
          os.path.join(EXTRA_DATA_DIR, 'True.csv')]:
    if os.path.exists(p):
        print(f"  ‚úì {os.path.basename(p)} ({os.path.getsize(p)/1e6:.1f} MB)")
    else:
        print(f"  ‚úó NOT FOUND: {p}")

## Step 1: Data Loading, Merging & Preprocessing

**BERT preprocessing is minimal** (compared to LSTM):
- Preserve punctuation, stopwords ‚Üí BERT's self-attention needs them for context
- Only remove URLs, HTML tags
- Remove "Reuters" data leakage (as found in reference notebook)
- BERT tokenizer handles subword tokenization automatically

In [None]:
# ============================================================
# Cell 3: Data Loading & Merging & Minimal BERT Cleaning
# ============================================================
import pandas as pd
import numpy as np
import re
import os
import random
from sklearn.model_selection import train_test_split

# ========================= BERT-specific Text Cleaning =========================
def clean_for_bert(text):
    """
    Minimal cleaning for BERT:
    - BERT's self-attention NEEDS punctuation, stopwords, and sentence structure
    - Only remove noise: URLs, HTML tags, Reuters leakage
    """
    text = str(text).lower()
    text = re.sub(r'https?://\S+|www\.\S+', '', text)  # remove URLs
    text = re.sub(r'<.*?>', '', text)                    # remove HTML
    text = text.replace('reuters', '')                   # remove Reuters leakage
    text = re.sub(r'\s+', ' ', text).strip()             # normalize spaces
    return text


# ========================= Load Extra Dataset =========================
def load_news_dataset(dataset_dir):
    """Load News_dataset (Fake.csv + True.csv)"""
    dfs = []
    for fname, label in [('Fake.csv', 0), ('True.csv', 1)]:
        fpath = os.path.join(dataset_dir, fname)
        if os.path.exists(fpath):
            df_part = pd.read_csv(fpath)
            print(f"    {fname}: {len(df_part)} rows")
            if 'title' in df_part.columns:
                df_part['text'] = df_part.apply(
                    lambda r: f"{r['title']}. {r['text']}"
                    if pd.notna(r.get('title')) and str(r.get('title', '')).strip()
                    else str(r.get('text', '')), axis=1)
            df_part['label'] = label
            dfs.append(df_part[['text', 'label']])
    if dfs:
        df = pd.concat(dfs, ignore_index=True)
        print(f"    Total: {len(df)} (Fake: {(df['label']==0).sum()}, Real: {(df['label']==1).sum()})")
        return df
    return pd.DataFrame(columns=['text', 'label'])


# ========================= EDA Data Augmentation =========================
def _random_deletion(words, p=0.1):
    if len(words) <= 1: return words
    new = [w for w in words if random.random() > p]
    return new if new else [random.choice(words)]

def _random_swap(words, n=1):
    if len(words) < 2: return words
    w = words.copy()
    for _ in range(n):
        i, j = random.sample(range(len(w)), 2)
        w[i], w[j] = w[j], w[i]
    return w

def augment_dataset(texts, labels, num_aug=1, seed=42):
    """EDA augmentation on training set"""
    random.seed(seed); np.random.seed(seed)
    X, y = list(texts), list(labels)
    orig = len(X)
    for i in range(orig):
        words = X[i].split()
        if len(words) < 3: continue
        for _ in range(num_aug):
            method = random.choice(['delete', 'swap'])
            if method == 'delete':
                new_words = _random_deletion(words, p=0.1)
            else:
                new_words = _random_swap(words, n=max(1, len(words)//20))
            X.append(' '.join(new_words))
            y.append(y[i])
    combined = list(zip(X, y)); random.shuffle(combined)
    X_aug, y_aug = zip(*combined)
    print(f"    Augmentation: {orig} ‚Üí {len(X_aug)} (+{len(X_aug)-orig})")
    return np.array(X_aug), np.array(y_aug)


# ========================= Main Data Pipeline =========================
print("=" * 60)
print("  Loading & Preparing Data for DistilBERT")
print("=" * 60)

# 1. Load main dataset
print(f"\n  Main dataset: {DATA_PATH}")
df_main = pd.read_csv(DATA_PATH)
print(f"    Rows: {len(df_main)}")

# 2. Load extra dataset
df_list = [df_main]
if os.path.exists(EXTRA_DATA_DIR):
    print(f"\n  Extra dataset: {EXTRA_DATA_DIR}")
    df_extra = load_news_dataset(EXTRA_DATA_DIR)
    df_list.append(df_extra)

# 3. Merge & dedup
df = pd.concat(df_list, ignore_index=True)
before = len(df)
df['_key'] = df['text'].astype(str).str[:100]
df = df.drop_duplicates(subset='_key', keep='first').drop(columns='_key').reset_index(drop=True)
print(f"\n  Merged: {before} ‚Üí {len(df)} (removed {before-len(df)} duplicates)")
print(f"  Labels: Fake={int((df['label']==0).sum())}, Real={int((df['label']==1).sum())}")

# 4. Clean for BERT (minimal)
df = df.dropna(subset=['text', 'label'])
print(f"\n  Applying BERT-minimal cleaning...")
df['clean_text'] = df['text'].apply(clean_for_bert)
df = df[df['clean_text'].str.strip().str.len() > 0].reset_index(drop=True)
print(f"  After cleaning: {len(df)} rows")

# 5. Train/Val/Test split (80/10/10)
texts = df['clean_text'].values
labels = df['label'].values.astype(int)

X_train, X_temp, y_train, y_temp = train_test_split(
    texts, labels, test_size=0.2, random_state=42, stratify=labels)
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp)

print(f"\n  Split: Train={len(X_train)}, Val={len(X_val)}, Test={len(X_test)}")

# 6. Augment training set
print(f"\n  Data augmentation...")
X_train, y_train = augment_dataset(X_train, y_train, num_aug=1, seed=42)

print(f"\n  ‚úì Data ready!")
print(f"  Final: Train={len(X_train)}, Val={len(X_val)}, Test={len(X_test)}")

## Step 2: BERT Tokenization and PyTorch Dataset

In [None]:
# ============================================================
# Cell 4: Tokenization & Dataset Creation
# ============================================================
from transformers import DistilBertTokenizer
from torch.utils.data import Dataset, DataLoader
import torch
import time

MODEL_NAME = 'distilbert-base-uncased'
MAX_LEN = 256  # covers ~95% of articles; DistilBERT max is 512

tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
print(f"Tokenizer loaded: {MODEL_NAME} (vocab: {tokenizer.vocab_size})")


class FakeNewsDataset(Dataset):
    """PyTorch Dataset for BERT-style models"""
    def __init__(self, texts, labels, tokenizer, max_len=256):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            str(self.texts[idx]),
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'label': torch.tensor(self.labels[idx], dtype=torch.float)
        }


# Create datasets
print("Creating PyTorch datasets...")
t0 = time.time()

train_dataset = FakeNewsDataset(X_train, y_train, tokenizer, MAX_LEN)
val_dataset   = FakeNewsDataset(X_val, y_val, tokenizer, MAX_LEN)
test_dataset  = FakeNewsDataset(X_test, y_test, tokenizer, MAX_LEN)

print(f"  Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
print(f"  Max token length: {MAX_LEN}")
print(f"  Done in {time.time()-t0:.1f}s")

# Quick check
sample = train_dataset[0]
print(f"\n  Sample input_ids shape: {sample['input_ids'].shape}")
print(f"  Sample tokens: {tokenizer.decode(sample['input_ids'][:30])}")

## Step 3: Model, Trainer, Visualization & CoT Modules

In [None]:
# ============================================================
# Cell 5: DistilBERT Classifier Model
# ============================================================
"""
DistilBERT + Classification Head
- Load pre-trained DistilBERT (66M params)
- Extract [CLS] token representation
- Add Dropout + Dense layer for binary classification
- output_attentions=True for Chain-of-Thought explainability
"""
import torch
import torch.nn as nn
from transformers import DistilBertModel


class DistilBertClassifier(nn.Module):
    """
    DistilBERT for Fake News Detection
    Architecture: DistilBERT ‚Üí [CLS] ‚Üí Dropout ‚Üí Dense ‚Üí Sigmoid
    """
    def __init__(self, model_name='distilbert-base-uncased', dropout=0.2, freeze_bert=False):
        super().__init__()
        self.distilbert = DistilBertModel.from_pretrained(model_name)

        if freeze_bert:
            for param in self.distilbert.parameters():
                param.requires_grad = False
            print("  ‚òÖ DistilBERT layers FROZEN (only training classifier head)")

        hidden_size = self.distilbert.config.hidden_size  # 768

        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 1)
        )

    def forward(self, input_ids, attention_mask, return_attention=False):
        """
        Forward pass.
        Returns logits (and optionally attention weights for CoT).
        """
        outputs = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_attentions=return_attention
        )
        # [CLS] token is the first token
        cls_output = outputs.last_hidden_state[:, 0, :]  # (batch, 768)
        logits = self.classifier(cls_output).squeeze(-1)  # (batch,)

        if return_attention:
            # attentions: tuple of (batch, num_heads, seq_len, seq_len) per layer
            return logits, outputs.attentions
        return logits

    def count_parameters(self):
        total = sum(p.numel() for p in self.parameters())
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        return total, trainable


print("‚úì DistilBertClassifier defined")

In [None]:
# ============================================================
# Cell 6: Trainer (Train / Evaluate / Focal Loss)
# ============================================================
"""
Training loop for DistilBERT:
- AdamW with linear warmup (standard for Transformers)
- Gradient clipping
- Early stopping on validation accuracy
- Focal Loss option
"""
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import confusion_matrix as cm_func, roc_curve, auc


class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, logits, targets):
        bce = nn.functional.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        probs = torch.sigmoid(logits)
        p_t = probs * targets + (1 - probs) * (1 - targets)
        return (self.alpha * (1 - p_t) ** self.gamma * bce).mean()


def train_one_epoch(model, dataloader, optimizer, scheduler, criterion, device):
    model.train()
    total_loss = 0
    all_preds, all_labels = [], []

    for batch in dataloader:
        ids = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()
        logits = model(ids, mask)
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        total_loss += loss.item()
        preds = (torch.sigmoid(logits) > 0.5).float()
        all_preds.extend(preds.detach().cpu().numpy())
        all_labels.extend(labels.detach().cpu().numpy())

    return total_loss / len(dataloader), accuracy_score(all_labels, all_preds)


def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds, all_labels, all_probs = [], [], []

    with torch.no_grad():
        for batch in dataloader:
            ids = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            logits = model(ids, mask)
            loss = criterion(logits, labels)
            total_loss += loss.item()

            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    metrics = {
        'loss': total_loss / len(dataloader),
        'accuracy': accuracy_score(all_labels, all_preds),
        'precision': precision_score(all_labels, all_preds, zero_division=0),
        'recall': recall_score(all_labels, all_preds, zero_division=0),
        'f1': f1_score(all_labels, all_preds, zero_division=0),
        'confusion_matrix': cm_func(all_labels, all_preds)
    }
    return metrics, all_preds, all_labels, all_probs


def train_bert_model(model, train_ds, val_ds, test_ds,
                     criterion, device,
                     num_epochs=3, batch_size=32, learning_rate=2e-5,
                     weight_decay=0.01, warmup_ratio=0.1):
    """Full training loop with warmup scheduler and early stopping"""

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    # AdamW (standard for Transformers)
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=learning_rate, weight_decay=weight_decay)

    # Linear warmup scheduler
    total_steps = len(train_loader) * num_epochs
    warmup_steps = int(total_steps * warmup_ratio)
    scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)

    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'test_acc': []}
    best_val_acc = 0
    best_model_state = None
    patience_counter = 0
    early_stop_patience = 3

    for epoch in range(num_epochs):
        t0 = time.time()
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, scheduler, criterion, device)
        val_metrics, _, _, _ = evaluate_model(model, val_loader, criterion, device)
        test_metrics, _, _, _ = evaluate_model(model, test_loader, criterion, device)

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_metrics['loss'])
        history['val_acc'].append(val_metrics['accuracy'])
        history['test_acc'].append(test_metrics['accuracy'])

        lr_now = optimizer.param_groups[0]['lr']
        elapsed = time.time() - t0

        print(f"  Epoch [{epoch+1}/{num_epochs}] ({elapsed:.0f}s) "
              f"Loss: {train_loss:.4f} | Train: {train_acc:.4f} | "
              f"Val: {val_metrics['accuracy']:.4f}(F1:{val_metrics['f1']:.4f}) | "
              f"Test: {test_metrics['accuracy']:.4f} | LR: {lr_now:.2e}")

        if val_metrics['accuracy'] > best_val_acc:
            best_val_acc = val_metrics['accuracy']
            best_model_state = {k: v.clone() for k, v in model.state_dict().items()}
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= early_stop_patience:
            print(f"  Early stopping at epoch {epoch+1}")
            break

    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    return history, model


print("‚úì Trainer module defined")

In [None]:
# ============================================================
# Cell 7: Visualization Module (All 8 Figures + ROC-AUC)
# ============================================================
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix as cm_func, roc_curve, auc

matplotlib.rcParams['font.family'] = ['DejaVu Sans']
import warnings
warnings.filterwarnings('ignore', message='.*findfont.*')
matplotlib.rcParams['axes.unicode_minus'] = False


def plot_training_curves(history, title="Training Curves", save_path=None):
    """Fig 1/2: Loss + Train Acc + Test Acc"""
    epochs = range(1, len(history['train_loss']) + 1)
    fig, ax1 = plt.subplots(figsize=(10, 6))
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', color='tab:red', fontsize=12)
    l1 = ax1.plot(epochs, history['train_loss'], 'r-o', label='Train Loss', markersize=5, linewidth=2)
    if 'val_loss' in history:
        l1b = ax1.plot(epochs, history['val_loss'], 'r--s', label='Val Loss', markersize=4, linewidth=1.5, alpha=0.7)
    ax1.tick_params(axis='y', labelcolor='tab:red')
    ax1.grid(True, alpha=0.3)
    ax2 = ax1.twinx()
    ax2.set_ylabel('Accuracy', fontsize=12)
    l2 = ax2.plot(epochs, history['train_acc'], 'b-s', label='Train Acc', markersize=5, linewidth=2)
    l3 = ax2.plot(epochs, history['val_acc'], 'c-D', label='Val Acc', markersize=4, linewidth=1.5)
    l4 = ax2.plot(epochs, history['test_acc'], 'g-^', label='Test Acc', markersize=5, linewidth=2)
    ax2.set_ylim([0.4, 1.05])
    lines = l1 + l2 + l3 + l4
    ax1.legend(lines, [l.get_label() for l in lines], loc='center right', fontsize=9)
    plt.title(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight'); print(f"    Saved: {save_path}")
    plt.show()


def plot_lr_comparison(histories_dict, title="LR Comparison", save_path=None):
    """Fig 3/4: Different LR comparison"""
    fig, axes = plt.subplots(1, 3, figsize=(20, 6))
    colors = plt.cm.tab10(np.linspace(0, 1, len(histories_dict)))
    for idx, (lr, h) in enumerate(histories_dict.items()):
        ep = range(1, len(h['train_loss']) + 1)
        c = colors[idx]
        axes[0].plot(ep, h['train_loss'], color=c, marker='o', ms=3, lw=1.5, label=f'LR={lr}')
        axes[1].plot(ep, h['train_acc'], color=c, marker='s', ms=3, lw=1.5, label=f'LR={lr}')
        axes[2].plot(ep, h['test_acc'], color=c, marker='^', ms=3, lw=1.5, label=f'LR={lr}')
    for ax, yl, t in zip(axes, ['Loss', 'Accuracy', 'Accuracy'],
                          ['Training Loss', 'Training Accuracy', 'Test Accuracy']):
        ax.set_xlabel('Epoch'); ax.set_ylabel(yl); ax.set_title(t); ax.legend(fontsize=9); ax.grid(True, alpha=0.3)
    fig.suptitle(title, fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight'); print(f"    Saved: {save_path}")
    plt.show()


def plot_batch_size_comparison(histories_dict, title="Batch Size Comparison", save_path=None):
    """Fig 5/6: Different batch size comparison"""
    fig, axes = plt.subplots(1, 3, figsize=(20, 6))
    colors = plt.cm.tab10(np.linspace(0, 1, len(histories_dict)))
    for idx, (bs, h) in enumerate(histories_dict.items()):
        ep = range(1, len(h['train_loss']) + 1)
        c = colors[idx]
        axes[0].plot(ep, h['train_loss'], color=c, marker='o', ms=3, lw=1.5, label=f'BS={bs}')
        axes[1].plot(ep, h['train_acc'], color=c, marker='s', ms=3, lw=1.5, label=f'BS={bs}')
        axes[2].plot(ep, h['test_acc'], color=c, marker='^', ms=3, lw=1.5, label=f'BS={bs}')
    for ax, yl, t in zip(axes, ['Loss', 'Accuracy', 'Accuracy'],
                          ['Training Loss', 'Training Accuracy', 'Test Accuracy']):
        ax.set_xlabel('Epoch'); ax.set_ylabel(yl); ax.set_title(t); ax.legend(fontsize=9); ax.grid(True, alpha=0.3)
    fig.suptitle(title, fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight'); print(f"    Saved: {save_path}")
    plt.show()


def plot_predictions_table(texts, true_labels, pred_labels, n=100, save_path=None):
    """Fig 7: Prediction results table + grid"""
    from matplotlib.colors import ListedColormap
    n = min(n, len(texts))
    fig, axes = plt.subplots(1, 2, figsize=(16, 8), gridspec_kw={'width_ratios': [1, 3]})
    grid_size = int(np.ceil(np.sqrt(n)))
    grid = np.zeros((grid_size, grid_size))
    for i in range(n):
        r, c = i // grid_size, i % grid_size
        grid[r][c] = 1 if true_labels[i] == pred_labels[i] else -1
    cmap = ListedColormap(['#ff6b6b', 'white', '#51cf66'])
    axes[0].imshow(grid, cmap=cmap, vmin=-1, vmax=1, aspect='equal')
    axes[0].set_title(f'Prediction Grid ({n})\nGreen=Correct, Red=Wrong', fontsize=11)
    axes[1].axis('off')
    show_n = min(20, n)
    rows, colors_list = [], []
    label_map = {0: 'Fake', 1: 'Real'}
    for i in range(show_n):
        txt = texts[i][:55] + '...' if len(str(texts[i])) > 55 else str(texts[i])
        ok = 'Yes' if true_labels[i] == pred_labels[i] else 'No'
        rows.append([i+1, txt, label_map[int(true_labels[i])], label_map[int(pred_labels[i])], ok])
        colors_list.append(['#d4edda' if ok == 'Yes' else '#f8d7da'] * 5)
    t = axes[1].table(cellText=rows, colLabels=['#', 'Text', 'True', 'Pred', 'OK'],
                      cellColours=colors_list, loc='upper center', cellLoc='left')
    t.auto_set_font_size(False); t.set_fontsize(8); t.auto_set_column_width([0,1,2,3,4])
    correct = sum(1 for i in range(n) if true_labels[i] == pred_labels[i])
    fig.suptitle(f'Predictions: {correct}/{n} correct ({100*correct/n:.1f}%)', fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight'); print(f"    Saved: {save_path}")
    plt.show()


def plot_confusion_matrix(true_labels, pred_labels, save_path=None):
    """Fig 8: Confusion Matrix"""
    cm = cm_func(true_labels, pred_labels)
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Fake (0)', 'Real (1)'], yticklabels=['Fake (0)', 'Real (1)'],
                ax=ax, annot_kws={"size": 16})
    ax.set_xlabel('Predicted', fontsize=12); ax.set_ylabel('Actual', fontsize=12)
    ax.set_title('Confusion Matrix', fontsize=14, fontweight='bold')
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight'); print(f"    Saved: {save_path}")
    plt.show()


def plot_roc_auc(true_labels, pred_probs, save_path=None):
    """ROC-AUC Curve (from reference notebook)"""
    fpr, tpr, _ = roc_curve(true_labels, pred_probs)
    roc_auc = auc(fpr, tpr)
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC (AUC = {roc_auc:.4f})')
    ax.plot([0, 1], [0, 1], 'navy', lw=2, linestyle='--')
    ax.set_xlim([0, 1]); ax.set_ylim([0, 1.05])
    ax.set_xlabel('False Positive Rate', fontsize=12)
    ax.set_ylabel('True Positive Rate', fontsize=12)
    ax.set_title('ROC Curve', fontsize=14, fontweight='bold')
    ax.legend(loc='lower right', fontsize=12); ax.grid(True, alpha=0.3)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight'); print(f"    Saved: {save_path}")
    plt.show()
    return roc_auc

print("‚úì Visualization module defined")

In [None]:
# ============================================================
# Cell 8: Chain-of-Thought Reasoning (BERT Attention)
# ============================================================
"""
CoT using DistilBERT's multi-head attention:
- Extract attention from ALL layers and heads
- Average across heads/layers ‚Üí word-level importance
- Combine with rule-based feature analysis
"""

SENSATIONAL_WORDS = {
    'shocking', 'unbelievable', 'breaking', 'exclusive', 'urgent',
    'bombshell', 'horrifying', 'terrifying', 'scandal', 'outrage',
    'devastating', 'explosive', 'stunning', 'alarming', 'incredible',
    'exposed', 'secret', 'leaked', 'conspiracy', 'hoax', 'destroyed',
    'slammed', 'blasted', 'ripped', 'epic', 'insane'
}
CREDIBILITY_PHRASES = [
    'according to', 'research shows', 'study finds', 'officials said',
    'reuters', 'associated press', 'confirmed by', 'evidence suggests',
    'data shows', 'report says', 'investigation found', 'spokesperson said',
    'published in', 'university of', 'official statement', 'government report'
]
EMOTIONAL_WORDS = {
    'hate', 'love', 'angry', 'furious', 'amazing', 'terrible',
    'disgusting', 'wonderful', 'horrible', 'fantastic', 'awful',
    'outrageous', 'evil', 'hero', 'villain', 'miracle', 'disaster',
    'tragic', 'brilliant', 'stupid', 'genius', 'idiot', 'corrupt'
}
CLICKBAIT_PATTERNS = [
    'you won', 'believe', 'click here', 'share this', 'going viral',
    'mind blowing', 'what happened next', 'will shock', 'doctors hate',
    'one weird trick', 'exposed', 'they don want'
]


class BertChainOfThought:
    """Chain-of-Thought analyzer using DistilBERT attention weights"""

    def __init__(self, model, tokenizer, device, max_len=256):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.max_len = max_len

    def analyze(self, text):
        self.model.eval()
        encoding = self.tokenizer(
            text, max_length=self.max_len, padding='max_length',
            truncation=True, return_tensors='pt')
        ids = encoding['input_ids'].to(self.device)
        mask = encoding['attention_mask'].to(self.device)

        with torch.no_grad():
            logits, attentions = self.model(ids, mask, return_attention=True)

        prob = torch.sigmoid(logits).item()
        prediction = 'Real' if prob > 0.5 else 'Fake'
        confidence = prob if prob > 0.5 else 1 - prob

        # Average attention: all layers, all heads ‚Üí [CLS] row
        # attentions: tuple of (1, num_heads, seq_len, seq_len) for each layer
        attn_all = torch.stack(attentions)  # (layers, 1, heads, seq, seq)
        cls_attn = attn_all[:, 0, :, 0, :].mean(dim=(0, 1))  # avg over layers and heads ‚Üí (seq_len,)
        cls_attn = cls_attn.cpu().numpy()

        # Map back to tokens
        tokens = self.tokenizer.convert_ids_to_tokens(ids[0].cpu())
        real_len = mask[0].sum().item()
        word_attn = {}
        for i in range(1, real_len - 1):  # skip [CLS] and [SEP]
            token = tokens[i]
            if token.startswith('##'):
                continue  # skip subword continuations
            word_attn[token] = max(word_attn.get(token, 0), cls_attn[i])
        top_words = sorted(word_attn.items(), key=lambda x: x[1], reverse=True)[:10]

        features = self._analyze_features(text)
        reasoning = self._build_reasoning(prediction, confidence, top_words, features)

        return {
            'text_preview': text[:300] + '...' if len(text) > 300 else text,
            'prediction': prediction,
            'confidence': f"{confidence:.2%}",
            'probability': prob,
            'top_attention_words': [(w, f"{a:.4f}") for w, a in top_words[:5]],
            'features': features,
            'reasoning': reasoning
        }

    def _analyze_features(self, text):
        words = text.lower().split()
        text_lower = text.lower()
        sens = [w for w in words if w in SENSATIONAL_WORDS]
        cred = [p for p in CREDIBILITY_PHRASES if p in text_lower]
        emot = [w for w in words if w in EMOTIONAL_WORDS]
        click = [p for p in CLICKBAIT_PATTERNS if p in text_lower]
        def level(items, hi=2, lo=1):
            return 'HIGH' if len(items) >= hi else 'MEDIUM' if len(items) >= lo else 'LOW'
        return {
            'sensational': (level(sens), sens[:5]),
            'credibility': (level(cred), cred[:3]),
            'emotional': (level(emot), emot[:5]),
            'clickbait': (level(click), click[:3]),
            'word_count': len(words)
        }

    def _build_reasoning(self, prediction, confidence, top_words, features):
        lines = ["=" * 50, "Step 1 - Text Feature Analysis:"]
        for key, (lvl, items) in [('sensational', features['sensational']),
                                   ('credibility', features['credibility']),
                                   ('emotional', features['emotional']),
                                   ('clickbait', features['clickbait'])]:
            lines.append(f"  [{key.title():20s}] {lvl}")
            if items:
                lines.append(f"    Found: {', '.join(str(x) for x in items)}")
        lines.append(f"  [{'Word Count':20s}] {features['word_count']}")

        lines.append("\nStep 2 - DistilBERT Attention Key Tokens:")
        for token, weight in top_words[:5]:
            bar = '‚ñà' * int(float(weight) * 200)
            lines.append(f"  '{token}' [{weight}] {bar}")

        lines.append("\nStep 3 - Reasoning Chain:")
        reasons = []
        if prediction == 'Fake':
            if features['sensational'][0] != 'LOW':
                reasons.append("Sensational/exaggerated language detected ‚Üí common in fabricated news.")
            if features['credibility'][0] == 'LOW':
                reasons.append("No references to credible sources ‚Üí reduces reliability.")
            if features['emotional'][0] != 'LOW':
                reasons.append("Highly emotional language ‚Üí often used to manipulate readers.")
            if features['clickbait'][0] != 'LOW':
                reasons.append("Clickbait patterns present ‚Üí prioritizes engagement over accuracy.")
            if not reasons:
                reasons.append("DistilBERT's learned contextual patterns match fake news characteristics.")
        else:
            if features['credibility'][0] != 'LOW':
                reasons.append("References to credible sources ‚Üí consistent with legitimate news.")
            if features['sensational'][0] == 'LOW':
                reasons.append("Neutral, factual language ‚Üí consistent with professional journalism.")
            if features['emotional'][0] == 'LOW':
                reasons.append("Objective tone ‚Üí no excessive emotional manipulation.")
            if not reasons:
                reasons.append("DistilBERT's learned contextual patterns match real news characteristics.")
        for i, r in enumerate(reasons, 1):
            lines.append(f"  {i}. {r}")

        lines.extend([f"\n{'=' * 50}",
                      f"Conclusion: [{prediction}] with {confidence:.2%} confidence.",
                      "=" * 50])
        return '\n'.join(lines)

print("‚úì Chain-of-Thought module defined")

In [None]:
# ============================================================
# Cell 9: Configuration & Experiment Helper
# ============================================================
import time

CONFIG = {
    'model_name': 'distilbert-base-uncased',
    'max_len': 256,
    'num_epochs': 4,       # BERT converges fast (2-4 epochs is standard)
    'batch_size': 32,
    'learning_rate': 2e-5, # Standard BERT fine-tuning LR
    'weight_decay': 0.01,
    'warmup_ratio': 0.1,
    'dropout': 0.2,
    'freeze_bert': False,  # Set True for faster training (only train head)
}


def run_bert_experiment(train_ds, val_ds, test_ds,
                        loss_fn='bce', learning_rate=None,
                        batch_size=None, num_epochs=None,
                        freeze_bert=None):
    """Run a single DistilBERT experiment and return history + model"""
    lr = learning_rate or CONFIG['learning_rate']
    bs = batch_size or CONFIG['batch_size']
    ep = num_epochs or CONFIG['num_epochs']
    freeze = freeze_bert if freeze_bert is not None else CONFIG['freeze_bert']

    model = DistilBertClassifier(
        model_name=CONFIG['model_name'],
        dropout=CONFIG['dropout'],
        freeze_bert=freeze
    ).to(DEVICE)

    total, trainable = model.count_parameters()
    print(f"  Model: {total/1e6:.1f}M total, {trainable/1e6:.1f}M trainable")

    if loss_fn == 'bce':
        criterion = nn.BCEWithLogitsLoss()
    elif loss_fn == 'focal':
        criterion = FocalLoss(alpha=0.25, gamma=2.0)
    else:
        raise ValueError(f"Unknown loss: {loss_fn}")
    criterion = criterion.to(DEVICE)

    history, best_model = train_bert_model(
        model, train_ds, val_ds, test_ds, criterion, DEVICE,
        num_epochs=ep, batch_size=bs, learning_rate=lr,
        weight_decay=CONFIG['weight_decay'], warmup_ratio=CONFIG['warmup_ratio'])

    return history, best_model


print("‚úì Configuration ready")
print(f"  Device: {DEVICE}")
print(f"  Config: {CONFIG}")

## Step 4: Run Experiments (Fig 1-6)

DistilBERT converges in 2-4 epochs (vs 10-20 for LSTM). Each experiment takes ~5-10 min on T4 GPU.

In [None]:
# ============================================================
# Cell 10: Experiment 1 ‚Äî BCE Loss (Default Config) ‚Üí Fig 1
# ============================================================
print("=" * 60)
print("[Exp 1] DistilBERT + BCE Loss (default config)")
print("=" * 60)

t0 = time.time()
history_bce, model_bce = run_bert_experiment(
    train_dataset, val_dataset, test_dataset, loss_fn='bce')
print(f"  ‚úì Done in {(time.time()-t0)/60:.1f} min")

plot_training_curves(
    history_bce,
    title="Fig 1: DistilBERT + BCE Loss ‚Äî Default Config",
    save_path=os.path.join(OUTPUT_DIR, "fig1_bert_bce.png")
)

In [None]:
# ============================================================
# Cell 11: Experiment 2 ‚Äî Focal Loss (Default Config) ‚Üí Fig 2
# ============================================================
print("=" * 60)
print("[Exp 2] DistilBERT + Focal Loss (default config)")
print("=" * 60)

t0 = time.time()
history_focal, model_focal = run_bert_experiment(
    train_dataset, val_dataset, test_dataset, loss_fn='focal')
print(f"  ‚úì Done in {(time.time()-t0)/60:.1f} min")

plot_training_curves(
    history_focal,
    title="Fig 2: DistilBERT + Focal Loss ‚Äî Default Config",
    save_path=os.path.join(OUTPUT_DIR, "fig2_bert_focal.png")
)

In [None]:
# ============================================================
# Cell 12: Experiment 3 ‚Äî Learning Rate Comparison ‚Üí Fig 3 & 4
# ============================================================
print("=" * 60)
print("[Exp 3] Learning Rate Comparison")
print("=" * 60)

# BERT-specific LR range (much smaller than LSTM)
learning_rates = [5e-5, 2e-5, 1e-5, 5e-6]

# --- Fig 3: BCE + different LR ---
print("\n  === BCE Loss ===")
lr_histories_bce = {}
for lr in learning_rates:
    print(f"\n  LR = {lr}:")
    t0 = time.time()
    h, _ = run_bert_experiment(train_dataset, val_dataset, test_dataset,
                               loss_fn='bce', learning_rate=lr)
    lr_histories_bce[lr] = h
    print(f"    ({(time.time()-t0)/60:.1f} min)")

plot_lr_comparison(
    lr_histories_bce,
    title="Fig 3: DistilBERT + BCE ‚Äî Learning Rate Comparison",
    save_path=os.path.join(OUTPUT_DIR, "fig3_bert_lr_bce.png")
)

# --- Fig 4: Focal + different LR ---
print("\n  === Focal Loss ===")
lr_histories_focal = {}
for lr in learning_rates:
    print(f"\n  LR = {lr}:")
    t0 = time.time()
    h, _ = run_bert_experiment(train_dataset, val_dataset, test_dataset,
                               loss_fn='focal', learning_rate=lr)
    lr_histories_focal[lr] = h
    print(f"    ({(time.time()-t0)/60:.1f} min)")

plot_lr_comparison(
    lr_histories_focal,
    title="Fig 4: DistilBERT + Focal ‚Äî Learning Rate Comparison",
    save_path=os.path.join(OUTPUT_DIR, "fig4_bert_lr_focal.png")
)

In [None]:
# ============================================================
# Cell 13: Experiment 4 ‚Äî Batch Size Comparison ‚Üí Fig 5 & 6
# ============================================================
print("=" * 60)
print("[Exp 4] Batch Size Comparison")
print("=" * 60)

batch_sizes = [16, 32, 64]

# --- Fig 5: BCE + different BS ---
print("\n  === BCE Loss ===")
bs_histories_bce = {}
for bs in batch_sizes:
    print(f"\n  Batch Size = {bs}:")
    t0 = time.time()
    h, _ = run_bert_experiment(train_dataset, val_dataset, test_dataset,
                               loss_fn='bce', batch_size=bs)
    bs_histories_bce[bs] = h
    print(f"    ({(time.time()-t0)/60:.1f} min)")

plot_batch_size_comparison(
    bs_histories_bce,
    title="Fig 5: DistilBERT + BCE ‚Äî Batch Size Comparison",
    save_path=os.path.join(OUTPUT_DIR, "fig5_bert_bs_bce.png")
)

# --- Fig 6: Focal + different BS ---
print("\n  === Focal Loss ===")
bs_histories_focal = {}
for bs in batch_sizes:
    print(f"\n  Batch Size = {bs}:")
    t0 = time.time()
    h, _ = run_bert_experiment(train_dataset, val_dataset, test_dataset,
                               loss_fn='focal', batch_size=bs)
    bs_histories_focal[bs] = h
    print(f"    ({(time.time()-t0)/60:.1f} min)")

plot_batch_size_comparison(
    bs_histories_focal,
    title="Fig 6: DistilBERT + Focal ‚Äî Batch Size Comparison",
    save_path=os.path.join(OUTPUT_DIR, "fig6_bert_bs_focal.png")
)

## Step 5: Final Evaluation (Fig 7, 8 + ROC-AUC + Classification Report)

In [None]:
# ============================================================
# Cell 14: Final Evaluation ‚Üí Fig 7, 8, ROC-AUC, Classification Report
# ============================================================
from sklearn.metrics import classification_report

print("=" * 60)
print("[Final Evaluation] Best DistilBERT + BCE Model")
print("=" * 60)

# Evaluate on test set
criterion_eval = nn.BCEWithLogitsLoss().to(DEVICE)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)
test_metrics, all_preds, all_labels, all_probs = evaluate_model(
    model_bce, test_loader, criterion_eval, DEVICE)

print(f"\n  ‚ïî‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïó")
print(f"  ‚ïë  Final Test Results (DistilBERT)     ‚ïë")
print(f"  ‚ï†‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ï£")
print(f"  ‚ïë  Accuracy:   {test_metrics['accuracy']:.4f}                 ‚ïë")
print(f"  ‚ïë  Precision:  {test_metrics['precision']:.4f}                 ‚ïë")
print(f"  ‚ïë  Recall:     {test_metrics['recall']:.4f}                 ‚ïë")
print(f"  ‚ïë  F1 Score:   {test_metrics['f1']:.4f}                 ‚ïë")
print(f"  ‚ïö‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïù")

# Detailed Classification Report (from reference notebook)
print("\n  --- Classification Report ---")
print(classification_report(all_labels, all_preds,
                            target_names=['Fake News (0)', 'Real News (1)']))

# Fig 7: Prediction Table
plot_predictions_table(
    list(X_test[:100]), all_labels[:100], all_preds[:100], n=100,
    save_path=os.path.join(OUTPUT_DIR, "fig7_bert_predictions.png")
)

# Fig 8: Confusion Matrix
plot_confusion_matrix(
    all_labels, all_preds,
    save_path=os.path.join(OUTPUT_DIR, "fig8_bert_confusion.png")
)

# ROC-AUC Curve (from reference notebook)
roc_score = plot_roc_auc(
    all_labels, all_probs,
    save_path=os.path.join(OUTPUT_DIR, "fig9_bert_roc_auc.png")
)
print(f"\n  ROC-AUC Score: {roc_score:.4f}")

## Step 6: Chain-of-Thought Reasoning Demo

Uses DistilBERT's multi-head self-attention (6 layers x 12 heads = 72 attention heads) to explain predictions.

In [None]:
# ============================================================
# Cell 15: Chain-of-Thought Reasoning Demo
# ============================================================
print("=" * 60)
print("[CoT] Chain-of-Thought Reasoning Demo")
print("=" * 60)

cot = BertChainOfThought(model_bce, tokenizer, DEVICE, MAX_LEN)

# Analyze 5 test samples
for i in range(5):
    print(f"\n{'‚îÄ' * 60}")
    print(f"  Sample {i+1}:")
    print(f"{'‚îÄ' * 60}")
    result = cot.analyze(str(X_test[i]))
    print(f"\nText: {result['text_preview']}")
    print(f"\nPrediction: {result['prediction']} (Confidence: {result['confidence']})")
    print(f"Top Attention: {result['top_attention_words']}")
    print(f"\n{result['reasoning']}")

print(f"\n{'‚ïê' * 60}")
print("  ‚úì Chain-of-Thought analysis complete")
print(f"{'‚ïê' * 60}")

## Step 7: Save Outputs & Model

In [None]:
# ============================================================
# Cell 16: Save Outputs to Google Drive
# ============================================================
import shutil

# Save model weights
model_save_path = os.path.join(OUTPUT_DIR, "best_distilbert_model.pt")
torch.save(model_bce.state_dict(), model_save_path)
print(f"  ‚úì Model saved: {model_save_path} ({os.path.getsize(model_save_path)/1e6:.1f} MB)")

# Copy everything to Google Drive
drive_output = os.path.join(DATA_ROOT, 'outputs_bert')
os.makedirs(drive_output, exist_ok=True)

count = 0
for fname in os.listdir(OUTPUT_DIR):
    src = os.path.join(OUTPUT_DIR, fname)
    dst = os.path.join(drive_output, fname)
    shutil.copy2(src, dst)
    print(f"  ‚úì {fname}")
    count += 1

print(f"\n  All {count} files saved to: {drive_output}")

print("\n" + "=" * 60)
print("  ‚úÖ All DistilBERT experiments completed!")
print("  Generated: fig1~fig9 + model weights")
print("  Model: DistilBERT (66M params)")
print(f"  Test Accuracy: {test_metrics['accuracy']:.4f}")
print(f"  ROC-AUC: {roc_score:.4f}")
print("=" * 60)