# Building a Spam Filter with Naive Bayes
To classify messages as spam or non-spam, we saw in the previous lesson that the computer:
1. Learns how humans classify messages.
2. Uses that human knowledge to estimate probabilities for new messages — probabilities for spam and non-spam.
3. Classifies a new message based on these probability values — if the probability for spam is greater, then it classifies the message as spam. Otherwise, it classifies it as non-spam (if the two probability values are equal, then we may need a human to classify the message).

So our first task is to "teach" the computer how to classify messages. To do that, we'll use the multinomial Naive Bayes algorithm along with a dataset of 5,572 SMS messages that are already classified by humans.

The dataset was put together by Tiago A. Almeida and José María Gómez Hidalgo, and it can be downloaded from the The UCI Machine Learning Repository.

In [19]:
import pandas as pd

In [20]:
sms=pd.read_csv('SMSSpamCollection',sep='\t',header=None,names=['Label','SMS'])

In [36]:
# explore the data a bit
sms.head()

Unnamed: 0,Label,SMS
0,ham,"Go until jurong point, crazy.. Available only ..."
1,ham,Ok lar... Joking wif u oni...
2,spam,Free entry in 2 a wkly comp to win FA Cup fina...
3,ham,U dun say so early hor... U c already then say...
4,ham,"Nah I don't think he goes to usf, he lives aro..."


