# VL04 - Text Classification with Naive Bayes
In this seminar we demonstrate the full lifecycle of training and evaluating a **Multinomial Naïve Bayes** spam classifier.
It connects the ideas from:
- **Lab 1 (rule-based spam filter)**:hand-written rules  
- **VL03 (text representation)**:Bag-of-Words model  

and shows how these come together in a *learned* classifier.

In [None]:
import pandas as pd
import spacy
from datasets import load_from_disk
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

try:
    nlp = spacy.load("en_core_web_sm", disable=["parser", "ner"])
except OSError:
    print("Warning: spaCy model 'en_core_web_sm' not found. Please run 'python -m spacy download en_core_web_sm'")
    # Fallback to a simpler model creation if the standard one fails
    nlp = spacy.blank("en")

## 1. Load and inspect the dataset
We use the same small SMS Spam dataset used in the previous labs. If you don't have it, run the following command in your terminal (root of the repository folder).

``python scripts/download_dataset.py dbarbedillo/SMS_Spam_Multilingual_Collection_Dataset data/sms_spam``

or use directly
````python
ds = load_dataset("dbarbedillo/SMS_Spam_Multilingual_Collection_Dataset")
````
if your notebook has access to internet.

In [None]:
ds = load_from_disk("../../data/sms_spam")  # columns: ['labels', 'text']
df =  ds["train"].to_pandas()

df_spam = df[["labels", "text"]].copy()
df_spam.columns = ["label", "text"]
df_spam.info()

# Apparently we have some duplicated rows!
df_spam.duplicated(keep=False).sum()
df_spam = df_spam.drop_duplicates()

df_spam["label"].value_counts()

## 2. Split into train/test sets
We split the dataset to estimate how well the model generalizes: we train on one partition and evaluate on unseen data; this prevents optimistic results from *testing on the training set*.

In practice, we use `train_test_split(..., test_size=0.2, random_state=42, stratify=df_spam["label"])` from sklearn. Here create an 80/20 split **stratified** by the column `label` to preserve the spam/ham ratio in both sets. The `random_state` param is to set the random seed to make the split reproducible.

Typically 80/20 is a solid default for small-medium corpora (or / and cross validation). For very large corpora, a 90/10 (or even 95/5) test set is sufficient.

In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    df_spam["text"], df_spam["label"], test_size=0.2, random_state=42, stratify=df_spam["label"]
)

print(y_train.value_counts(normalize=True))
print(y_test.value_counts(normalize=True))

## 3. Represent text as Bag-of-Words
**Naïve Bayes** relies on frequency-based representations of text, where each word’s occurrence contributes evidence for a class.
We use the Bag-of-Words (BoW) model to convert messages into numerical feature vectors—each column represents a word, and each entry its count in a message.

As seen in the previous Lab, we can use the `CountVectorizer` for this task. The vectorizer is first fit on the training set to learn the vocabulary and word frequencies, and then the same vocabulary is used to transform the test set, ensuring both datasets share identical feature dimensions.

In [None]:
vectorizer = CountVectorizer(stop_words='english')

X_train_bow = vectorizer.fit_transform(X_train) # fit to the training split
X_test_bow  = vectorizer.transform(X_test)      # transform the test split

print(f"Vocabulary size: {len(vectorizer.get_feature_names_out())}")

## 4. Train the Naïve Bayes model
The Multinomial Naïve Bayes classifier learns from word frequencies to estimate how strongly each word supports a class (e.g., spam vs. ham).

We use the `MultinomialNB` class for this purpose. We can pass many parameters such as:
- alpha: value for the laplace smoothing (α = 1 by default)
- class_prior: we can override priors `P(c)` (e.g., [0.9, 0.1] ) learned from data, e.g., if you know the % of spam in real life, or want to be more conservative.

Then `.fit()` receives the training corpus in bow format `X_train_bow`, and the true labels `y_train`. Here we learn the class prior `P(c)` and conditional word probabilities `P(w | c)`.

In [None]:
nb = MultinomialNB(alpha=1.0)
nb.fit(X_train_bow, y_train)

print("Model trained.")

## 5. Evaluate on the test set
After training the model, we evaluate its performance on the test set, which contains messages the model has never seen before.
We call `.predict()` with the vectorized representation of the test messages. This returns the predicted class for each document — the class with the **highest log-probability** according to the Naïve Bayes model.

