In [31]:
from typing import List

from pathlib import Path
import random
import re
import tarfile
import urllib

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

In [20]:
def is_spamassassin_sample(filename: Path):
    # Spamassassin samples have a filename which is a serial number followed by a md5sum
    # of the contents of the email.
    return bool(re.match("^\d{5}.[0-9a-f]{32}", filename.absolute().name))

def load_samples_from_tarball(dataset_name: str, tarball_name: str) -> List[Path]:
    tarball_path = Path("datasets/" + tarball_name)
    if not tarball_path.is_file():
        Path("datasets").mkdir(parents=True, exist_ok=True)
        url = "https://spamassassin.apache.org/old/publiccorpus/" + tarball_name
        urllib.request.urlretrieve(url, tarball_path)

    if not Path("datasets/" + dataset_name).is_dir():
        Path("datasets/" + dataset_name).mkdir(parents=True, exist_ok=True)

    already_extracted = False
    for filename in Path("datasets/" + dataset_name).iterdir():
        if is_spamassassin_sample(filename):
            already_extracted = True
            break

    if not already_extracted:
        with tarfile.open(tarball_path) as tarball:
            tarball.extractall(path="datasets")

    result = []
    for filename in Path("datasets/" + dataset_name).iterdir():
        if is_spamassassin_sample(filename):
            result.append(filename)

    return result

def load_samples():
    return {
        "easy_ham": load_samples_from_tarball("easy_ham", "20030228_easy_ham.tar.bz2"),
        "easy_ham_2": load_samples_from_tarball("easy_ham_2", "20030228_easy_ham_2.tar.bz2"),
        "hard_ham": load_samples_from_tarball("hard_ham", "20030228_hard_ham.tar.bz2"),
        "spam": load_samples_from_tarball("spam", "20030228_spam.tar.bz2"),
        "spam_2": load_samples_from_tarball("spam_2", "20050311_spam_2.tar.bz2"),
    }
    
raw_data = load_samples()
for dataset_name, filenames in sorted(out.items()):
    print(f"{dataset_name} has {len(filenames)} samples")

easy_ham has 2500 samples
easy_ham_2 has 1400 samples
hard_ham has 250 samples
spam has 500 samples
spam_2 has 1396 samples


In [41]:
# 0xCD / 0x100 = 0.80078125, so if the last 2 digits of the md5sum are
# less than CD, then then it's part of the training set. Otherwise it's
# in the test set.
def is_filename_in_training_set(filename: Path):
    assert is_spamassassin_sample(filename)
    tail = filename.absolute().name[-2:]
    return int('0x' + tail, 16) < 0xCD

train_filenames = []
train_labels = []
test_filenames = []
test_labels = []

count_train_spams = 0
count_test_spams = 0

for dataset_name, filenames in sorted(out.items()):
    for filename in filenames:
        is_spam = "spam" in dataset_name
        if is_filename_in_training_set(filename):
            train_filenames.append(filename)
            train_labels.append(is_spam)
            if is_spam:
                count_train_spams += 1
        else:
            test_filenames.append(filename)
            test_labels.append(is_spam)
            if is_spam:
                count_test_spams += 1

train_percent_spam = count_train_spams / (len(train_filenames))
test_percent_spam = count_test_spams / (len(test_filenames))
test_percent = len(test_filenames) / (len(train_filenames) + len(test_filenames))

print(f"Training set is {len(train_filenames)} samples with {train_percent_spam*100:.2f}% spam")
print(f"Test set is {len(test_filenames)} samples with {test_percent_spam*100:.2f}% spam")
print(f"Test set is {test_percent*100:.2f}% of the samples")

train_shuffled = list(zip(train_filenames, train_labels))
random.shuffle(train_shuffled)
train_filenames, train_labels = zip(*train_shuffled)

test_shuffled = list(zip(test_filenames, test_labels))
random.shuffle(test_shuffled)
test_filenames, test_labels = zip(*test_shuffled)

Training set is 4834 samples with 31.63% spam
Test set is 1212 samples with 30.28% spam
Test set is 20.05% of the samples
