In [2]:
import string
import email
import nltk

nltk.download('stopwords')  # ikuya

punctuations = list(string.punctuation)
stopwords = set(nltk.corpus.stopwords.words('english'))
stemmer = nltk.PorterStemmer()

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\mi\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [3]:
## Utility functions to browse a list of words
# ikuya
def browse(L, name='(list)', num_samples=5):
    n = num_samples if len(L)>=num_samples else len(L)
    print('%s: len=%d, samples: %s' % (name, len(L), ' '.join(L[:n])))

In [4]:
browse(punctuations, 'punctuations')
browse(list(stopwords), 'stopwords')

punctuations: len=32, samples: ! " # $ %
stopwords: len=179, samples: until yourself all mustn't him


In [5]:
## functions to load and parse emails

# Combine the different parts of the email into a flat list of strings
def flatten_to_string(parts):
    ret = []
    if type(parts) == str:
        ret.append(parts)
    elif type(parts) == list:
        for part in parts:
            ret += flatten_to_string(part)
    elif parts.get_content_type == 'text/plain':
        ret += parts.get_payload()
    return ret

# Extract subject and body text from a single email file
def extract_email_text(path):
    # Load a single email from an input file
    with open(path, errors='ignore') as f:
        msg = email.message_from_file(f)
    if not msg:
        return ""
    
    # Read the email subject
    subject = msg['Subject']
    if not subject:
        subject = ""
    
    # Read the email body
    body = ' '.join(m for m in flatten_to_string(msg.get_payload())
                    if type(m) == str)
    if not body:
        body = ""
    
    return subject + ' ' + body

# Process a single email file into stemmed tokens
def load(path):
    email_text = extract_email_text(path)
    if not email_text:
        return []
    
    # Tokenize the message
    tokens = nltk.word_tokenize(email_text)
    
    # Remove punctuation from tokens
    tokens = [i.strip("".join(punctuations)) for i in tokens
             if i not in punctuations]
    
    # Remove stopwords and stem tokens
    if len(tokens) > 2:
        return [stemmer.stem(w) for w in tokens if w not in stopwords]
    return []


In [6]:
## Load dataset

import os

DATA_DIR = 'datasets/trec07p/data/'
LABELS_FILE = 'datasets/trec07p/full/index'
TRAINING_SET_RATIO = 0.7

nltk.download('punkt')  # ikuya

labels = {}
spam_words = set()
ham_words = set()

# Read the labels
with open(LABELS_FILE) as f:
    for line in f:
        line = line.strip()
        label, key = line.split()
        labels[key.split('/')[-1]] = 1 if label.lower()=='ham' else 0

# Split corpus into training and test sets
filelist = os.listdir(DATA_DIR)
num_train = int(len(filelist) * TRAINING_SET_RATIO)
X_train = filelist[:num_train]
X_test  = filelist[num_train:]

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\mi\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


- Jaccard 類似度
    - $J(S, T) = |S \cap T| / |S \cup T|$
    - まったく違う（共通要素がない）なら 0、まったく同じなら 1、と正規化される
    - S, T の要素数に大きな差がある（＝一方が極端に大きい）と直感に反した値になる