With both the true labels (`y_test`) and the predicted labels (`y_pred`), we can now measure how well the model generalizes by comparing its predictions to the ground truth.

**Remember**: Task matter when interpretting results. What metrics are most important to us when it comes to spam detection?

In [None]:
y_pred = nb.predict(X_test_bow)
print(classification_report(y_test, y_pred))

ConfusionMatrixDisplay.from_estimator(nb, X_test_bow, y_test, cmap="Blues")
plt.title("Confusion Matrix – Naïve Bayes Spam Classifier")
plt.show()

### 5.1 Inspecting class probabilities
So far, we used `.predict()` to get only the most likely class for each message — spam or ham.
But Naïve Bayes is a probabilistic model, meaning it actually computes a full probability distribution over classes for every document.

We can access these values with `.predict_proba()`, which returns a matrix where each row corresponds to a message and each column to a class: `P(class | message)`.

In [None]:
# Get the predicted probabilities for each class
y_proba = nb.predict_proba(X_test_bow)

# Check the class order
print("Class order:", nb.classes_)

# Example: show the first 5 rows
df_proba = pd.DataFrame(y_proba, columns=nb.classes_)
df_proba["true_label"] = y_test.values
df_proba["message"] = X_test.values

# Reorder columns for readability
df_proba = df_proba[["message", "true_label", "ham", "spam"]]

# Show first few rows
pd.set_option("display.max_colwidth", 100)  # so text isn't truncated
df_proba.head()

### 5.2 Adjusting the decision threshold
By default, `.predict()` labels a message as the class with the highest probability, effectively using a threshold of 0.5 in binary classification.
However, for spam detection we may want to be more conservative — for example, labeling a message as spam only if
`P(spam | d) >= τ`.

In [None]:
# Use custom threshold τ
tau = 0.6
p_spam = y_proba[:, nb.classes_.tolist().index("spam")]
y_pred_tau = np.where(p_spam >= tau, "spam", "ham")

print(classification_report(y_test, y_pred_tau, digits=3))

# let's compute our confusion matrix with our predictions
cm = confusion_matrix(y_test, y_pred_tau, labels=nb.classes_) 

# we pass the computex matrix to display
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=nb.classes_)
disp.plot(cmap="Blues")
plt.title(f"Confusion Matrix — Naïve Bayes (τ ={tau})")
plt.show()

### 5.3 Inspect learned probabilities
Which words are most indicative of **spam** or **ham**?

After training, the Naïve Bayes model has learned how strongly each word supports one class over the other.
For every word w and class c, it stores:
- log P(w | c) — how likely the word is under that class
- log P(c) — the prior probability of that class

When classifying a new message, the model sums these log-probabilities across all words:

`log P(c) + ∑ count(w,d) × log P(w | c)`

A word is spam-indicative if it occurs much more often in spam than ham. That is:

`log P(w | spam) − log P(w | ham) > 0)`

and ham-indicative if the opposite is true.

By comparing these learned probabilities, we can see which words the model considers most characteristic of spam or ham.

In [None]:
def print_indicative_features(nb, vectorizer, topk=10, verbose=True):
    # 1. Identify class indices (spam = 1, ham = 0)
    classes = nb.classes_.tolist()
    i_spam  = classes.index("spam")
    i_ham   = classes.index("ham")
    
    # 2. Retrieve the learned log probabilities for each word and class
    feature_names = np.array(vectorizer.get_feature_names_out())
    log_pw = nb.feature_log_prob_
    
    # 3. Compute the log-odds for each feature: spam minus ham
    log_odds = log_pw[i_spam] - log_pw[i_ham]
    
    # 4. Build a small DataFrame for inspection
    df_weights = pd.DataFrame({
        "feature": feature_names,
        "logP_w_given_spam": log_pw[i_spam], # log P(w | spam)
        "logP_w_given_ham":  log_pw[i_ham],  # log P(w | ham)
        "logodds_spam_minus_ham": log_odds,  # log P(w | spam) - log P(w | ham)
        "odds_ratio": np.exp(log_odds)  # x times more likely to be of that class
    }).sort_values("logodds_spam_minus_ham", ascending=False)
    
    # 5. Display the top indicative words for each class

    top_spam = df_weights.head(topk)                 # most spam-indicative
    top_ham  = df_weights.tail(topk).iloc[::-1]      # most ham-indicative

    if (verbose):
        print("Top spam-indicative features:")
        display(top_spam[["feature", "logodds_spam_minus_ham", "odds_ratio"]])    
        print("\nTop ham-indicative features:")
        display(top_ham[["feature", "logodds_spam_minus_ham", "odds_ratio"]])
    else:    
        spam_words = ", ".join(top_spam["feature"].tolist())
        ham_words  = ", ".join(top_ham["feature"].tolist())
        print(f"Top spam-indicative words ({topk}):\n {spam_words}")
        print()
        print(f"Top ham-indicative words  ({topk}):\n {ham_words}")
        
