# 5. Linear binary classification

In [1]:
import numpy as np
import matplotlib.pyplot as plt

## 5.9 Implement spam filter

### 5.9.1 Download data

In [2]:
import urllib.request
filename, _ = urllib.request.urlretrieve(
    'https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip',
    'smsspamcollection.zip'
)

In [3]:
import zipfile
with zipfile.ZipFile(filename, 'r') as fi:
    fi.extractall('.')

In [4]:
with open('SMSSpamCollection', encoding="utf-8") as fi:
    for n, line in enumerate(fi):
        if n < 10:
            print(line, end='')
        else:
            break

ham	Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...
ham	Ok lar... Joking wif u oni...
spam	Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's
ham	U dun say so early hor... U c already then say...
ham	Nah I don't think he goes to usf, he lives around here though
spam	FreeMsg Hey there darling it's been 3 week's now and no word back! I'd like some fun you up for it still? Tb ok! XxX std chgs to send, £1.50 to rcv
ham	Even my brother is not like to speak with me. They treat me like aids patent.
ham	As per your request 'Melle Melle (Oru Minnaminunginte Nurungu Vettam)' has been set as your callertune for all Callers. Press *9 to copy your friends Callertune
spam	WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only.
spam	H

### 5.9.2 Load data

In [6]:
import collections

def tokenize(s):
    # decompose sentences into words
    return [t.rstrip('.') for t in s.split(' ')]

def vectorize(tokens):
    # count each word
    return collections.Counter(tokens)

def readiter(fi):
    for line in fi:
        fields = line.strip('\n').split('\t')
        x = vectorize(tokenize(fields[1])) # wordcount for each word in the mail
        y = fields[0] # label
        yield x, y

with open('SMSSpamCollection', encoding="utf-8") as fi:
    D = [d for d in readiter(fi)]

In [7]:
D[6]

(Counter({'Even': 1,
          'my': 1,
          'brother': 1,
          'is': 1,
          'not': 1,
          'like': 2,
          'to': 1,
          'speak': 1,
          'with': 1,
          'me': 2,
          'They': 1,
          'treat': 1,
          'aids': 1,
          'patent': 1}),
 'ham')

In [8]:
from sklearn.model_selection import train_test_split 
Dtrain, Dtest = train_test_split(D, test_size=0.1, random_state=0)

In [10]:
len(Dtrain), len(Dtest)

(5016, 558)

### 5.9.3 Convert each data to other formats

In [12]:
from sklearn.preprocessing import LabelEncoder
from sklearn.feature_extraction import DictVectorizer

VX = DictVectorizer() # convert collections.Counter to sparce vector (scipy.Sparse)
VY = LabelEncoder() # convert each label to integer

Xtrain = VX.fit_transform([d[0] for d in Dtrain])
Ytrain = VY.fit_transform([d[1] for d in Dtrain])
Xtest = VX.transform([d[0] for d in Dtest])
Ytest = VY.transform([d[1] for d in Dtest])

In [13]:
print(Dtrain[10])
print(Xtrain[10])
print(Ytrain[10])

(Counter({'we': 3, 'it': 2, 'have': 2, 'I': 1, 'take': 1, "didn't": 1, 'the': 1, 'phone': 1, 'callon': 1, 'Friday': 1, 'Can': 1, 'assume': 1, "won't": 1, 'this': 1, 'year': 1, 'now?': 1}), 'ham')
  (0, 1831)	1.0
  (0, 2385)	1.0
  (0, 2769)	1.0
  (0, 5546)	1.0
  (0, 6110)	1.0
  (0, 6923)	1.0
  (0, 8101)	2.0
  (0, 8587)	2.0
  (0, 9821)	1.0
  (0, 10231)	1.0
  (0, 11832)	1.0
  (0, 11957)	1.0
  (0, 12014)	1.0
  (0, 12653)	3.0
  (0, 12862)	1.0
  (0, 13030)	1.0
0


In [14]:
print(VX.feature_names_[12653])
print(VY.classes_)

we
['ham' 'spam']


### 5.9.4 Fit logistic regression model to data by SGD method

In [15]:
from sklearn.linear_model import SGDClassifier 
model = SGDClassifier(loss='log')
model.fit(Xtrain, Ytrain)

SGDClassifier(loss='log')

### 5.9.5 Prediction and model evaluation

In [16]:
model.predict(Xtest[0])

array([0], dtype=int64)

In [17]:
model.predict_proba(Xtest[0])

array([[0.99454788, 0.00545212]])

In [18]:
# Accuracy
# (TN + TP) / (TN + TP + FN + FP)
model.score(Xtest, Ytest)

0.9731182795698925

In [19]:
# Apply model to a new input
msg = "Your account has been credited with 500 FREE Text Messages."
model.predict_proba(VX.transform(vectorize(tokenize(msg))))

array([[0.2194571, 0.7805429]])

### 5.9.6 Check model parameter

In [20]:
model.coef_

array([[-0.87028695, -0.24992692, -0.00955703, ...,  0.23285088,
        -0.15255926, -0.00234113]])

In [21]:
F = sorted(zip(VX.feature_names_, model.coef_[0]), key=lambda x: x[1])

In [23]:
F[:20] # if a SMS contains these words, it tend to be regarded as "ham" 

