#ALL IMPORTS

In [None]:
!pip install nlpaug nltk --quiet
!pip install shap
!pip install langdetect

import os, nltk, certifi  # OS helpers, NLTK tools, and a trusted SSL cert bundle
from pathlib import Path  # Clean, cross-platform file paths
import pandas as pd  # Data frames for tables
import numpy as np  # Fast arrays and math
import torch  # PyTorch core
from torch.utils.data import DataLoader, TensorDataset, RandomSampler, SequentialSampler, WeightedRandomSampler  # Batching and dataset helpers
from transformers import BertTokenizerFast, BertForSequenceClassification, get_linear_schedule_with_warmup  # BERT tokenizer/model and LR scheduler
from torch.optim import AdamW  # AdamW optimizer (weights decay friendly)
from sklearn.model_selection import train_test_split  # Split data into train/validation/test
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, ConfusionMatrixDisplay  # Evaluation metrics and confusion-matrix plot
import matplotlib.pyplot as plt  # Charts and figures
import random  # Simple random utilities (seeding)
from langdetect import detect  # Detect the language of a text
import shap  # Explain model predictions with SHAP values
import nlpaug.augmenter.word as naw  # Word-level data augmentation
from google.cloud import translate
from google.oauth2 import service_account
from google.auth.transport.requests import Request



# Point SSL to a valid cert store to avoid HTTPS errors
os.environ['SSL_CERT_FILE'] = certifi.where()
PROJECT_ROOT = Path(__file__).resolve().parent if "__file__" in globals() else Path.cwd()
CANDIDATES = [
    os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", ""),
    PROJECT_ROOT / "keys" / "translate-sa.json",        # ./keys/translate-sa.json
    PROJECT_ROOT.parent / "keys" / "translate-sa.json", # ../keys/translate-sa.json
    Path.home() / "gcp" / "translate-sa.json",          # ~/gcp/translate-sa.json
]
key_path = next((Path(p).expanduser() for p in CANDIDATES if p and Path(p).expanduser().is_file()), None)
if key_path:
    os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = str(key_path)
    creds = service_account.Credentials.from_service_account_file(str(key_path))
    client = translate.TranslationServiceClient(credentials=creds)
else:
    os.environ.pop("GOOGLE_APPLICATION_CREDENTIALS", None)
    client = translate.TranslationServiceClient()
# --- end auth ---
nltk.data.path.clear()  # Reset NLTK’s search paths
nltk.data.path.append('nltk_data')  # Tell NLTK to use the local ./nltk_data folder
nltk.download('averaged_perceptron_tagger', download_dir='nltk_data')  # POS tagger model (Download a part-of-speech tagger)
nltk.download('averaged_perceptron_tagger_eng', download_dir='nltk_data')  # English-only POS tagger
nltk.download('wordnet', download_dir='nltk_data')  # WordNet lexical database (a dictionary of word relations)
nltk.download('punkt', download_dir='nltk_data')  # Tokenizer models (a tokenizer to split text into sentences and words.)
nltk.download('omw-1.4', download_dir='nltk_data')  # Multilingual WordNet data



# Setting up translation environment

In [None]:
# ----- portable credentials + project resolution -----
SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
LOCATION = os.getenv("TRANSLATE_LOCATION", "global")  # override via env if needed

PROJECT_ROOT = Path(__file__).resolve().parent if "__file__" in globals() else Path.cwd()
CANDIDATE_KEYS = [
    os.getenv("GOOGLE_APPLICATION_CREDENTIALS"),
    PROJECT_ROOT / "keys" / "translate-sa.json",        # ./keys/translate-sa.json
    PROJECT_ROOT.parent / "keys" / "translate-sa.json", # ../keys/translate-sa.json
    Path.home() / "gcp" / "translate-sa.json",          # ~/gcp/translate-sa.json
]