print_indicative_features(nb, vectorizer, topk=15, verbose=False)

## 6. Inference on new messages
We now apply the model to unseen examples to see how it combines prior + likelihoods.

In [None]:
samples = [
    "Win a free prize today!",
    "Lunch meeting tomorrow at noon",
    "Congratulations!!! Claim your reward now",
]
X_samples = vectorizer.transform(samples)
preds = nb.predict(X_samples)
probs = nb.predict_proba(X_samples)

# Find the column index for "spam"
classes = nb.classes_.tolist()
i_spam = classes.index("spam")  # robust way

for msg, label, prob in zip(samples, preds, probs):
    print(f"{msg:50s} → {label.upper()}  (P(spam)={prob[i_spam]:.3f})")

## 7. Custom pre-processing

In [None]:
import re, html, unicodedata

def spacy_tokenizer(text, do_normalise=True):
    """
    Custom tokenizer with numeric normalization.
    """
    doc = nlp(text)

    tokens = []
    for token in doc:
        if token.is_punct or token.is_space or token.is_stop:
            continue

        lemma = token.lemma_ if do_normalise else token.text
        tokens.append(lemma)

    return [t for t in tokens if t]
    

def run_nb_pipeline(custom_tokenizer, class_prior = None):
    """Train NB. class_prior = [P(ham), P(spam)] or None to learn from data."""
    vectorizer = CountVectorizer(tokenizer=custom_tokenizer, stop_words=None, token_pattern=None)
    X_train_bow = vectorizer.fit_transform(X_train)
    X_test_bow  = vectorizer.transform(X_test)
    
    print(f"Vocabulary size: {len(vectorizer.get_feature_names_out())}")
    
    nb = MultinomialNB(alpha=1.0, class_prior=class_prior)
    nb.fit(X_train_bow, y_train)
    
    y_pred = nb.predict(X_test_bow)
    print(classification_report(y_test, y_pred))
    
    ConfusionMatrixDisplay.from_estimator(nb, X_test_bow, y_test, cmap="Blues")
    plt.title("Confusion Matrix – Naïve Bayes Spam Classifier with custom pre-processing ")
    plt.show()

    # We print out the top indicative words
    print_indicative_features(nb, vectorizer, topk=300, verbose=False)

    return (nb, vectorizer, y_pred)

nb_a, vec_a, y_pred_a = run_nb_pipeline(spacy_tokenizer)

## 8. Custom features
We can also add more than "word" features to NB. There are different ways to do so, a simpler one being injecting special tokens into NB. For example, we can have a features such as
- `__HAS_EXCLAM__` if the given message has a number of excalamation points !!
- `__MANY_CAPS__` if the messages have more than a given number of uppercase words

We can also clean our text further by collapsing certain patterns that are spelled differently. For example, we can collapse all variations of "terms and conditions", such as "t&c", "ts&cs" that are featured in the dataset and are currently counted as separate tokens. 

We also have numbers being counted separately as individual tokens, e.g., '500', '180', so we can characterise numbers, phone numbers, etc.

In [None]:
URL_STRONG_RE = re.compile(r"""(?ix)
\b
(
  (?:https?://|ftp://)?                                  # optional scheme
  (?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?\.)+           # subdomain(s)
  (?:                                                    # TLD / eTLD
      com|org|net|info|biz|edu|gov|mil|io|me|tv|cc|
      co|uk|co\.uk|ac\.uk|gov\.uk|au|ca|de|fr|es|it|nl|se|no|fi|pl|in|cn|
      us|za|be|ch|at|dk|ru|ie|nz
      | [a-z]{2}                                         # generic ccTLD fallback
  )
  (?:/[^\s]*)?                                          # optional path
)
""")

