In [1]:
import string
import email
import nltk

nltk.download('stopwords')  # ikuya

punctuations = list(string.punctuation)
stopwords = set(nltk.corpus.stopwords.words('english'))
stemmer = nltk.PorterStemmer()

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\mi\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [2]:
## Utility functions to browse a list of words
# ikuya
def browse(L, name='(list)', num_samples=5):
    n = num_samples if len(L)>=num_samples else len(L)
    print('%s: len=%d, samples: %s' % (name, len(L), ' '.join(L[:n])))

In [3]:
browse(punctuations, 'punctuations')
browse(list(stopwords), 'stopwords')

punctuations: len=32, samples: ! " # $ %
stopwords: len=179, samples: after an each doing these


In [6]:
## functions to load and parse emails

# Combine the different parts of the email into a flat list of strings
def flatten_to_string(parts):
    ret = []
    if type(parts) == str:
        ret.append(parts)
    elif type(parts) == list:
        for part in parts:
            ret += flatten_to_string(part)
    elif parts.get_content_type == 'text/plain':
        ret += parts.get_payload()
    return ret

# Extract subject and body text from a single email file
def extract_email_text(path):
    # Load a single email from an input file
    with open(path, errors='ignore') as f:
        msg = email.message_from_file(f)
    if not msg:
        return ""
    
    # Read the email subject
    subject = msg['Subject']
    if not subject:
        subject = ""
    
    # Read the email body
    body = ' '.join(m for m in flatten_to_string(msg.get_payload())
                    if type(m) == str)
    if not body:
        body = ""
    
    return subject + ' ' + body

# Process a single email file into stemmed tokens
PUNCTS = "".join(punctuations)
def load(path):
    email_text = extract_email_text(path)
    if not email_text:
        return []
    
    # Tokenize the message
    tokens = nltk.word_tokenize(email_text)
    
    # Remove punctuation from tokens
    tokens = [i.strip(PUNCTS) for i in tokens if i not in punctuations]
    
    # Remove stopwords and stem tokens
    if len(tokens) > 2:
        return [stemmer.stem(w) for w in tokens if w not in stopwords]
    return []

In [7]:
%%time
## Load dataset

import os

DATA_DIR = 'datasets/trec07p/data/'
LABELS_FILE = 'datasets/trec07p/full/index'
TRAINING_SET_RATIO = 0.7

nltk.download('punkt')  # ikuya

labels = {}
spam_words = set()
ham_words = set()

# Read the labels
with open(LABELS_FILE) as f:
    for line in f:
        line = line.strip()
        label, key = line.split()
        labels[key.split('/')[-1]] = 1 if label.lower()=='ham' else 0

# Split corpus into training and test sets
filelist = os.listdir(DATA_DIR)
num_train = int(len(filelist) * TRAINING_SET_RATIO)
X_train = filelist[:num_train]
X_test  = filelist[num_train:]

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\mi\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Wall time: 118 ms


In [8]:
%%time
# Learn (extract ham_words and spam_words)
for filename in X_train:
    path = os.path.join(DATA_DIR, filename)
    if filename in labels:
        label = labels[filename]
        stems = load(path)
        if not stems:
            continue
        if label == 1:
            ham_words.update(stems)
        elif label == 0:
            spam_words.update(stems)
        else:
            continue
blacklist = spam_words - ham_words

Wall time: 4min 21s


In [35]:
# ham, spam, blacklist の数を見てみる。 blacklist の中身を見てみる。
print("#X_train=%d, #X_test=%d" % (len(X_train), len(X_test)))
browse(list(ham_words), "ham_words")
browse(list(spam_words), "spam_words")
browse(list(blacklist), "blacklist")

#blacklist