def _find_key(cands):
    for p in cands:
        if not p:
            continue
        p = Path(p).expanduser()
        if p.is_file():
            return str(p)
    return None

KEY = _find_key(CANDIDATE_KEYS)

if KEY:
    creds = service_account.Credentials.from_service_account_file(KEY, scopes=SCOPES)
    project_id = os.getenv("GOOGLE_CLOUD_PROJECT") or os.getenv("GCP_PROJECT") or creds.project_id
else:
    # Application Default Credentials (gcloud auth application-default login)
    creds, project_id = google_auth_default(scopes=SCOPES)

creds.refresh(Request())

if not project_id:
    raise RuntimeError(
        "No GCP project id resolved. Set GOOGLE_CLOUD_PROJECT or use a service account key with project_id."
    )

client = translate.TranslationServiceClient(credentials=creds)
parent = f"projects/{project_id}/locations/{LOCATION}"
# -----------------------------------------------------

# sanity check
resp = client.translate_text(request={
    "parent": parent,
    "contents": ["Bonjour"],
    "mime_type": "text/plain",
    "target_language_code": "en",
})
print(resp.translations[0].translated_text)

# SETTING SEED FOR REPRODUCIBILITY

In [None]:
# ========================= REPRODUCIBILITY SEEDS =========================
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(42)

# DATA LOADING & MERGING SECTION (FEEDBACK + MAIN)

