In [None]:
import os, nltk, certifi
from pathlib import Path
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset, RandomSampler, SequentialSampler
from transformers import BertTokenizerFast, BertForSequenceClassification, get_linear_schedule_with_warmup
from torch.optim import AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import random
import shap
import nlpaug.augmenter.word as naw

# HTTPS certs (keep)
os.environ['SSL_CERT_FILE'] = certifi.where()
# NLTK local setup (keep if you use augmentation/linguistic tools)
nltk.data.path.clear()
nltk.data.path.append('nltk_data')
nltk.download('averaged_perceptron_tagger', download_dir='nltk_data')
nltk.download('averaged_perceptron_tagger_eng', download_dir='nltk_data')
nltk.download('wordnet', download_dir='nltk_data')
nltk.download('punkt', download_dir='nltk_data')
nltk.download('omw-1.4', download_dir='nltk_data')

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(42)

In [None]:
def load_and_merge_datasets():
    dfs = []  # Empty list for later dataset saving
    files_info = [
        'True.csv',
        'Fake.csv'
    ]
    from pathlib import Path

    ROOT = Path(__file__).resolve().parent if "__file__" in globals() else Path.cwd()
    DATASETS_DIR = Path(os.getenv("DATA_DIR", ROOT / "Datasets")).expanduser().resolve()
    if not DATASETS_DIR.is_dir():
        raise FileNotFoundError(f"Datasets folder not found: {DATASETS_DIR}")

    base_path = str(DATASETS_DIR)
    for fname in files_info:
        df = None
        path = os.path.join(base_path, fname)
        for sep in [',', ';']:
            for encoding in ['utf-8', 'latin1']:
                try:
                    df = pd.read_csv(path, encoding=encoding, sep=sep, on_bad_lines='skip', low_memory=False)
                    if df.shape[1] <= 1:
                        df = None
                        continue
                    break
                except Exception:
                    df = None
            if df is not None:
                break
        if df is None:
            print(f"Skipping this file: {fname}")
            continue

        cols = {c.lower(): c for c in df.columns}
        title_col = next((cols[c] for c in cols if "title" in c), None)
        text_col  = next((cols[c] for c in cols if "text" in c), None)
        label_col = next((cols[c] for c in cols if "label" in c), None)

        if title_col and text_col:
            df['input'] = df[title_col].astype(str).fillna('') + ' [SEP] ' + df[text_col].astype(str).fillna('')
        elif title_col:
            df['input'] = df[title_col].astype(str).fillna('')
        elif text_col:
            df['input'] = df[text_col].astype(str).fillna('')
        else:
            print(f"SKIPPING: {fname} (no title/text col)")
            continue
        if label_col is None:
            print(f"SKIPPING: {fname} (no label col)")
            continue

        df['label'] = df[label_col]
        df = df[['input', 'label']].dropna()

        valid_labels = {'0', '1', 'true', 'false', 'fake', 'real'}
        df = df[df['label'].astype(str).str.strip().str.lower().isin(valid_labels)]
        try:
            df['label'] = df['label'].astype(int)
        except Exception:
            df['label'] = df['label'].map(
                lambda x: 1 if str(x).strip().lower() in ['1', 'fake', 'false']
                else (0 if str(x).strip().lower() in ['0', 'true', 'real']
                else np.nan)
            )
            df = df.dropna(subset=['label'])
            df['label'] = df['label'].astype(int)

        df = df[df['label'].isin([0, 1])]
        dfs.append(df)

    if len(dfs) == 0:
        raise Exception("No valid datasets loaded!")

    data = pd.concat(dfs, ignore_index=True)
    initial_count = data.shape[0]
    data['input'] = data['input'].str.strip()
    data['input_lower'] = data['input'].str.lower()
    duplicate_count = data.duplicated('input_lower').sum()
    data = data.drop_duplicates('input_lower').drop(columns=['input_lower'])
    final_count = data.shape[0]
    print("="*50)
    print(f"Initial rows before duplicate removal: {initial_count}")
    print(f"Duplicate entries found and removed: {duplicate_count}")
    print(f"Rows left after removing duplicates: {final_count}")
    print(f"Shape after merge and duplicate removal: {data.shape}")
    print("="*50)
    print(data.head())

    data.to_csv("Datasets/main_data.csv", index=False)
    print("Merged data saved as main_data.csv")

    # --------- FEEDBACK INTEGRATION ---------
    FEEDBACK_PATH = str(DATASETS_DIR / "user_feedback.csv")
    MAIN_DATA_PATH = str(DATASETS_DIR / "main_data.csv")

    if os.path.exists(FEEDBACK_PATH):
        print("="*50)
        print("Feedback file found. Integrating corrections...")
        main_data = pd.read_csv(MAIN_DATA_PATH)
        feedback = pd.read_csv(FEEDBACK_PATH)
        print(f"Shape of the the feedback file is: {feedback.shape}")

        
        feedback['input'] = feedback['input'].astype(str)

        if feedback['label'].dtype.kind not in 'biufc':
            label_map = {'Fake': 1, 'True': 0, 1: 1, 0: 0}
            feedback['label'] = feedback['label'].map(label_map)

        feedback = feedback[['input', 'label']].dropna()
        main_data = main_data[~main_data['input'].isin(feedback['input'])]

        updated_data = pd.concat([main_data, feedback], ignore_index=True)
        updated_data = updated_data.sample(frac=1, random_state=42).reset_index(drop=True)
        updated_data.to_csv(MAIN_DATA_PATH, index=False)
        print("Feedback integrated and main_data.csv updated.")
        data = updated_data

    data = data.sample(frac=1, random_state=42).reset_index(drop=True)
    return data

