In [78]:
from math import log
import glob
from collections import Counter

In [79]:
def get_features(text):
    """Extracts features from text
    
    Args:
        text (str): Any document containing strings
    Returns:
        Set of all the unique words in the document.
    
    """
    
    return set([w.lower() for w in text.split(" ")])

In [82]:
""" initialize all probabilities with None"""
log_priors = None
cond_probs = None
features = None

In [127]:
def train(documents, labels):
    """Train a Bernoulli naive Bayes classifier

    Args:
        documents (list): Each element in this list
            is a text
        labels (list): The ground truth label for
            each document
    """
    global log_priors
    global cond_probs
    global features     
    
    
    """Compute log( P(Y) )"""
    label_counts = Counter(labels)
    N = float(sum(label_counts.values()))
    log_priors = {k: log(v/N) for k, v in label_counts.items()}

    
    """Feature extraction"""
    # Extract features from each document
    X = [set(get_features(d)) for d in documents]      # Vectorize X

    # Get all features
    features = set([f for features in X for f in features])

    
    """Compute log( P(X|Y) )

       Use Laplace smoothing v + 1 / N + 2)"""
    
    
    # Structure for conditional Probabs
    cond_probs = {l: {f: 0. for f in features} for l in log_priors}

    # Step through each document - fill cond_probabs
    for f in features:
        for x, l in zip(X, labels):
            if f in x:
                cond_probs[l][f] += 1.

                
    # Now, compute log probs
    for l in cond_probs:
        N = label_counts[l]
        cond_probs[l] = {f: (v + 1.) / (N + 2.) for f, v in cond_probs[l].items()}

In [132]:
def predict(text):
    """Make a prediction from text"""

    global log_priors
    global cond_probs
    global features     

    
    
    # Extract features
    x = get_features(text)

    pred_class = None
    max_ = float("-inf")

    # Compute posterior probability for all classes
    for l in log_priors:
        log_sum = log_priors[l]
        for f in features:
            prob = cond_probs[l][f]
            log_sum += log(prob if f in x else 1. - prob)
        if log_sum > max_:
            max_ = log_sum
            pred_class = l

    return pred_class

In [133]:
def get_labeled_data(type_):

    examples = []
    labels = []

    file_names = glob.glob('./ex6DataEmails/spam-{0}/*.txt'.format(type_))
    for n in file_names:
        f = open(n)
        examples.append(f.read())
        labels.append('spam')
        f.close()

    file_names = glob.glob('./ex6DataEmails/nonspam-{0}/*.txt'.format(type_))
    for n in file_names:
        f = open(n)
        examples.append(f.read())
        labels.append('nonspam')
        f.close()

    return examples, labels

In [134]:
train_docs, train_labels = get_labeled_data('train')
test_docs, test_labels = get_labeled_data('test')

# Train model
print('Number of training examples: {0}'.format(len(train_labels)))
print('Number of test examples: {0}'.format(len(test_labels)))

Number of training examples: 700
Number of test examples: 260


In [135]:

print('Training model...')




train(train_docs, train_labels)



print('Training complete!')


print('Number of features found: {0}'.format(len(features)))



Training model...
Training complete!
Number of features found: 19100


In [140]:
print(test_labels[0])
print(predict(test_docs[0]))

spam
spam


In [141]:
# Simple error test metric
print('Testing model...')

f = lambda doc, l: 1. if predict(doc) != l else 0.
num_missed = sum([f(doc, l) for doc, l in zip(test_docs, test_labels)])


N = len(test_labels) * 1.
error_rate = round(100. * (num_missed / N), 3)

print('Error rate of {0}% ({1}/{2})'.format(error_rate, int(num_missed), int(N)))

Testing model...
Error rate of 4.231% (11/260)


In [113]:
a = ["apple", "mango", "banana"]
b = [40, 50, 90]
c = ["planet", "earth" , "moon"]

In [115]:
for k in zip(a,b, c):
    print(k)

('apple', 40, 'planet')
('mango', 50, 'earth')
('banana', 90, 'moon')