In [None]:
def load_and_merge_datasets():
    import os, re, unicodedata
    import numpy as np
    import pandas as pd
    from functools import lru_cache

    # ---- regex helpers ----
    AGENCY_RE   = re.compile(r"\b(reuters|associated\s+press|ap|afp|xinhua|bbc|cnn|fox\s+news|the\s+cook\s+political\s+report|center\s+for\s+politics)\b", re.I)
    URL_RE      = re.compile(r"https?://\S+|www\.\S+|t\.co/\S+|pic\.twitter\.com/\S+", re.I)
    HANDLE_RE   = re.compile(r"@\w{2,}", re.I)
    HASHTAG_RE  = re.compile(r"#\w+", re.I)
    CREDIT_RE   = re.compile(r"^\s*(featured image via|photo by|image credit|via:)\b.*$", re.I|re.M)
    LABEL_RE    = re.compile(r"\b(fake\s+news|satire|rumou?r|hoax|fact[-\s]?check(?:ed|ing)?)\b", re.I)

    def _canon(s: str) -> str:
        s = unicodedata.normalize("NFKC", str(s)).lower()
        s = re.sub(r"\s+", " ", s).strip()
        return s

    def _clean_text(s: str) -> str:
        s = unicodedata.normalize("NFKC", str(s))
        s = URL_RE.sub(" ", s)
        s = HANDLE_RE.sub(" ", s)
        s = HASHTAG_RE.sub(" ", s)
        s = CREDIT_RE.sub(" ", s)
        s = AGENCY_RE.sub(" [ORG] ", s)
        s = LABEL_RE.sub(" ", s)
        s = re.sub(r"\s+", " ", s).strip()
        return s

    # ---- language detection + translation (sample-first) ----
    BATCH = 128  # adjust to quota

    def _looks_english(s: str) -> bool:
        if not s:
            return True
        ascii_ratio = sum(1 for ch in s if ord(ch) < 128) / max(1, len(s))
        return ascii_ratio >= 0.98 and any(v in s.lower() for v in "aeiou")

    try:
        from langdetect import detect
        HAVE_DETECT = True
    except Exception:
        HAVE_DETECT = False

    @lru_cache(maxsize=200_000)
    def _lang(s: str) -> str:
        if _looks_english(s):
            return "en"
        if HAVE_DETECT:
            try:
                return detect(s)
            except Exception:
                return "en"
        return "en"

    def _needs_translation(series: pd.Series, sample_n: int = 100) -> bool:
        s = series.astype(str)
        sample = s[s.str.strip().ne("")].head(sample_n)
        if sample.empty:
            return False
        # translate only if ANY sampled row is non-English
        return any(_lang(text) != "en" for text in sample)

    def _translate_batch(texts):
        # requires configured `client` and `PARENT` for GCP Translate v3
        try:
            resp = client.translate_text(request={
                "parent": PARENT,
                "contents": texts,
                "mime_type": "text/plain",
                "target_language_code": "en",
            })
            return [t.translated_text for t in resp.translations]
        except Exception:
            return texts  # fail-open

    def translate_series_to_en(series: pd.Series) -> pd.Series:
        s = series.astype(str)
        uniq = pd.Series(s.unique())
        langs = uniq.map(_lang)
        non_en_set = set(uniq[langs != "en"].tolist())
        if not non_en_set:
            return s
        mask = s.isin(non_en_set)
        out = s.copy()
        idx = np.flatnonzero(mask.values)
        for i in range(0, len(idx), BATCH):
            j = idx[i:i+BATCH]
            out.iloc[j] = _translate_batch(s.iloc[j].tolist())
        return out

    # ---- load files ----
    dfs = []
    DATASET_DIR = (Path(__file__).resolve().parent if "__file__" in globals() else Path.cwd()) / "Datasets"
    if not DATASET_DIR.is_dir():
        raise FileNotFoundError(f"Dataset folder not found: {DATASET_DIR}")

    files_info = ["True.csv", "Fake.csv"]

    for fname in files_info:
        df = None
        path = DATASET_DIR / fname
        for sep in [",", ";"]:
            for encoding in ["utf-8", "latin1"]:
                try:
                    df = pd.read_csv(path, encoding=encoding, sep=sep, on_bad_lines="skip", low_memory=False)
                    if df.shape[1] <= 1:
                        df = None
                        continue
                    break
                except Exception:
                    df = None
            if df is not None:
                break
        if df is None:
            print(f"Skipping this file: {path.name} (could not parse)")
            continue

        # find columns
        cols = {c.lower(): c for c in df.columns}
        title_col = next((cols[c] for c in cols if "title" in c), None)
        text_col  = next((cols[c] for c in cols if "text" in c), None)
        label_col = next((cols[c] for c in cols if "label" in c), None)

        # build input
        if title_col and text_col:
            df['input'] = df[title_col].astype(str).fillna('') + ' [SEP] ' + df[text_col].astype(str).fillna('')
        elif title_col:
            df['input'] = df[title_col].astype(str).fillna('')
        elif text_col:
            df['input'] = df[text_col].astype(str).fillna('')
        else:
            print(f"SKIPPING: {path.name} (no title or text column found)")
            continue

        if label_col is None:
            print(f"SKIPPING: {path.name} (no label column found)")
            continue

        df['label'] = df[label_col]

        # translate only if the first 100 non-empty rows indicate non-English
        if _needs_translation(df['input'], sample_n=100):
            df['input'] = translate_series_to_en(df['input'])
            print(f"{path.name}: non-English detected in sample. Translated affected rows.")
        else:
            print(f"{path.name}: sample looks English. Skipping translation.")

        # clean and filter
        df = df[['input', 'label']].dropna()
        df['input_clean'] = df['input'].astype(str).apply(_clean_text)

        valid_labels = {'0', '1', 'true', 'false', 'fake', 'real'}
        df = df[df['label'].astype(str).str.strip().str.lower().isin(valid_labels)]

        try:
            df['label'] = df['label'].astype(int)
        except Exception:
            df['label'] = df['label'].map(
                lambda x: 1 if str(x).strip().lower() in ['1', 'fake', 'false']
                else (0 if str(x).strip().lower() in ['0', 'true', 'real'] else np.nan)
            )
            df = df.dropna(subset=['label'])
            df['label'] = df['label'].astype(int)

        df = df[df['label'].isin([0, 1])]
        dfs.append(df)

    if len(dfs) == 0:
        raise Exception("No valid datasets loaded!")

    data = pd.concat(dfs, ignore_index=True)

    # use cleaned text
    data['input'] = data['input_clean']
    data = data.drop(columns=['input_clean'])

    initial_count = data.shape[0]

    # duplicate handling
    data['input'] = data['input'].str.strip()
    data['input_lower'] = data['input'].str.lower()
    dup_exact = data.duplicated('input_lower').sum()
    data = data.drop_duplicates('input_lower')

    data['canon'] = data['input'].apply(_canon)
    dup_canon = data.duplicated('canon').sum()
    data = data.drop_duplicates('canon').drop(columns=['canon', 'input_lower'])

    final_count = data.shape[0]
    print("="*50)
    print(f"Initial rows before duplicate removal: {initial_count}")
    print(f"Exact dups removed (raw): {dup_exact}")
    print(f"Canonical dups removed (cleaned): {dup_canon}")
    print(f"Rows left after removing duplicates: {final_count}")
    print(f"Shape after merge and duplicate removal: {data.shape}")
    print("="*50)
    print(data.head())

    # save merged/cleaned
    MAIN_DATA_PATH = "Datasets/main_data.csv"
    data.to_csv(MAIN_DATA_PATH, index=False)
    print("Merged data saved as main_data.csv")

    # ---- feedback integration ----
    FEEDBACK_PATH = "Datasets/user_feedback.csv"
    if os.path.exists(FEEDBACK_PATH):
        print("="*50)
        print("Feedback file found. Integrating corrections...")
        main_data = pd.read_csv(MAIN_DATA_PATH)
        feedback = pd.read_csv(FEEDBACK_PATH)

        if 'input' in feedback.columns:
            if _needs_translation(feedback['input'], sample_n=100):
                feedback['input'] = translate_series_to_en(feedback['input'])
                print("Feedback: non-English detected in sample. Translated affected rows.")
            else:
                print("Feedback: sample looks English. Skipping translation.")
            feedback.to_csv(FEEDBACK_PATH, index=False)

        if feedback['label'].dtype.kind not in 'biufc':
            label_map = {'Fake': 1, 'True': 0, 1: 1, 0: 0}
            feedback['label'] = feedback['label'].map(label_map)

        feedback = feedback[['input', 'label']].dropna()

        main_data = main_data[~main_data['input'].isin(feedback['input'])]
        updated_data = pd.concat([main_data, feedback], ignore_index=True)
        updated_data = updated_data.sample(frac=1, random_state=42).reset_index(drop=True)
        updated_data.to_csv(MAIN_DATA_PATH, index=False)
        print("Feedback integrated and main_data.csv updated.")
        data = updated_data

    # final shuffle for randomness
    data = data.sample(frac=1, random_state=42).reset_index(drop=True)
    return data