data = load_and_merge_datasets()
print(f"after merging the feedback data and main data = {data.shape}")
print("="*50)

In [None]:
true_count = (data['label'] == 0).sum() # true label
fake_count = (data['label'] == 1).sum() # fake label

print(f"Before balancing: \nTrue label news = {true_count} \nFake label news = {fake_count}")

In [None]:
# Autodetect and balance by augmenting the minority class with WordNet synonyms
aug = naw.SynonymAug(aug_src='wordnet', aug_max=10)

# Current class counts
counts = data['label'].value_counts()
if len(counts) != 2: # check if only 1 and 0 exists.
    raise ValueError("Expected binary labels 0/1. Got: " + str(counts.to_dict()))

minority_label = counts.idxmin() # finds the label with the fewest rows.
majority_label = counts.idxmax() # finds the label with the most rows.

n_to_augment = counts[majority_label] - counts[minority_label] # calculate how many rows need to be augmented.
print("Detecting minority class....")
if minority_label==1:
    print(f"Minority class: Fake | 1 | Need to add: {n_to_augment}")
else:
    print(f"Minority class: True | 0 | Need to add: {n_to_augment}")

if n_to_augment > 0: # only augment if the dataset is imbalanced.
    minority_df = data[data['label'] == minority_label].copy() # takes only the copy of minority rows.
    augmented_texts = []

    for i in range(n_to_augment): # loop upto difference in number of rows.
        original_text = minority_df.sample(1, random_state=42+i)['input'].values[0] # picks one random minority text each time, with a changing seed.
        new_text = aug.augment(original_text)           # may return str or list
        if isinstance(new_text, list): # if the augmenter returns a list, take the first string
            new_text = new_text[0]
        augmented_texts.append(new_text)

    aug_df = pd.DataFrame({'input': augmented_texts, 'label': minority_label}) # turns the new texts into a DataFrame with the minority label.
    data = pd.concat([data, aug_df], ignore_index=True) # adds the new rows to the dataset.
    data = data.sample(frac=1, random_state=42).reset_index(drop=True) # shuffles all rows and resets row numbers.
else:
    print("Already balanced. No augmentation applied.")

print("After balancing:", data['label'].value_counts())

In [None]:
# Visualize balance for your new, balanced dataset!
plt.figure(figsize=(5,5))
plt.pie(
    [len(data[data.label==0]), len(data[data.label==1])],
    labels=['True', 'Fake'],
    autopct='%1.1f%%',
    explode=[0.05,0.05],
    colors=['skyblue', 'salmon']
)
plt.title("Label Distribution After Augmentation (0=Real, 1=Fake)")
plt.show()

In [None]:
def normalize_input_cell(cell_value):
    """
    - Cleans and standardizes a cell from the 'input' column:
    - Flattens lists into a single string
    - Removes None/nan/empty-like values
    - Strips extra whitespace
    """

    # If the cell contains a list, flatten and clean each element
    if isinstance(cell_value, list):
        cleaned_parts = [] # create a bucket for clean pieces.
        for element in cell_value:
            if element is None:
                continue # Skip missing pieces.
            element_str = str(element).strip() # Turn piece into text and trim spaces.
            if element_str.lower() in ("nan", "none"):
                continue
            cleaned_parts.append(element_str)
        return " ".join(cleaned_parts).strip() # Join parts with a space. Trim ends.

    # If the cell is None or a NaN float, return empty string
    if cell_value is None or (isinstance(cell_value, float) and pd.isna(cell_value)):
        return ""

    # For normal strings (or other types), clean and return
    cell_str = str(cell_value).strip()
    return "" if cell_str.lower() == "nan" else cell_str


# Apply cleaning to every cell in the 'input' column
data["input"] = data["input"].apply(normalize_input_cell)