#X_train=52793, #X_test=22626
ham_words: len=187154, samples:  gpo_copy_fil www.wbir.com/news/national/story.asp slashdot.org/comments.pl parlato
spam_words: len=122577, samples:  事務局以外でサポートとして、私個人の心情が入りメールしています。 myrangeinternet.com/x/mjuxodm5ndk5|mzezmjg4|chjvzhvjdhrlc3rwyw5lbebzcgvlzhkudxdhdgvybg9vlmnh|mjm0odi=|ng==|||.html ｻﾘｱｨｴﾊｩﾋﾄ｣ｺﾍｨｹﾐｽｳｷﾖｲ羚ﾖﾀ犲ﾍｽ盪ｹｵﾄｲ霻ｻｯﾔｭﾔｹｽｨﾖﾐｸﾟｲ羯ﾜﾀ柀ﾋﾔｱﾐｽｳﾌ袞ｵ www.4ulanesdealer.com/i/svrlha/y13s3gzy/viloy_5971/kaleh_15.jpg
blacklist: len=99287, samples: myrangeinternet.com/x/mjuxodm5ndk5|mzezmjg4|chjvzhvjdhrlc3rwyw5lbebzcgvlzhkudxdhdgvybg9vlmnh|mjm0odi=|ng==|||.html 事務局以外でサポートとして、私個人の心情が入りメールしています。 ｻﾘｱｨｴﾊｩﾋﾄ｣ｺﾍｨｹﾐｽｳｷﾖｲ羚ﾖﾀ犲ﾍｽ盪ｹｵﾄｲ霻ｻｯﾔｭﾔｹｽｨﾖﾐｸﾟｲ羯ﾜﾀ柀ﾋﾔｱﾐｽｳﾌ袞ｵ www.4ulanesdealer.com/i/svrlha/y13s3gzy/viloy_5971/kaleh_15.jpg a/nod/to/the/fact/that/the/university/is/located/on/the/former/site/of/leland/stanford/s/horse/farm.//the/university/s/founding/grant/was/written/on/november/11


In [29]:
# blacklist から明らかに数値やURLっぽいものを除いて、中身を見てみる
import re
nums = re.compile('([\\d:,=]+|\\d[\\da-fA-F]+)')
bb = blacklist.copy()
bb = [w for w in bb if '/' not in w and '.' not in w and '-' not in w]
bb = [w for w in bb if not nums.fullmatch(w)]
bb.sort(reverse=True)
offset, length = int(len(bb)/3), 20
bb[offset:offset+length]  

['rovinci',
 'rovinc',
 'rovidi',
 'rovid',
 'rovesciano',
 'roversi',
 'rovement',
 'roux',
 'routledg',
 'routine痴',
 'roust',
 'rousseau',
 'rouss',
 'roushil',
 'rousettu',
 'rous=3dmze3njeyn0a4ota4qdiynka',
 'rous=3dmze3njeyn0a4nzy3qdiynka',
 'rous=3dmze3njeyn0a4njqzqdiynka',
 'rourk',
 'roup']

In [40]:
def predict(X, blackset):
    n = len(X)
    y_truth = [-1] * n
    y_predict = [-1] * n
    for i in range(n):
        filename = X[i]
        label = labels[filename]
        y_truth[i] = label
        path = os.path.join(DATA_DIR, filename)
        stems = load(path)
        y_predict[i] = 1
        for w in stems:
            if w in blackset:
                y_predict[i] = 0
                break
        #y_predict[i] = int(bool(set(stems) & blackset))
    return (y_truth, y_predict)

In [41]:
%%time

(y_truth, y_predict) = predict(X_test, blacklist)
print("%d/%d" % (sum([a==b for a,b in zip(y_truth, y_predict)]), len(y_truth)))

14276/22626
Wall time: 1min 59s


In [42]:
def percent(n, total):
    return n / total * 100.0
    