# run
data = load_and_merge_datasets()
print(f"after merging the feedback data and main data = {data.shape}")
print("="*50)

# Count current label distribution

In [None]:
true_count = (data['label'] == 0).sum() # true label   
fake_count = (data['label'] == 1).sum() # fake label

print(f"Before balancing: \nTrue label news = {true_count} \nFake label news = {fake_count}")

# Choosing Augmenter: synonym replacement

In [None]:
# # Autodetect and balance by augmenting the minority class with WordNet synonyms
# aug = naw.SynonymAug(aug_src='wordnet', aug_max=10)

# # Current class counts
# counts = data['label'].value_counts()
# if len(counts) != 2: # check if only 1 and 0 exists.
#     raise ValueError("Expected binary labels 0/1. Got: " + str(counts.to_dict()))

# minority_label = counts.idxmin() # finds the label with the fewest rows.
# majority_label = counts.idxmax() # finds the label with the most rows.

# n_to_augment = counts[majority_label] - counts[minority_label] # calculate how many rows need to be augmented.
# print("Detecting minority class....")
# if minority_label==1:
#     print(f"Minority class: Fake | 1 | Need to add: {n_to_augment}")
# else:
#     print(f"Minority class: True | 0 | Need to add: {n_to_augment}")

# if n_to_augment > 0: # only augment if the dataset is imbalanced.
#     minority_df = data[data['label'] == minority_label].copy() # takes only the copy of minority rows.
#     augmented_texts = []