- MinHash
    - 二つの集合の Jaccard 類似度の近似値を得る手法。集合の sketch のみあれば計算できるので、集合そのものを保存するより保存領域が少なくて済む。
    - Andrei Z. Broder, http://cs.brown.edu/courses/cs253/papers/nearduplicate.pdf, 2000
    - 集合$S$の部分集合$S_A$と$S_B$があったとき、$S_A$と$S_B$のJaccard類似度は、ランダム置換 $\pi$ が $min(\pi(S_A)) == min(\pi(S_B))$ である確率に等しい。（$\pi$は$S→S$のランダムな置換）
    - そこで、たとえば $n=100$ 個の置換 ${ \pi_1, \pi_2, \ldots, \pi_n }$に対する ${ min(\pi_1(S_A)), min(\pi_2(S_A)), \ldots, min(\pi_n(S_A)) }$を集合 $S_A$ のsketchとし、同様に計算した $S_B$ のsketchといくつの要素が一致するか調べると、それがJaccard類似度の近似となる。
    - 実用的には S の要素は n-shingle (単語版の n-gram で、連続した n 語の列）などを使う（文書の場合）。
    - 実用的にはランダム置換 $\pi$ に（ソルト入りの）ハッシュ関数を使う。ハッシュ値の最小値を保存するため MinHash と呼ばれる（たぶん）。
- Locality Sensitive Hashing (LSH)
    - 概要
        - 対象のアイテム群の中から、クエリのアイテムと距離が近い（＝類似した）アイテム群（の部分集合）を効率よく（ただし確率的に）見つける方法。
        - 文書検索などにおいて、大量（n個）の文書群から類似した文書を見つけるには、素朴には少なくとも O(n) の比較が必要だが、これを速くしたい。二分検索等が使えるハッシュ値であれば高速に見つけることができる、という性質を使う。
    - 手法
        - 基本的なアイデアは、アイテムとアイテムが近ければ一致する確率が高い locality sensitive なハッシュ関数 h(.) を複数個用意し、各アイテムに対して複数のハッシュ値を保存しておいて、クエリアイテムのハッシュ値群と多く一致すれば類似していると見なす。
        - locality sensitive なハッシュ関数は、アイテムの性質やアイテム間での距離の定義に応じて設計する必要がある。


In [11]:
%%time
from datasketch import MinHash, MinHashLSH

# Extract only spam files for inserting into the LSH matcher
spam_files = [x for x in X_train if labels[x] == 0]

# Initialize MinHashLSH matcher with a Jaccard
# threshold of 0.5 and 128 MinHash Permutation functions
lsh = MinHashLSH(threshold=0.5, num_perm=128)

# Populate the LSH matcher with training spam MinHashes
for idx, f in enumerate(spam_files):
    #print(idx, f)
    minhash = MinHash(num_perm=128)
    stems = load(os.path.join(DATA_DIR, f))
    if len(stems) < 2: continue
    for s in stems:
        minhash.update(s.encode('utf-8'))
    lsh.insert(f, minhash)

Wall time: 2min 46s


In [12]:
def lsh_predict_label(lsh, stems):
    '''
    Queries the LSH matcher and returns:
        0 if predicted spam
        1 if predicted ham
       -1 if parsing error
    '''
    minhash = MinHash(num_perm=128)
    if len(stems) < 2:
        return -1
    for s in stems:
        minhash.update(s.encode('utf-8'))
    matches = lsh.query(minhash)
    if matches:
        return 0
    else:
        return 1

In [18]:
def predict(X, lsh):
    n = len(X)
    y_truth = [-1] * n
    y_predict = [-1] * n
    for i in range(n):
        filename = X[i]
        label = labels[filename]
        y_truth[i] = label
        path = os.path.join(DATA_DIR, filename)
        stems = load(path)
        y_predict[i] = lsh_predict_label(lsh, stems)
    return (y_truth, y_predict)

In [20]:
%%time

(y_truth, y_predict) = predict(X_test, lsh)
print("%d/%d" % (sum([a==b for a,b in zip(y_truth, y_predict)]), len(y_truth)))

15760/22626
Wall time: 4min 6s


In [21]:
def percent(n, total):
    return n / total * 100.0
    
def print_result(y_truth, y_predict):
    ham_ham = ham_spam = spam_ham = spam_spam = 0
    confusion_matrix = (ham_ham, ham_spam, spam_ham, spam_spam)
    
    for (truth, predict) in zip(y_truth, y_predict):
        if truth==1:
            if predict==1:
                ham_ham += 1
            else:
                ham_spam += 1
        else:
            if predict==1:
                spam_ham += 1
            else:
                spam_spam += 1
    
    print("#y_truth, #y_predict = %d, %d" % (len(y_truth), len(y_predict)))
    num_predicted = ham_ham + ham_spam + spam_ham + spam_spam
    
    def pct(n):
        return percent(n, num_predicted)
    
    print("             Predicted HAM, Predicted SPAM")
    print("Actual HAM :         %5d,        %5d" % (ham_ham, ham_spam))
    print("Actual SPAM:         %5d,        %5d" % (spam_ham, spam_spam))
    print()
    print("             Predicted HAM, Predicted SPAM")
    print("Actual HAM :         %2.1f%%,        %2.1f%%" % (pct(ham_ham), pct(ham_spam)))
    print("Actual SPAM:         %2.1f%%,        %2.1f%%" % (pct(spam_ham), pct(spam_spam)))
    print()
    print("accuracy = %2.1f%%" % (pct(ham_ham + spam_spam)))

print_result(y_truth, y_predict)


#y_truth, #y_predict = 22626, 22626
             Predicted HAM, Predicted SPAM
Actual HAM :          6439,          136
Actual SPAM:          5283,        10768

             Predicted HAM, Predicted SPAM
Actual HAM :         28.5%,        0.6%
Actual SPAM:         23.3%,        47.6%

accuracy = 76.0%


In [22]:
# ham/spam ratio in all emails
num_hams  = sum([1 for label in labels.values() if label==1])
num_spams = sum([1 for label in labels.values() if label==0])
num_total = num_hams + num_spams
print("all (ham, spam) = (%d, %d)" % (num_hams, num_spams))
print("all (ham, spam) = (%2.1f%%, %2.1f%%)" % (percent(num_hams, num_total), percent(num_spams, num_total)))

# → 全データの ham/spam 割合はおよそ 33/67

all (ham, spam) = (25220, 50199)
all (ham, spam) = (33.4%, 66.6%)


In [23]:
# ham/spam ratio in test emails
y_test = [labels[fn] for fn in X_test]
num_hams  = sum([1 for label in y_test if label==1])
num_spams = sum([1 for label in y_test if label==0])
num_total = num_hams + num_spams
print("test (ham, spam) = (%d, %d)" % (num_hams, num_spams))
print("test (ham, spam) = (%2.1f%%, %2.1f%%)" % (percent(num_hams, num_total), percent(num_spams, num_total)))

# → テストデータの ham/spam 割合はおよそ 30/70

test (ham, spam) = (6575, 16051)
test (ham, spam) = (29.1%, 70.9%)


In [62]:
y_truth_2 = y_truth
y_predict_2 = [0] * len(y_truth_2)
print_result(y_truth_2, y_predict_2)
# → 全部 spam (0) と判定しても accuracy は 70% になる

#y_truth, #y_predict = 22626, 22626
             Predicted HAM, Predicted SPAM
Actual HAM :             0,         6575
Actual SPAM:             0,        16051

             Predicted HAM, Predicted SPAM
Actual HAM :         0.0%,        29.1%
Actual SPAM:         0.0%,        70.9%

accuracy = 70.9%


In [63]:
import random

y_truth_3 = y_truth
y_predict_3 = [random.randint(0, 1) for _ in y_truth_3]
print_result(y_truth_3, y_predict_3)
# → 0/1 五分五分で判定すると accuracy は 50%

#y_truth, #y_predict = 22626, 22626
             Predicted HAM, Predicted SPAM
Actual HAM :          3265,         3310
Actual SPAM:          7965,         8086

             Predicted HAM, Predicted SPAM
Actual HAM :         14.4%,        14.6%
Actual SPAM:         35.2%,        35.7%

accuracy = 50.2%