In [None]:
# =========== TRAIN/VAL/TEST SPLIT (70/15/15) STRATIFIED ==============
train_text, temp_text, train_labels, temp_labels = train_test_split(
    data['input'], data['label'], test_size=0.3, random_state=42, stratify=data['label']
)
val_text, test_text, val_labels, test_labels = train_test_split(
    temp_text, temp_labels, test_size=0.5, random_state=42, stratify=temp_labels
)

# ---- Print basic stats ----
print(f"Training set:   {len(train_text)} samples")
print(f"Validation set: {len(val_text)} samples")
print(f"Test set:       {len(test_text)} samples")

In [None]:
MAX_LENGTH = 512
MODEL_NAME = 'bert-base-multilingual-cased'  # switch to mBERT
tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)

def encode_texts(texts):
    return tokenizer.batch_encode_plus(  # Encode whole batch at once
        list(texts),
        max_length=MAX_LENGTH,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,  # mask of 1s for real tokens and 0s for padding
        return_tensors='pt'  # return PyTorch tensors
    )

tokens_train = encode_texts(train_text)
tokens_val = encode_texts(val_text)
tokens_test = encode_texts(test_text)

In [None]:
batch_size = 64  # number of samples the model sees in one step.

train_data = TensorDataset(
    tokens_train['input_ids'],
    tokens_train['attention_mask'],
    torch.tensor(train_labels.values).long()
)
val_data = TensorDataset(
    tokens_val['input_ids'],
    tokens_val['attention_mask'],
    torch.tensor(val_labels.values).long()
)
test_data = TensorDataset(
    tokens_test['input_ids'],
    tokens_test['attention_mask'],
    torch.tensor(test_labels.values).long()
)

train_loader = DataLoader(train_data, sampler=RandomSampler(train_data), batch_size=batch_size)
val_loader   = DataLoader(val_data, sampler=SequentialSampler(val_data), batch_size=batch_size)
test_loader  = DataLoader(test_data, sampler=SequentialSampler(test_data), batch_size=batch_size)

In [None]:
# =================== mBERT MODEL SETUP (top-3 layers unfrozen) ====================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

# 1) Freeze all params (embeddings, all encoder blocks, pooler, classifier)
for p in model.parameters():
    p.requires_grad = False

# 2) Unfreeze only the top 3 encoder layers (layers 9,10,11 for BERT-base)
for layer in model.bert.encoder.layer[-3:]:
    for p in layer.parameters():
        p.requires_grad = True

# 3) Ensure the classification head is trainable
for p in model.classifier.parameters():
    p.requires_grad = True

model = model.to(device)

# Optimizer only over trainable params
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5, weight_decay=0.01)
epochs = 10
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)
loss_fn = torch.nn.CrossEntropyLoss()

In [None]:
def explain_prediction(text, tokenizer, model, device, max_len=512):
    """
    Generate SHAP explanations for a single text input.
    Translates the model's prediction into per-token importance scores
    showing how much each token contributed toward predicting 'Fake'.
    Returns a list of (token, score) pairs.
    """
    model.eval()
    explainer = shap.Explainer(
        lambda x: model(
            **tokenizer(list(x), padding=True, truncation=True, max_length=max_len, return_tensors="pt").to(device)
        ).logits.softmax(-1).detach().cpu().numpy(),
        masker=shap.maskers.Text(tokenizer)
    )
    shap_values = explainer([text])
    tokens = shap_values.data[0]
    scores = shap_values.values[0][:, 1]  # if label 1 = Fake
    return list(zip(tokens, scores))

