In [2]:
# %load ./dsfs/text.py
from typing import Set
import re

def tokenize(text: str) -> Set[str]:
    text = text.lower()
    all_words = re.findall("[a-z0-9']+", text)
    return set(all_words)

assert tokenize("Data Science is science") == {"data", "science", "is"}


In [3]:
from typing import NamedTuple
class Message(NamedTuple):
    text: str
    is_spam: bool

In [4]:
from typing import List, Tuple, Dict, Iterable
import math
from collections import defaultdict
from dsfs.text import tokenize

class NaiveBayesSpamClassifier:
    def __init__(self, k:float = 0.5):
        self.k = k
        self.tokens: Set[str] = set()
        self.token_spam_counts: Dict[str, int] = defaultdict(int)
        self.tokem_ham_counts: Dict[str, int] = defaultdict(int)
        self.spam_messages = self.ham_messages = 0
    
    def train(self, messages: Iterable[Message]) -> None:
        for message in messages:
            if message.is_spam:
                self.spam_messages += 1
            else:
                self.ham_messages += 1
                
            for token in tokenize(message.text):
                self.tokens.add(token)
                if message.is_spam: 
                    self.token_spam_counts[token] += 1
                else:
                    self.tokem_ham_counts[token] += 1
    
    def _probabilities(self, token: str) -> Tuple[float, float]:
        """Return P(tokem|spam) and P(token|ham)"""
        spam = self.token_spam_counts[token]
        ham = self.tokem_ham_counts[token]
        p_spam =(spam + self.k)/(self.spam_messages + 2 * self.k)
        p_ham = (ham + self.k)/(self.ham_messages + 2 * self.k)
        return p_spam, p_ham
    
    def predict(self, text:str) -> float:
        text_tokens = tokenize(text)
        log_prob_if_spam = log_prob_if_ham = 0.0
        for token in self.tokens:
            prob_if_spam, prob_if_ham = self._probabilities(token)
            if token in text_tokens:
                log_prob_if_spam += math.log(prob_if_spam)
                log_prob_if_ham += math.log(prob_if_ham)
            else:
                log_prob_if_spam += math.log(1.0 - prob_if_spam)
                log_prob_if_ham += math.log(1.0 - prob_if_ham)
        
        prob_if_spam = math.exp(log_prob_if_spam)
        prob_if_ham = math.exp(log_prob_if_ham)
        return prob_if_spam/ (prob_if_spam + prob_if_ham)
    

In [5]:
messages = [
    Message("spam rules", is_spam=True),
    Message("ham rules", is_spam=False),
    Message("hello ham", is_spam=False),
]
model = NaiveBayesSpamClassifier()
model.train(messages)

assert model.tokens == {"spam", "rules", "ham", "hello"}
assert model.spam_messages == 1
assert model.ham_messages == 2

assert model.token_spam_counts == {"spam": 1, "rules": 1}
assert model.tokem_ham_counts == {"ham": 2, "rules": 1, "hello": 1}

text = "hello spam"
display(model.predict(text))

0.8350515463917525

In [6]:
from io import BytesIO
import requests
import tarfile

BASE_URL = "https://spamassassin.apache.org/old/publiccorpus"
FILES = [
    "20021010_easy_ham.tar.bz2",
    "20021010_hard_ham.tar.bz2",
    "20021010_spam.tar.bz2"
]

OUTPUT_DIR = "spam_data"
for filename in FILES:
    content = requests.get(f"{BASE_URL}/{filename}").content
    fin = BytesIO(content)
    with tarfile.open(fileobj=fin, mode='r:bz2') as tf:
        tf.extractall(OUTPUT_DIR)

In [7]:
import glob, re
path = f'{OUTPUT_DIR}/*/*'

data: List[Message] = []
for filename in glob.glob(path):
    is_spam = "ham" not in filename
    with open(filename, errors='ignore') as email_file:
        for line in email_file:
            if line.startswith('Subject:'):
                subject = line.lstrip("Subject: ")
                data.append(Message(subject, is_spam=is_spam))
                break

    

In [8]:
import random
from dsfs.ml import split_data

random.seed(0)
train_messages, test_messages = split_data(data, 0.75)
model = NaiveBayesSpamClassifier()
model.train(train_messages)

In [9]:
from collections import Counter
predictions = [(message, model.predict(message.text)) for message in test_messages]
confusion_matrix = Counter((message.is_spam, spam_probabality > 0.5) for message, spam_probabality in predictions)
display(confusion_matrix)

Counter({(False, False): 668,
         (True, True): 85,
         (True, False): 54,
         (False, True): 18})

In [10]:
tp = confusion_matrix[(True, True)]
fn = confusion_matrix[(True, False)]
fp = confusion_matrix[(False, True)]
tn = confusion_matrix[(False, False)]

from dsfs.scoring import f1_score, precision, recall
print(precision(tp=tp, tn=tn, fp=fp, fn=fn))
print(recall(tp=tp, tn=tn, fp=fp, fn=fn))
print(f1_score(tp=tp, fn=fn, tn=tn, fp=fp))

0.8252427184466019
0.6115107913669064
0.7024793388429752