[('me', -1.194638109049447),
 ('&lt;#&gt;', -1.0749890334061118),
 ('him', -1.0744443600719897),
 ("I'll", -1.0549658603306475),
 ('I', -1.0165661117902196),
 ('i', -0.9667037491192448),
 ('Its', -0.9182350542387557),
 ('my', -0.9019855179546786),
 ('good', -0.8805952607782656),
 ('ask', -0.8717738880799215),
 ('', -0.870286951189069),
 ('What', -0.861321765698758),
 ('And', -0.8514916544747598),
 ('way', -0.8376640493879023),
 ('Happy', -0.8364719148845247),
 ('something', -0.8360421876667266),
 ('u', -0.803753016778358),
 ('Yes', -0.7993034259531921),
 ('still', -0.7849465723548167),
 ('always', -0.7844191941847404)]

In [25]:
F[-20:] # if a SMS contains these words, it tend to be regarded as "spam" 


[('FREE>Ringtone!Reply', 1.4981841204024866),
 ('To', 1.5897534636196589),
 ('Reply', 1.5932434796860635),
 ('-', 1.6069742027512757),
 ('Text', 1.617982301285652),
 ("let's", 1.646270972909076),
 ('84484', 1.7236123342228586),
 ('ringtoneking', 1.7236123342228586),
 ('146tf150p', 1.7346711881330812),
 ('2/2', 1.7346711881330812),
 ('text', 1.8075763801739368),
 ('service', 1.8396200086776857),
 ('won', 1.8445201633856914),
 ('mobile', 1.9577346787946288),
 ('&', 1.9911340715306087),
 ('STOP', 2.0132017317644375),
 ('txt', 2.0157868310297453),
 ('now!', 2.0634541524171204),
 ('Txt', 2.0930732307338555),
 ('Call', 2.37362002879792)]

## 5.10 Exercise

In [102]:
# implement logistic model with SGD method
# Use the same Xtrain, Ytrain, Xtest, Ytest as 5.9

def sigmoid(a):
    if a >= 0:
        return 1 / (1 + np.exp(-a))
    else:
        return 1 - 1 / (1 + np.exp(a))


def log_sgd(Xtrain, Ytrain, max_epochs=400000, lr0=0.03, eps=1e-5):
    N = Xtrain.shape[0]
    w = np.zeros(len(VX.feature_names_)) # parameter
    for t in range(max_epochs):
        lr = lr0 / np.sqrt(t + 1) # learning rate
        i = np.random.randint(0, N)
        pt = sigmoid((Xtrain[i] @ w)[0])
        grad = -(Ytrain[i] - pt) * Xtrain[i] # gradient
        if np.sum(np.abs(grad)) < eps:
            break
        w -= lr * grad.toarray()[0] # update parameter
    return w, t


def accuracy(Xtest, Ytest, w):
    N = Xtest.shape[0]
    correct = 0
    for i in range(N):
        wx = (Xtest[i] @ w)[0]
        if (wx > 0 & Ytest[i] == 1) | (wx < 0 & Ytest[i] == 0):
            correct += 1
    return correct / N
        



In [103]:
w, t = log_sgd(Xtrain, Ytrain, max_epochs=100000)
print(f'loop num : {t}')

loop num : 22808


In [104]:
# Accuracy
print(f'accuracy : {accuracy(Xtest, Ytest, w)}')

accuracy : 0.9587813620071685


In [105]:
F = sorted(zip(VX.feature_names_, w), key=lambda x: x[1])

In [106]:
F[:20] # if a SMS contains these words, it tend to be regarded as "ham"

[('I', -0.5057903348095482),
 ('you', -0.4484355483823495),
 ('i', -0.32578881136785337),
 ('', -0.3152527662234237),
 ('u', -0.26173899941214723),
 ('me', -0.25197415357438024),
 ('in', -0.25193862282627605),
 ('my', -0.23806125125348165),
 ('the', -0.23222304945330213),
 ('it', -0.190065643677985),
 ('and', -0.18942464709828516),
 ('that', -0.17279321755165747),
 ('is', -0.14858685869998972),
 ("I'm", -0.13913598468995092),
 ('at', -0.13890915033755252),
 ('not', -0.12581018430000693),
 ('be', -0.11221464347553248),
 ('but', -0.1118362789582083),
 ('to', -0.11082681947534066),
 ('Ok', -0.10943764106786377)]

In [107]:
F[-20:] # if a SMS contains these words, it tend to be regarded as "spam"

[('call', 0.037844736240140564),
 ('Text', 0.03797698973181187),
 ('now!', 0.0393851157161856),
 ('-', 0.03968963275758024),
 ('text', 0.042381186331310064),
 ('Your', 0.042857278466695486),
 ('won', 0.050488142691532456),
 ('contact', 0.05099825015131854),
 ('STOP', 0.05392127664940743),
 ('from', 0.05436170634800177),
 ('Txt', 0.05441386892418338),
 ('txt', 0.06269217325152589),
 ('or', 0.06423881699338911),
 ('prize', 0.06691027880392306),
 ('To', 0.06947837603181871),
 ('mobile', 0.07551175673683519),
 ('claim', 0.08141343474069138),
 ('&', 0.0831991159607921),
 ('FREE', 0.09942219487226042),
 ('Call', 0.11928860930264168)]