def print_result(y_truth, y_predict):
    ham_ham = ham_spam = spam_ham = spam_spam = 0
    confusion_matrix = (ham_ham, ham_spam, spam_ham, spam_spam)
    
    for (truth, predict) in zip(y_truth, y_predict):
        if truth==1:
            if predict==1:
                ham_ham += 1
            else:
                ham_spam += 1
        else:
            if predict==1:
                spam_ham += 1
            else:
                spam_spam += 1
    
    print("#y_truth, #y_predict = %d, %d" % (len(y_truth), len(y_predict)))
    num_predicted = ham_ham + ham_spam + spam_ham + spam_spam
    
    def pct(n):
        return percent(n, num_predicted)
    
    print("             Predicted HAM, Predicted SPAM")
    print("Actual HAM :         %5d,        %5d" % (ham_ham, ham_spam))
    print("Actual SPAM:         %5d,        %5d" % (spam_ham, spam_spam))
    print()
    print("             Predicted HAM, Predicted SPAM")
    print("Actual HAM :         %2.1f%%,        %2.1f%%" % (pct(ham_ham), pct(ham_spam)))
    print("Actual SPAM:         %2.1f%%,        %2.1f%%" % (pct(spam_ham), pct(spam_spam)))
    print()
    print("accuracy = %2.1f%%" % (pct(ham_ham + spam_spam)))

print_result(y_truth, y_predict)


#y_truth, #y_predict = 22626, 22626
             Predicted HAM, Predicted SPAM
Actual HAM :          5436,         1139
Actual SPAM:          7211,         8840

             Predicted HAM, Predicted SPAM
Actual HAM :         24.0%,        5.0%
Actual SPAM:         31.9%,        39.1%

accuracy = 63.1%


In [60]:
# ham/spam ratio in all emails
num_hams  = sum([1 for label in labels.values() if label==1])
num_spams = sum([1 for label in labels.values() if label==0])
num_total = num_hams + num_spams
print("all (ham, spam) = (%d, %d)" % (num_hams, num_spams))
print("all (ham, spam) = (%2.1f%%, %2.1f%%)" % (percent(num_hams, num_total), percent(num_spams, num_total)))

# → 全データの ham/spam 割合はおよそ 33/67

all (ham, spam) = (25220, 50199)
all (ham, spam) = (33.4%, 66.6%)


In [61]:
# ham/spam ratio in test emails
y_test = [labels[fn] for fn in X_test]
num_hams  = sum([1 for label in y_test if label==1])
num_spams = sum([1 for label in y_test if label==0])
num_total = num_hams + num_spams
print("test (ham, spam) = (%d, %d)" % (num_hams, num_spams))
print("test (ham, spam) = (%2.1f%%, %2.1f%%)" % (percent(num_hams, num_total), percent(num_spams, num_total)))

# → テストデータの ham/spam 割合はおよそ 30/70

test (ham, spam) = (6575, 16051)
test (ham, spam) = (29.1%, 70.9%)


In [62]:
y_truth_2 = y_truth
y_predict_2 = [0] * len(y_truth_2)
print_result(y_truth_2, y_predict_2)
# → 全部 spam (0) と判定しても accuracy は 70% になる

#y_truth, #y_predict = 22626, 22626
             Predicted HAM, Predicted SPAM
Actual HAM :             0,         6575
Actual SPAM:             0,        16051

             Predicted HAM, Predicted SPAM
Actual HAM :         0.0%,        29.1%
Actual SPAM:         0.0%,        70.9%

accuracy = 70.9%


In [63]:
import random

y_truth_3 = y_truth
y_predict_3 = [random.randint(0, 1) for _ in y_truth_3]
print_result(y_truth_3, y_predict_3)
# → 0/1 五分五分で判定すると accuracy は 50%

#y_truth, #y_predict = 22626, 22626
             Predicted HAM, Predicted SPAM
Actual HAM :          3265,         3310
Actual SPAM:          7965,         8086

             Predicted HAM, Predicted SPAM
Actual HAM :         14.4%,        14.6%
Actual SPAM:         35.2%,        35.7%

accuracy = 50.2%