#     for i in range(n_to_augment): # loop upto difference in number of rows.
#         original_text = minority_df.sample(1, random_state=42+i)['input'].values[0] # picks one random minority text each time, with a changing seed.
#         new_text = aug.augment(original_text)           # may return str or list
#         if isinstance(new_text, list): # if the augmenter returns a list, take the first string
#             new_text = new_text[0]
#         augmented_texts.append(new_text)

#     aug_df = pd.DataFrame({'input': augmented_texts, 'label': minority_label}) # turns the new texts into a DataFrame with the minority label.
#     data = pd.concat([data, aug_df], ignore_index=True) # adds the new rows to the dataset.
#     data = data.sample(frac=1, random_state=42).reset_index(drop=True) # shuffles all rows and resets row numbers.
# else:
#     print("Already balanced. No augmentation applied.")

# print("After balancing:", data['label'].value_counts())

# AFTER AUGMENTATION, NEW DATASET!

In [None]:
# Visualize balance for your new, balanced dataset!
plt.figure(figsize=(5,5))
plt.pie(
    [len(data[data.label==0]), len(data[data.label==1])],
    labels=['True', 'Fake'],
    autopct='%1.1f%%',
    explode=[0.05,0.05],
    colors=['skyblue', 'salmon']
)
plt.title("Label Distribution After Augmentation (0=Real, 1=Fake)")
plt.show()

# NORMALIZATION OF INPUT COLUMN
### -LIST TO STRING
### -REMOVE WHITE SPACE
### -REMOVE "NAN, NONE"

In [None]:
def normalize_input_cell(cell_value):
    """
    - Cleans and standardizes a cell from the 'input' column:
    - Flattens lists into a single string
    - Removes None/nan/empty-like values
    - Strips extra whitespace
    """

    # If the cell contains a list, flatten and clean each element
    if isinstance(cell_value, list):
        cleaned_parts = [] # create a bucket for clean pieces.
        for element in cell_value:
            if element is None:
                continue # Skip missing pieces.
            element_str = str(element).strip() # Turn piece into text and trim spaces.
            if element_str.lower() in ("nan", "none"):
                continue
            cleaned_parts.append(element_str)
        return " ".join(cleaned_parts).strip() # Join parts with a space. Trim ends.

    # If the cell is None or a NaN float, return empty string
    if cell_value is None or (isinstance(cell_value, float) and pd.isna(cell_value)):
        return ""

    # For normal strings (or other types), clean and return
    cell_str = str(cell_value).strip()
    return "" if cell_str.lower() == "nan" else cell_str


# Apply cleaning to every cell in the 'input' column
data["input"] = data["input"].apply(normalize_input_cell)

# TRAIN / VALIDATION / TEST SET SPLIT (70/15/15) STRATIFIED

In [None]:
# =========== TRAIN/VAL/TEST SPLIT (70/15/15) STRATIFIED ==============
train_text, temp_text, train_labels, temp_labels = train_test_split(
    data['input'], data['label'], test_size=0.3, random_state=42, stratify=data['label']
)
val_text, test_text, val_labels, test_labels = train_test_split(
    temp_text, temp_labels, test_size=0.5, random_state=42, stratify=temp_labels
)

# ---- Print basic stats ----
print(f"Training set:   {len(train_text)} samples")
print(f"Validation set: {len(val_text)} samples")
print(f"Test set:       {len(test_text)} samples")

# INITIALIZATION OF  TRANSLATOR OBJECT (GOOGLETRANS)

In [None]:
from types import SimpleNamespace

