In [None]:
%load_ext autoreload
%autoreload 2

from email import policy
from email.parser import BytesParser

from src.spam_classifier.constants import PROJECT_ROOT

DATA_DIR = PROJECT_ROOT / "data"

In [None]:
def parse_email(filepath):
    with open(filepath, "rb") as f:
        msg = BytesParser(policy=policy.default).parse(f)

    try:
        body = ""
        if msg.is_multipart():
            for part in msg.walk():
                if part.get_content_type() == "text/plain" and not part.get_content_disposition():
                    body += f"{part.get_content()} "
                    break
        else:
            body = msg.get_content()
    except Exception as e:
        print(f"Error processing {filepath.name}: {e}")

    return {"body": body, "type": msg.get_content_type()}

In [None]:
ham_emails = []
spam_emails = []
for dir in (DATA_DIR).iterdir():
    if "spam" in dir.name:
        spam_emails += [parse_email(f) for f in dir.iterdir()]
    else:
        ham_emails += [parse_email(f) for f in dir.iterdir()]

In [None]:
from sklearn.base import BaseEstimator, TransformerMixin
from tqdm import tqdm

from src.spam_classifier.mail_class import Mail


class MailTransformer(BaseEstimator, TransformerMixin):
    def __init__(self):
        pass

    def fit(self, X, y=None):
        return self

    def transform(self, X):
        email_list = []
        for email in tqdm(X):
            body, type = email["body"], email["type"]
            email_list.append(Mail(body, type).transform_mail())
        return email_list

In [None]:
class MailVocabulary(BaseEstimator, TransformerMixin):
    def __init__(self, vocab_size: int = 1000):
        self.vocab_size = vocab_size

    def fit(self, X, y=None):
        word_counter = {}
        for word_dict in X:
            for word, count in word_dict.items():
                word_counter[word] = word_counter.get(word, 0) + count
        most_common = list(
            dict(sorted(word_counter.items(), key=lambda x: x[1], reverse=True)).keys()
        )[: self.vocab_size]
        self.vocabulary_ = {word: i for i, word in enumerate(["unknown"] + most_common)}
        return self

    def transform(self, X):
        transformed_X = []
        for word_dict in X:
            email_dict = [
                word_dict[word] if word in word_dict else 0 for word in self.vocabulary_.keys()
            ]
            transformed_X.append(email_dict)

        return transformed_X

In [None]:
X = ham_emails + spam_emails
y = [0 if email in ham_emails else 1 for email in X]

In [None]:
from sklearn.model_selection import train_test_split

random_state = 42

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=random_state
)

X_val, X_test, y_val, y_test = train_test_split(
    X_test, y_test, test_size=0.5, stratify=y_test, random_state=random_state
)

In [None]:
from sklearn.pipeline import Pipeline

pipeline = Pipeline(
    [
        ("email_transformer", MailTransformer()),
        ("mail_vocab", MailVocabulary()),
    ]
)

In [None]:
fitted_pipeline = pipeline.fit(X_train)

In [None]:
X_train, X_val, X_test = [fitted_pipeline.transform(X) for X in [X_train, X_val, X_test]]

In [None]:
from sklearn.linear_model import LogisticRegression

lr_clf = LogisticRegression()

In [None]:
from sklearn.model_selection import GridSearchCV

lr_params = {
    "solver": ["liblinear", "lbfgs"],
    "max_iter": [1000],
}

lr_param_grid = GridSearchCV(
    estimator=lr_clf, param_grid=lr_params, cv=10, scoring="accuracy", verbose=2, n_jobs=-1
).fit(X_train, y_train)

In [None]:
print(lr_param_grid.best_score_)
print(lr_param_grid.best_estimator_)

In [None]:
lr_clf = lr_param_grid.best_estimator_

In [None]:
from sklearn.model_selection import cross_val_score

score = cross_val_score(lr_clf, X_train, y_train, cv=10, scoring="accuracy")
print(score.mean())

In [None]:
from sklearn.metrics import precision_score, recall_score

y_pred = lr_clf.predict(X_test)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
print(f"Precision: {precision*100:.2f}%")
print(f"Recall: {recall*100:.2f}%")