In [23]:
# find how many rows in the dataset
sms.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5572 entries, 0 to 5571
Data columns (total 2 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   Label   5572 non-null   object
 1   SMS     5572 non-null   object
dtypes: object(2)
memory usage: 87.2+ KB


In [25]:
# find the percentage of 'spam' and 'non-spam' messages
sms['Label'].value_counts(normalize=True).mul(100)

ham     86.593683
spam    13.406317
Name: Label, dtype: float64

About 87% of the messages are ham ("ham" means non-spam), and the remaining 13% are spam. 

## Training set and Test Set
We're now going to split our dataset into a training and a test set, where the training set accounts for 80% of the data, and the test set for the remaining 20%.

In [37]:
# Randomize the dataset
sms_randomized=sms.sample(frac=1,random_state=1)

# Calculate index for split
training_index=round(len(sms_randomized)*0.8)

# Training set and Test set split
tr_set=sms_randomized[:training_index].reset_index(drop=True)
test_set=sms_randomized[training_index:].reset_index(drop=True)

print(tr_set.shape)
print(test_set.shape)

(4458, 2)
(1114, 2)


We'll now analyze the percentage of spam and ham messages in the training and test sets. We expect the percentages to be close to what we have in the full dataset, where about 87% of the messages are ham, and the remaining 13% are spam.

In [38]:
tr_set['Label'].value_counts(normalize=True).mul(100)

ham     86.54105
spam    13.45895
Name: Label, dtype: float64

In [34]:
test_set['Label'].value_counts(normalize=True).mul(100)

ham     86.804309
spam    13.195691
Name: Label, dtype: float64

## Data Cleaning
To calculate all the probabilities required by the algorithm, we'll first need to perform a bit of data cleaning to bring the data in a format that will allow us to extract easily all the information we need.

We'll begin with removing all the punctuation and bringing every letter to lower case.

In [39]:
tr_set.head()

Unnamed: 0,Label,SMS
0,ham,"Yep, by the pretty sculpture"
1,ham,"Yes, princess. Are you going to make me moan?"
2,ham,Welp apparently he retired
3,ham,Havent.
4,ham,I forgot 2 ask ü all smth.. There's a card on ...


In [40]:
tr_set['SMS']=tr_set['SMS'].str.replace('\W',' ')
tr_set['SMS']=tr_set['SMS'].str.lower()
tr_set.head()

  tr_set['SMS']=tr_set['SMS'].str.replace('\W',' ')


Unnamed: 0,Label,SMS
0,ham,yep by the pretty sculpture
1,ham,yes princess are you going to make me moan
2,ham,welp apparently he retired
3,ham,havent
4,ham,i forgot 2 ask ü all smth there s a card on ...


## Creating the Vocabulary
Now move to creating the vocabulary, which in this context means a list with all the unique words in our training set.

In [48]:
tr_set['SMS']=tr_set['SMS'].str.split()
test_set.head()

Unnamed: 0,Label,SMS
0,ham,Later i guess. I needa do mcat study too.
1,ham,But i haf enuff space got like 4 mb...
2,spam,Had your mobile 10 mths? Update to latest Oran...
3,ham,All sounds good. Fingers . Makes it difficult ...
4,ham,"All done, all handed in. Don't know if mega sh..."


In [43]:
vocabulary=[]
for sms in tr_set['SMS']:
    for word in sms:
        vocabulary.append(word)
vocabulary=list(set(vocabulary))
len(vocabulary)

7783

It looks like there are 7783 unique words in all messages of our training set.

In [46]:
word_counts_per_sms={unique_word:[0]*len(tr_set['SMS']) for unique_word in vocabulary}
for index, sms in enumerate(tr_set['SMS']):
    for word in sms:
        word_counts_per_sms[word][index]+=1
word_counts=pd.DataFrame(word_counts_per_sms)
word_counts.head()

Unnamed: 0,lautech,jaykwon,girlfrnd,81010,yalrigu,dats,pert,08718723815,drinks,landline,...,tkls,community,splat,aint,doinat,careful,aka,eng,memory,ese
0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [47]:
training_set_clean=pd.concat([tr_set,word_counts],axis=1)
training_set_clean.head()

Unnamed: 0,Label,SMS,lautech,jaykwon,girlfrnd,81010,yalrigu,dats,pert,08718723815,...,tkls,community,splat,aint,doinat,careful,aka,eng,memory,ese
0,ham,"[yep, by, the, pretty, sculpture]",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,ham,"[yes, princess, are, you, going, to, make, me,...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,ham,"[welp, apparently, he, retired]",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,ham,[havent],0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,ham,"[i, forgot, 2, ask, ü, all, smth, there, s, a,...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


## Calculating Constants

In [50]:
spam=training_set_clean[training_set_clean['Label']=='spam']
ham=training_set_clean[training_set_clean['Label']=='ham']

# p(spam) / p(ham)
p_spam=len(spam)/len(training_set_clean)
p_ham=len(ham)/len(training_set_clean)

# N_spam
n_words_per_spam_message=spam['SMS'].apply(len)
n_spam=n_words_per_spam_message.sum()

# N_ham
n_words_per_ham_message=ham['SMS'].apply(len)
n_ham=n_words_per_ham_message.sum()

# N_vocabulary
n_vocabulary=len(vocabulary)

# Laplace smoothing
alpha=1

## Calculating Parameters

In [51]:
# initiate parameters
parameters_spam={unique_word:0 for unique_word in vocabulary}
parameters_ham={unique_word:0 for unique_word in vocabulary}

# Calculate parameters
for word in vocabulary:
    n_word_given_spam=spam[word].sum()
    p_word_given_spam=(n_word_given_spam+alpha)/(n_spam+alpha*n_vocabulary)
    parameters_spam[word]=p_word_given_spam
    
    n_word_given_ham=ham[word].sum()
    p_word_given_ham=(n_word_given_ham+alpha)/(n_ham+alpha*n_vocabulary)
    parameters_ham[word]=p_word_given_ham

## Classifying A New Message
The spam filter can be understood as a function that:

* Takes in as input a new message (w1, w2, ..., wn).
* Calculates P(Spam|w1, w2, ..., wn) and P(Ham|w1, w2, ..., wn).
* Compares the values of P(Spam|w1, w2, ..., wn) and P(Ham|w1, w2, ..., wn), and:
** If P(Ham|w1, w2, ..., wn) > P(Spam|w1, w2, ..., wn), then ** the message is classified as ham.
** If P(Ham|w1, w2, ..., wn) < P(Spam|w1, w2, ..., wn), then the message is classified as spam.
** If P(Ham|w1, w2, ..., wn) = P(Spam|w1, w2, ..., wn), then the algorithm may request human help.

In [55]:
import re
def classify(message):
    message=re.sub('\W',' ',message)
    message=message.lower()
    message=message.split()
    
    p_spam_given_message=p_spam
    p_ham_given_message=p_ham
    for word in message:
        if word in parameters_spam:
            p_spam_given_message*=parameters_spam[word]
            
        if word in parameters_ham:
            p_ham_given_message*=parameters_ham[word]
    print('P(spam|message):', p_spam_given_message)
    print('P(ham|message):', p_ham_given_message)
    
    if p_spam_given_message>p_ham_given_message:
        print('Label: Spam')
    elif p_ham_given_message>p_spam_given_message:
        print('Label: Ham')
    else:
        print('Equal probabilities, have a human classify this!')

In [56]:
classify('''WINNER!! This is the secret code to unlock the money: C3421.''')

P(spam|message): 1.3481290211300841e-25
P(ham|message): 1.9368049028589875e-27
Label: Spam


In [57]:
classify('''Sounds good, Tom, then see u there''')

P(spam|message): 2.4372375665888117e-25
P(ham|message): 3.687530435009238e-21
Label: Ham


## Measuring the Spam Filter's Accuracy
Now try to determine how well the spam filter does on our test set of 1,114 messages.

In [62]:
def classify_test_set(message):
    message=re.sub('\W', ' ', message)
    message=message.lower().split()
    
    p_spam_given_message=p_spam
    p_ham_given_message=p_ham
    
    for word in message:
        if word in parameters_spam:
            p_spam_given_message*=parameters_spam[word]
        if word in parameters_ham:
            p_ham_given_message*=parameters_ham[word]
    if p_spam_given_message > p_ham_given_message:
        return 'spam'
    elif p_ham_given_message > p_spam_given_message:
        return 'ham'
    else:
        return 'needs human classification'


In [63]:
test_set['predicted']=test_set['SMS'].apply(classify_test_set)
test_set.head()

Unnamed: 0,Label,SMS,predicted
0,ham,Later i guess. I needa do mcat study too.,ham
1,ham,But i haf enuff space got like 4 mb...,ham
2,spam,Had your mobile 10 mths? Update to latest Oran...,spam
3,ham,All sounds good. Fingers . Makes it difficult ...,ham
4,ham,"All done, all handed in. Don't know if mega sh...",ham


Now we can compare the predicted values with the actual values to measure how good our spam filter is with classifying new messages.

In [67]:
correct=0
total=len(test_set['SMS'])

for row in test_set.iterrows():
    row=row[1]
    if row['Label']==row['predicted']:
        correct+=1
print('Correct:', correct)
print('Incorrect:', total-correct)
print('Accuracy:', correct/total)

Correct: 1100
Incorrect: 14
Accuracy: 0.9874326750448833


The accuracy is close to 98.74%, which is really good. Our spam filter looked at 1,114 messages that it hasn't seen in training, and classified 1,100 correctly.