PROJECT = "cogent-metric-470213-i8"
LOCATION = "global"
client = translate.TranslationServiceClient()
PARENT = f"projects/{PROJECT}/locations/{LOCATION}"

def gtranslate(text, dest="en", src=None):
    req = {
        "parent": PARENT,
        "contents": [text],
        "mime_type": "text/plain",
        "target_language_code": dest,
    }
    if src:  # e.g., "fr"
        req["source_language_code"] = src
    resp = client.translate_text(request=req)
    return SimpleNamespace(text=resp.translations[0].translated_text)

def predict_news_multilingual(news_text):
    try:
        translated = gtranslate(news_text, dest="en").text
    except Exception:
        translated = news_text
    inputs = tokenizer(translated, padding=True, truncation=True, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
        pred = torch.argmax(logits, dim=1).item()
    return "Fake" if pred == 1 else "True"

# TOKENIZATION SECTION

In [None]:
MAX_LENGTH = 512  # use more context, set as high as VRAM allows
MODEL_NAME = 'bert-base-uncased'
tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)

def encode_texts(texts):
    return tokenizer.batch_encode_plus( #Encode whole batch at once for consistency
        list(texts),
        max_length=MAX_LENGTH,
        padding='max_length',
        truncation=True,
        return_attention_mask=True, # add a mask of 1s for real tokens and 0s for padding.
        return_tensors='pt' # return PyTorch tensors
    )

tokens_train = encode_texts(train_text) # tokenize the train split.
tokens_val = encode_texts(val_text) # tokenize the validation split.
tokens_test = encode_texts(test_text) # tokenize the test split.

# TORCH DATASET/LOADER SETUP

In [None]:
# batch_size = 64  # number of samples the model sees in one step.

# train_data = TensorDataset(tokens_train['input_ids'], tokens_train['attention_mask'], torch.tensor(train_labels.values).long()) # a PyTorch dataset with three tensors: token ids, attention mask, and labels for training
# val_data   = TensorDataset(tokens_val['input_ids'], tokens_val['attention_mask'], torch.tensor(val_labels.values).long())
# test_data  = TensorDataset(tokens_test['input_ids'], tokens_test['attention_mask'], torch.tensor(test_labels.values).long())

# train_loader = DataLoader(train_data, sampler=RandomSampler(train_data), batch_size=batch_size) # creates an iterator that yields shuffled training batches.
# val_loader   = DataLoader(val_data, sampler=SequentialSampler(val_data), batch_size=batch_size) # no shuffle, stable and reproducible evaluation
# test_loader  = DataLoader(test_data, sampler=SequentialSampler(test_data), batch_size=batch_size) # no shuffle, stable and reproducible test

batch_size = 64  # number of samples the model sees in one step.

# build label tensors once
y_train = torch.tensor(train_labels.values).long()
y_val   = torch.tensor(val_labels.values).long()
y_test  = torch.tensor(test_labels.values).long()

train_data = TensorDataset(tokens_train['input_ids'], tokens_train['attention_mask'], y_train)
val_data   = TensorDataset(tokens_val['input_ids'],   tokens_val['attention_mask'],   y_val)
test_data  = TensorDataset(tokens_test['input_ids'],  tokens_test['attention_mask'],  y_test)

# === Weighted sampler for class imbalance (replaces RandomSampler) ===
class_counts  = torch.bincount(y_train)                 # e.g., tensor([n_true, n_fake])
class_weights = (1.0 / class_counts.float())            # keep for loss_fn: CrossEntropyLoss(weight=class_weights.to(device))
sample_weights = class_weights[y_train]                  # per-sample weights
train_sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

# OLD: shuffled training loader (disabled)
# train_loader = DataLoader(train_data, sampler=RandomSampler(train_data), batch_size=batch_size)

