In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
import random
from math import *

In [2]:
pattern = os.path.join('tr', "*.txt")
spam_files = glob.glob(pattern)
spam_counts = {}

pattern = os.path.join('ham', "*.txt")
ham_files = glob.glob(pattern)
ham_counts = {}

def get_counts(counts, files, sample_size=256):
    #ix = torch.randint(0, len(files), (sample_size,)).tolist()
    ix = range(0, sample_size)
    for x in ix:
        fp = files[x]
        f = open(fp, 'r')
        lines = f.read().split('\n')[5:]
        for line in lines:
            words = line.split(' ')
            words = [w.replace('\t', '').replace('!', '').replace('.', '').replace(',', '').replace('?', '').replace('*', '').lower() for w in words]
            for w in words:
                if counts.get(w): counts[w] += 1
                else: counts[w] = 1
                    
get_counts(spam_counts, spam_files, sample_size=len(spam_files))
get_counts(ham_counts, ham_files, sample_size=len(ham_files))

In [3]:
ss = sum(spam_counts.values())
print(ss)

# WHITELIST
spam_counts[' '] = 1
spam_counts['͏'] = 1
spam_counts[''] = 1
spam_counts['\u2007͏'] = 1
spam_counts['\u200c'] = 1
spam_counts['\xad'] = 1

remove = []
for (k,v) in spam_counts.items():
    if v/ss > 0.02 or v/ss <= 3/ss: remove.append(k)
for r in remove: spam_counts.pop(r)

sorted_spam_counts = sorted(spam_counts.items(), key=lambda x: x[1], reverse=True)
print(sorted_spam_counts[:10])

309383
[('and', 5543), ('to', 5311), ('the', 5095), ('of', 3977), ('you', 2677), ('for', 2623), ('a', 2583), ('on', 2399), ('in', 2248), ('your', 1962)]


In [4]:
sh = sum(ham_counts.values())
print(sh)

sorted_ham_counts = sorted(ham_counts.items(), key=lambda x: x[1], reverse=True)
print(sorted_ham_counts[:10])

638093
[('', 58252), ('͏', 43974), ('and', 9919), ('the', 8977), ('\u200c', 7819), ('to', 7090), ('for', 5733), ('a', 5388), ('of', 5151), ('off', 5135)]


In [5]:
# determine likelihood that the message is spam given that 'w' appears in it
def spam_prob(w):
    ss = sum(spam_counts.values())
    sh = sum(ham_counts.values())
    ps = 0
    if spam_counts.get(w): ps = spam_counts[w]
    ph = 0
    if ham_counts.get(w): ph = ham_counts[w]
    if ph == 0 and ps == 0: return 0.5
    p = (ps / ss) / ((ps / ss) + (ph / sh))
    return p

In [6]:
# words in the spam corpus that are likely spam
spam_words = []
for item in sorted_spam_counts:
    if spam_prob(item[0]) > 0.90:
        spam_words.append(item[0])

for w in spam_words[:5]: print(w)
print(len(spam_words))

university
best
admissions
campus
sponsored
1919


In [7]:
# words in the spam corpus that are likely not spam
not_spam = []
for item in sorted_spam_counts:
    if spam_prob(item[0]) <= 0.90:
        not_spam.append(item[0])
        
for w in not_spam[:5]: print(w)
print(len(not_spam))

and
to
the
of
you
2636


In [11]:
spam_probs = []
for fp in ham_files:
    f = open(fp, 'r')
    lines = f.read().split('\n')[5:]
    for line in lines:
        words = line.split(' ')
        words = [w.replace('\t', '').replace('!', '').replace('.', '').replace(',', '').replace('?', '').replace('*', '').lower() for w in words]
        for w in words:
            spam_probs.append(spam_prob(w))
    spam_likelihood = sum(spam_probs) / len(spam_probs)
    # most ham messages don't exceed 0.4
    if spam_likelihood > 0.39:
        print(spam_likelihood)

0.39958163839324046


In [12]:
spam_probs = []
for fp in spam_files:
    f = open(fp, 'r')
    lines = f.read().split('\n')[5:]
    for line in lines:
        words = line.split(' ')
        words = [w.replace('\t', '').replace('!', '').replace('.', '').replace(',', '').replace('?', '').replace('*', '').lower() for w in words]
        for w in words:
            spam_probs.append(spam_prob(w))
    spam_likelihood = sum(spam_probs) / len(spam_probs)
    # most spam messages don't go below 0.44
    if spam_likelihood < 0.44:
        print(spam_likelihood)

### Further Optimizations
We can just ignore all words that have a spamicity of around $0.5$, ie: words that appear around the same amount of times
in both of our datasets. 