EXCLAM_RE = re.compile(r"!{2,}")
PHONE_RE     = re.compile(r'(?:\+?\d[\s\-\(\)]?){7,}\d')                   # ~8+ digits total
TNC_RE = re.compile(
    r'\b(?:t.?s?\s*(?:&|and)*\s*c(?:s|\'s)?|terms?\s*(?:&|and)?\s*conditions?)\b',
    re.I
)

from spacy.symbols import ORTH

TAG_RE = re.compile(r"<[^>]+>")
def pre_normalize(text: str) -> str:
    # 1) HTML decode & strip residual tags
    t = html.unescape(text)
    t = TAG_RE.sub(" ", t)

    # 2) Collapse weird entity leftovers and whitespace
    t = re.sub(r"&(?:amp|nbsp);", " ", t, flags=re.I) # annonimized information?
    t = re.sub(r"\s+", " ", t).strip()
    #t = unicodedata.normalize('NFKC', t)
    return t

def replace_signals_in_text(text: str, normalize = True) -> str:
    t = pre_normalize(text) if normalize else text

    t = TNC_RE.sub(" __HAS_TNC__ ", t)
    t = URL_STRONG_RE.sub(" __HAS_URL__ ", t)
    t = PHONE_RE.sub(" __HAS_PHONE__ ", t)
    t = EXCLAM_RE.sub(" __HAS_EXCLAM__ ", t)

    t = re.sub(r"\s+", " ", t).strip()
    return t

def hybrid_tokenizer(raw_text, do_normalise=True, caps_threshold: int = 3):
    rewritten = replace_signals_in_text(raw_text, normalize=False)       
    doc = nlp(rewritten)

    # --- doc-level ALL-CAPS detection (before filtering) ---
    caps_count = sum(
        1 for tok in doc
        if tok.is_alpha and tok.text.isupper() and len(tok.text) >= 2
    )

    tokens = []
    for tok in doc:
        if tok.is_punct or tok.is_space or tok.is_stop:
            continue

        # numeric normalization
        if tok.like_num:
            tokens.append("__IS_DIGIT__")
            continue
        if any(ch.isdigit() for ch in tok.text):
            tokens.append("__HAS_NUM__")
            continue

        # preserve placeholders exactly as written ---
        if tok.text.startswith("__") and tok.text.endswith("__"):
            tokens.append(tok.text)
            continue        

        lemma = tok.lemma_ if do_normalise else tok.text
        tokens.append(lemma)

    # append doc-level caps signal once (no duplicates)
    if caps_count >= caps_threshold:
        tokens.append("__MANY_CAPS__")

    
    return [t for t in tokens if t]

# We inform spacy tokenizer to consider the different placeholders as complete tokens (otherwise might get splitted)
PLACEHOLDERS = [
    "__HAS_TNC__", "__HAS_URL__", "__HAS_EXCLAM__", 
    "__HAS_PHONE__", "__MANY_CAPS__"
]

for ph in PLACEHOLDERS:
    nlp.tokenizer.add_special_case(ph, [{ORTH: ph}])


# Let's run the pipeline (function from before) with our new custom tokenizer with advanced processing and signals
nb_b, vec_b, y_pred_b = run_nb_pipeline(hybrid_tokenizer)

### Inspect errors
The following helper functions let you analyze misclassified messages and understand why the model made a certain prediction. Qualitatively analysing the output of your model is a good way of understanding if there are patters of errors that can be addressed.

Use `.preview_errors_explained()` to scan all test samples, providing the true label you want to explore. Then it selects the messages that were misclassified and displays a table with:
	•	the true and predicted labels,
	•	the model’s estimated probability of spam,
	•	the top positive and negative evidence (words), and
	•	the original text for context.


In [None]:
import numpy as np
import pandas as pd