# NEW: weighted training loader
train_loader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
val_loader   = DataLoader(val_data,   sampler=SequentialSampler(val_data),   batch_size=batch_size)
test_loader  = DataLoader(test_data,  sampler=SequentialSampler(test_data),  batch_size=batch_size)

# BERT (BASE) MODEL SETUP
### UNFREEZING ALL LAYERS
### ADAMW, WEIGHT DECAY, SCHEDULE WITH WARMUP, CROSS ENTROPY LOSS

In [None]:
# =================== BERT MODEL SETUP (top-3 layers unfrozen) ====================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

# Freeze everything
for p in model.parameters():
    p.requires_grad = False

# Unfreeze top-3 encoder layers (9,10,11) + classifier
for layer in model.bert.encoder.layer[-3:]:
    for p in layer.parameters():
        p.requires_grad = True
for p in model.classifier.parameters():
    p.requires_grad = True

model = model.to(device)
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5, weight_decay=0.01)
epochs = 5
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(0.1*total_steps), num_training_steps=total_steps)
loss_fn = torch.nn.CrossEntropyLoss()

# SHAP EXPLAINATION FUNCTION SETUP

In [None]:
def explain_prediction(text, tokenizer, model, device, max_len=512):
    """
    Generate SHAP explanations for a single text input.
    Translates the model's prediction into per-token importance scores
    showing how much each token contributed toward predicting 'Fake'.
    Returns a list of (token, score) pairs.
    """
    model.eval() # Switch the model to eval mode. Turns off dropout. Makes outputs stable
    explainer = shap.Explainer(
        lambda x: model(**tokenizer(list(x), padding=True, truncation=True, max_length=max_len, return_tensors="pt").to(device)).logits.softmax(-1).detach().cpu().numpy(), # Get the prediction probabilities
        masker=shap.maskers.Text(tokenizer)
    )
    shap_values = explainer([text])
    # shap_values.data[0] is the list of tokens, shap_values.values[0] is the list of SHAP values
    tokens = shap_values.data[0]
    scores = shap_values.values[0][:, 1]  # If 1 = Fake
    # Return as a list of (token, score) pairs
    return list(zip(tokens, scores))

# EARLY STOPPING SETUP

