In [None]:
from src.modeling.predict import predict_spam, load_model   
from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    precision_recall_curve,
    classification_report
)
from src.config import PROCESSED_DATA_DIR, MODELS_DIR
from pathlib import Path
import pandas as pd 
from src.modeling.train import FEATURE_COLUMNS

import numpy as np 
import matplotlib.pyplot as plt
import seaborn as sns
from tabulate import tabulate  

In [None]:
TEST_SET_PATH = PROCESSED_DATA_DIR / "test.parquet"

def get_model_path(model_name: str): 
    MODEL_PATH = MODELS_DIR / f"{model_name}.pkl"
    if not MODEL_PATH: 
        raise FileNotFoundError
    return MODEL_PATH


In [None]:
lr_model = load_model(get_model_path('Logistic_regression'))
rf_model = load_model(get_model_path('random_forest'))
xg_model = load_model(get_model_path('xg_boosting'))
test = pd.read_parquet(TEST_SET_PATH)
X_test = test[FEATURE_COLUMNS]
y_test = test['label_encoded']
y_pred_lr = lr_model.predict(X_test)
y_pred_rf = rf_model.predict(X_test)
y_pred_xg = xg_model.predict(X_test)


In [None]:
y_pred_dumb = np.zeros_like(y_test)  # assuming Ham = 0, Spam = 1
def class_report_model(y_pred_model): 
    pr = f"{precision_score(y_test, y_pred_model, pos_label=1, zero_division=0):.4f}"
    re = f"{recall_score(y_test, y_pred_model, pos_label=1):.4f}"
    f1 = f"{f1_score(y_test, y_pred_model, pos_label=1, zero_division=0):.4f}"
    acc = f"{accuracy_score(y_test, y_pred_model):.4f}"
    return (pr, re, f1, acc)

In [None]:
dump_repo = class_report_model(y_pred_dumb)
lr_repo = class_report_model(y_pred_lr)
rf_repo = class_report_model(y_pred_rf)
xg_repo = class_report_model(y_pred_xg)


In [None]:
# Gather results
results = [
    ["MVP Baseline", ">=0.85", ">=0.75", ">=0.80" , "-"],
    ["Dumb Baseline (All Ham)", dump_repo[0], dump_repo[1], dump_repo[2], dump_repo[3]],
    ["Logistic Regression", lr_repo[0], lr_repo[1], lr_repo[2] ,lr_repo[3]],
    ["Random Forest", rf_repo[0], rf_repo[1], rf_repo[2],rf_repo[3]],
    ["XGboosting", xg_repo[0],  xg_repo[1],  xg_repo[2],  xg_repo[3]],
]

print(tabulate(results, headers=["Model", "Precision", "Recall", "F1-Score","Accuracy"], tablefmt="github"))

In [None]:
# Confusion matrix
cm_dump = confusion_matrix(y_test, y_pred_dumb)
cm_lr = confusion_matrix(y_test, y_pred_lr)
cm_rf = confusion_matrix(y_test, y_pred_rf)
cm_xg = confusion_matrix(y_test, y_pred_xg)
plt.figure(figsize=(6, 4))
fig, axes = plt.subplots(2, 2, figsize=(10, 8))

sns.heatmap(cm_dump, ax=axes[0,0],annot=True ,fmt='d', cmap='Blues', 
            xticklabels=['Ham', 'Spam'], 
            yticklabels=['Ham', 'Spam'])
axes[0, 0].set_title('Dump model Confusion Matrix')
axes[0, 0].set_ylabel('True Label')
axes[0, 0].set_xlabel('Predicted Label')
sns.heatmap(cm_lr, ax=axes[0,1],annot=True ,fmt='d', cmap='Blues', 
            xticklabels=['Ham', 'Spam'], 
            yticklabels=['Ham', 'Spam'])

axes[0, 1].set_title('Logistic Regression Confusion Matrix')
axes[0, 1].set_ylabel('True Label')
axes[0, 1].set_xlabel('Predicted Label')
sns.heatmap(cm_rf, ax=axes[1,0],annot=True ,fmt='d', cmap='Blues', 
            xticklabels=['Ham', 'Spam'], 
            yticklabels=['Ham', 'Spam'])

axes[1, 0].set_title('Random Forest Confusion Matrix')
axes[1, 0].set_ylabel('True Label')
axes[1, 0].set_xlabel('Predicted Label')
sns.heatmap(cm_xg, ax=axes[1,1],annot=True ,fmt='d', cmap='Blues', 
            xticklabels=['Ham', 'Spam'], 
            yticklabels=['Ham', 'Spam'])

axes[1, 1].set_title('XGBoosting Confusion Matrix')
axes[1, 1].set_ylabel('True Label')
axes[1, 1].set_xlabel('Predicted Label')
plt.tight_layout()
plt.show()

In [None]:
# Get feature coefficients 
importance = pd.DataFrame({
    'feature': FEATURE_COLUMNS, 
    'coefficient': lr_model.coef_[0]
}).sort_values('coefficient', key=abs, ascending=False)


print("\n🔍 FEATURE IMPORTANCE (Logistic Regression Coefficients):")
print(importance)

In [None]:
# Plot 
plt.figure(figsize=(8, 5))
sns.barplot(data=importance, y='feature', x='coefficient', hue='feature', palette='viridis')
plt.title('Feature Importance (Coefficients)')
plt.xlabel('Coefficient (positive = spammy)')
plt.show() 