In [None]:
class EarlyStopper:
    """It stops training when validation loss stops getting meaningfully better.
    Tracks the best validation loss seen so far.
    If the loss doesn’t improve by at least min_delta for patience checks, it tells you to stop."""
    def __init__(self, patience=2, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = float('inf')
        self.counter = 0

    def early_stop(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss # Saves the best loss record
            self.counter = 0
            return False # do not stop
        else:
            self.counter += 1
            return self.counter >= self.patience

In [None]:
train_losses, val_losses = [], []
early_stopper = EarlyStopper(patience=2)
best_val_loss = float('inf')

for epoch in range(epochs):
    model.train()
    total_train_loss = 0
    for batch in train_loader:
        b_input_ids, b_attn_mask, b_labels = [b.to(device) for b in batch]
        model.zero_grad()
        outputs = model(b_input_ids, attention_mask=b_attn_mask, labels=b_labels)
        loss = outputs.loss
        total_train_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # ---- Validation ----
    model.eval()
    total_val_loss = 0
    preds, truths = [], []
    with torch.no_grad():
        for batch in val_loader:
            b_input_ids, b_attn_mask, b_labels = [b.to(device) for b in batch]
            outputs = model(b_input_ids, attention_mask=b_attn_mask, labels=b_labels)
            total_val_loss += outputs.loss.item()
            logits = outputs.logits
            preds += list(torch.argmax(logits, dim=1).cpu().numpy())
            truths += list(b_labels.cpu().numpy())
    avg_val_loss = total_val_loss / len(val_loader)
    val_losses.append(avg_val_loss)
    val_acc = accuracy_score(truths, preds)
    print(f"Epoch {epoch+1}: train_loss={avg_train_loss:.4f} val_loss={avg_val_loss:.4f} val_acc={val_acc:.4f}")

    # ---- SAVE BEST MODEL ----
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        model.save_pretrained("mbert_fakenews_model")
        tokenizer.save_pretrained("mbert_fakenews_model")

    if early_stopper.early_stop(avg_val_loss):
        print("Early stopping triggered.")
        break

# ======================= TEST SET EVALUATION =======================
# Load the best model before testing
model = BertForSequenceClassification.from_pretrained("mbert_fakenews_model").to(device)

model.eval()
test_preds, test_truths = [], []
with torch.no_grad():
    for batch in test_loader:
        b_input_ids, b_attn_mask, b_labels = [b.to(device) for b in batch]
        logits = model(b_input_ids, attention_mask=b_attn_mask).logits
        test_preds += list(torch.argmax(logits, dim=1).cpu().numpy())
        test_truths += list(b_labels.cpu().numpy())

print("\n==== TEST SET METRICS ====")
print(classification_report(test_truths, test_preds, target_names=['True','Fake']))
acc = accuracy_score(test_truths, test_preds)
print("Test Accuracy: %.2f%%" % (acc*100))

In [None]:
plt.figure(figsize=(6,6))
cm = confusion_matrix(test_truths, test_preds)
ConfusionMatrixDisplay(cm, display_labels=['Real/True','Fake']).plot(values_format='d', cmap='Blues')
plt.title('Confusion Matrix - Test Set')
plt.show()

In [None]:
print("\nPredicting on Unseen News Headlines")
# 0 = true / real   |   1 = fake / false
unseen_news = [
    "中国科学院宣布在上海成立一个新的高端量子计算实验室，目标是在五年内实现100量子比特的突破。",  # 0 (Chinese)
    "El gobierno chino aumentará considerablemente la financiación de la educación rural este año para ayudar a más estudiantes en zonas pobres a terminar la escuela.",  # 0 (Spanish)
    "Experts claim that eating spicy snack sticks every day can dramatically lower cancer risk and recommend a nationwide rollout.",  # 1 (English)
    "北京市政府将在今年下半年为65岁以上的所有居民提供免费流感疫苗，以改善社区健康。",  # 0 (Chinese)
    "WeChat officiellement annonce que toutes les fonctionnalités de transfert d'argent seront définitivement désactivées pour chaque utilisateur à partir de demain.",  # 1 (French)
    "Scientists discover an alien creature over ten meters long in the Yangtze River, sparking nationwide debate.",  # 1 (English)
    "上海市政府与多家医院联合开展免费新冠疫苗加强针接种活动。",  # 0 (Chinese)
    "भारत मौसम विज्ञान विभाग ने भविष्यवाणी की है कि इस सर्दी पूरे देश में लगातार तीन महीने तक हिमपात होगा।",  # 1 (Hindi)
    "Huawei unveils the world’s first foldable 5G laptop, drawing industry attention.",  # 0 (English)
    "Una escuela primaria en Shenzhen introduce un curso de 'Supervivencia en la Luna', donde los estudiantes visitarán una base lunar real.",  # 1 (Spanish)
    "ዶክተር አማርኛ ሕክምና በኢትዮጵያ ዘርፍ በጥራት መግባባት አስቀድሞ ተሳካ።", # real
    "Gwamnatin Najeriya ta ƙaddamar da shirin bunkasa kiwon lafiya a karkara.", #real
    "የኢትዮጵያ ጤና ሚኒስቴር በዛሬው ቀን በአዲስ አበባ አዲስ የሕክምና ማዕከል መክፈቱን አስታወቀ።ይህ ማዕከል በዓመት ከ50,000 በላይ ታካሚዎችን ለመቀበል ይችላል እና የሳይንስ ዘመናዊ መሳሪያዎች በሙሉ ተዘጋጅቶበታል።ከአሁኑ በኋላ ዜጎች ለመሳሰሉት ሕክምናዊ አገልግሎት ወደ ውጭ ሀገር መሄድ አያስፈልጋቸውም ብለዋል ባለሥልጣኖች። "
]


unseen_enc = encode_texts(unseen_news)
model.eval()
with torch.no_grad():
    logits = model(unseen_enc['input_ids'].to(device), attention_mask=unseen_enc['attention_mask'].to(device)).logits
    predictions = torch.argmax(logits, dim=1).cpu().numpy()
for text, pred in zip(unseen_news, predictions):
    print(f"News: {text}\nPredicted: {'Fake' if pred==1 else 'true'}\n")