In [None]:
class EarlyStopper:
    """It stops training when validation loss stops getting meaningfully better.
    Tracks the best validation loss seen so far.
    If the loss doesn’t improve by at least min_delta for patience checks, it tells you to stop."""
    def __init__(self, patience=2, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = float('inf')
        self.counter = 0

    def early_stop(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss # Saves the best loss record
            self.counter = 0
            return False # do not stop
        else:
            self.counter += 1
            return self.counter >= self.patience

# TRAINING LOOP

In [None]:
train_losses, val_losses = [], []
early_stopper = EarlyStopper(patience=2)
best_val_loss = float('inf')  # track best validation loss.

for epoch in range(epochs): # which is 10 right now.
    model.train()           # Putting model into training mode. (Enables dropout and layer norms.)
    total_train_loss = 0
    for batch in train_loader:
        b_input_ids, b_attn_mask, b_labels = [b.to(device) for b in batch] # Move inputs and labels to GPU or CPU. Needed for fast compute.
        model.zero_grad()  # Clears old gradients from all model parameters so they don’t accumulate
        outputs = model(b_input_ids, attention_mask=b_attn_mask, labels=b_labels) # Outputs are: Ids, Att. mask, label
        loss = outputs.loss
        total_train_loss += loss.item() # Add to the epoch total. For averaging later.
        loss.backward() # Back propagation. Compute gradients.
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Clip gradients to 1.0. Prevent exploding updates.
        optimizer.step() # Apply gradients. Update weights.
        scheduler.step() # Update learning rate. Warmup then decay.
    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # ---- Validation ----
    model.eval()
    total_val_loss = 0
    preds, truths = [], []
    with torch.no_grad():
        for batch in val_loader:
            b_input_ids, b_attn_mask, b_labels = [b.to(device) for b in batch]
            outputs = model(b_input_ids, attention_mask=b_attn_mask, labels=b_labels)
            total_val_loss += outputs.loss.item()
            logits = outputs.logits
            preds += list(torch.argmax(logits, dim=1).cpu().numpy())
            truths += list(b_labels.cpu().numpy())
    avg_val_loss = total_val_loss / len(val_loader)
    val_losses.append(avg_val_loss)
    val_acc = accuracy_score(truths, preds)
    print(f"Epoch {epoch+1}: train_loss={avg_train_loss:.4f} val_loss={avg_val_loss:.4f} val_acc={val_acc:.4f}")

    # ---- SAVE BEST MODEL (overwrite existing folder) ----
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        model.save_pretrained("bert_fakenews_model")
        tokenizer.save_pretrained("bert_fakenews_model")

    # EARLY STOPPING
    if early_stopper.early_stop(avg_val_loss):
        print("Early stopping triggered.")
        break

# ======================= TEST SET EVALUATION =======================
# Load the best model before testing
model = BertForSequenceClassification.from_pretrained("bert_fakenews_model").to(device)

model.eval()
test_preds, test_truths = [], []
with torch.no_grad():
    for batch in test_loader:
        b_input_ids, b_attn_mask, b_labels = [b.to(device) for b in batch]
        logits = model(b_input_ids, attention_mask=b_attn_mask).logits
        test_preds += list(torch.argmax(logits, dim=1).cpu().numpy())
        test_truths += list(b_labels.cpu().numpy())

print("\n==== TEST SET METRICS ====")
print(classification_report(test_truths, test_preds, target_names=['True','Fake']))
acc = accuracy_score(test_truths, test_preds)
print("Test Accuracy: %.2f%%" % (acc*100))

# CONFUSION MATRIX

In [None]:
plt.figure(figsize=(6,6))
cm = confusion_matrix(test_truths, test_preds)
ConfusionMatrixDisplay(cm, display_labels=['Real/True','Fake']).plot(values_format='d', cmap='Blues')
plt.title('Confusion Matrix - Test Set')
plt.show()

# UNSEEN NEWS HEADLINES PREDICTION

In [None]:
print("\nPredicting on Unseen News Headlines")
# 0 = true / real   |   1 = fake / false
unseen_news = [
    "The Chinese Academy of Sciences announces a new high-end quantum-computing lab in Shanghai, aiming for a 100-qubit breakthrough within five years.",  # 0
    "The Chinese government will greatly increase rural-education funding this year to help more students in impoverished areas finish school.",         # 0
    "Experts claim that eating spicy snack sticks every day can dramatically lower cancer risk and recommend a nationwide rollout.",                    # 1
    "Beijing will offer free flu shots to all residents aged 65 + in the second half of the year to improve community health.",                         # 0
    "WeChat officially announces that all money-transfer features will be permanently shut down for every user starting tomorrow.",                     # 1
    "Scientists discover an alien creature over ten meters long in the Yangtze River, sparking nationwide debate.",                                     # 1
    "The Shanghai government, together with several hospitals, launches a free COVID-19 booster-shot campaign.",                                        # 0
    "The China Meteorological Administration predicts continuous snowfall across the entire country for three straight months this winter.",            # 1
    "Huawei unveils the world’s first foldable 5G laptop, drawing industry attention.",                                                                 # 0
    "A primary school in Shenzhen introduces a “Moon Survival Experience” course, where students will visit an actual moon base.",                     # 1
]


unseen_enc = encode_texts(unseen_news)
model.eval()
with torch.no_grad():
    logits = model(unseen_enc['input_ids'].to(device), attention_mask=unseen_enc['attention_mask'].to(device)).logits
    predictions = torch.argmax(logits, dim=1).cpu().numpy()
for text, pred in zip(unseen_news, predictions):
    print(f"News: {text}\nPredicted: {'Fake' if pred==1 else 'true'}\n")