def explain_message(nb, vectorizer, text, top_k=10, pos_label="spam"):
    """
    Return a dict with posterior, log-odds breakdown, and top contributing features.
    Works for MultinomialNB with a CountVectorizer/TF pipeline.
    """
    # Map class indices
    classes = list(nb.classes_)
    i_pos = classes.index(pos_label)
    i_neg = 1 - i_pos  # binary assumption

    # Vectorize this text
    X = vectorizer.transform([text])               # shape (1, V)
    x = X.tocsr()                                  # ensure CSR
    idx = x.indices
    vals = x.data

    # Feature names and per-class log P(w|c)
    feat = vectorizer.get_feature_names_out()
    log_pw_pos = nb.feature_log_prob_[i_pos]       # shape (V,)
    log_pw_neg = nb.feature_log_prob_[i_neg]

    # Class-log-priors
    lp_pos = nb.class_log_prior_[i_pos]
    lp_neg = nb.class_log_prior_[i_neg]
    prior_diff = lp_pos - lp_neg                   # log-odds prior term

    # Per-feature log-odds weights (spam minus ham)
    w = log_pw_pos - log_pw_neg                    # shape (V,)

    # Contributions for active features: w_i * count_i
    contrib_vals = w[idx] * vals
    contrib_pairs = list(zip(feat[idx], contrib_vals))

    # Split into pro-spam (positive) and pro-ham (negative) evidence
    pos_contrib = sorted([(t, c) for t, c in contrib_pairs if c > 0], key=lambda z: -z[1])[:top_k]
    neg_contrib = sorted([(t, c) for t, c in contrib_pairs if c < 0], key=lambda z: z[1])[:top_k]

    # Total log-odds and posterior
    log_odds = prior_diff + contrib_vals.sum()
    # convert log-odds to probability
    p_pos = 1.0 / (1.0 + np.exp(-log_odds))

    return {
        "p_spam": float(p_pos) if pos_label == "spam" else 1.0 - float(p_pos),
        "log_odds": float(log_odds),
        "prior_log_odds": float(prior_diff),
        "sum_feature_contrib": float(contrib_vals.sum()),
        "top_pos": pos_contrib,    # [(feature, +weight), ...]
        "top_neg": neg_contrib,    # [(feature, -weight), ...]
        "active_count": int(len(idx)),
    }

def preview_errors_explained(X_test, y_test, y_pred, nb, vectorizer,
                             true_label="ham", pos_label="spam", top_k=5):
    # align predictions to y_test index
    y_pred_s = pd.Series(y_pred, index=y_test.index)
    err_mask = (y_test == true_label) & (y_pred_s != true_label)
    err_idx = y_test.index[err_mask]

    rows = []
    for i in err_idx:
        # Handle both DataFrame-with-text-column and plain Series cases
        text = X_test.loc[i, "text"] if hasattr(X_test, "columns") and "text" in X_test.columns else X_test.loc[i]
        exp = explain_message(nb, vectorizer, text, top_k=top_k, pos_label=pos_label)
        rows.append({
            "index": i,
            "true": y_test.loc[i],
            "pred": y_pred_s.loc[i],
            "p_spam": round(exp["p_spam"], 6),
            "log_odds": round(exp["log_odds"], 6),
            "prior_log_odds": round(exp["prior_log_odds"], 6),
            "sum_feature_contrib": round(exp["sum_feature_contrib"], 6),
            "top_pos": exp["top_pos"],
            "top_neg": exp["top_neg"],
            "text": text
        })

    if not rows:
        print("No errors for this selection.")
        return pd.DataFrame(columns=[
            "index","true","pred","p_spam","log_odds",
            "prior_log_odds","sum_feature_contrib","top_pos","top_neg","text"
        ])

    df_exp = pd.DataFrame(rows).sort_values("p_spam", ascending=False)

    with pd.option_context('display.max_colwidth', None, 'display.width', 200):
        display(df_exp[["index","true","pred","p_spam","log_odds","top_pos","top_neg","text"]])

    return df_exp

print("— Errors with simple processing —")
df_err_A = preview_errors_explained(X_test, y_test, y_pred_a, nb_a, vec_a, true_label="ham")

## 9. Discussion
- How does this differ from the **rule-based** filter?  Go back to your implementation and compare performance.
- What is the effect of normalising (pre-processing) the text? Try normalising, not normalising and reflect on the results.
- Reflect on the performance of the model, and engineered features. What do we gain and lose with your implementation of more advanced pre-processing?