# Chapter 3 - Spam Classifier

*The task is to predict whether an email is spam or not spam.*

*In this challenge, we are asked to build a predictive model that predicts whether an email is spam or not spam using text from the email, and potentially email metadata i.e. headers*

*This notebook contains all the code for tackling this problem.*

*We will be using examples of spam and ham from Apache SpamAssassin's [public datasets](https://homl.info/spamassassin)*

# Setup

First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20.

In [1]:
# Python ≥3.5 is required
import sys
assert sys.version_info >= (3, 5)

# Scikit-Learn ≥0.20 is required
import sklearn
assert sklearn.__version__ >= "0.20"

# Common imports
import numpy as np
import os

# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

# Where to save the figures
PROJECT_ROOT_DIR = "."
CHAPTER_ID = "titantic"
IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID)
os.makedirs(IMAGES_PATH, exist_ok=True)

def save_fig(fig_id, tight_layout=True, fig_extension="png", resolution=300):
    path = os.path.join(IMAGES_PATH, fig_id + "." + fig_extension)
    print("Saving figure", fig_id)
    if tight_layout:
        plt.tight_layout()
    plt.savefig(path, format=fig_extension, dpi=resolution)

# Ignore useless warnings (see SciPy issue #5998)
import warnings
warnings.filterwarnings(action="ignore", message="^internal gelsd")

# Get the data

First let's fetch the data

In [60]:
import os
import tarfile
import urllib
from bs4 import BeautifulSoup
import requests
import shutil

DOWNLOAD_ROOT = "https://spamassassin.apache.org/old/publiccorpus/"
EXT = ('bz2', 'html')
SPAM_PATH = os.path.join("datasets", "spam")

def fetch_spam_data(spam_url=DOWNLOAD_ROOT, spam_path=SPAM_PATH):
    if not os.path.isdir(spam_path):
        os.makedirs(spam_path)
    for file in listFD(spam_url, EXT):
        filename = file.rsplit('/', maxsplit=1)[-1]
        tgz_path = os.path.join(spam_path, filename)
        urllib.request.urlretrieve(file, tgz_path)
        if file.rsplit('.', maxsplit=1)[-1] == "html":
            continue
        spam_tgz = tarfile.open(tgz_path)
        spam_tgz.extractall(path=spam_path)
        spam_tgz.close()

def listFD(DOWNLOAD_ROOT, ext=''):
    page = requests.get(DOWNLOAD_ROOT).text
    soup = BeautifulSoup(page, 'html.parser')
    return [DOWNLOAD_ROOT + '/' + node.get('href') for node in soup.find_all('a') if node.get('href').endswith(ext)]

def cleanFolders():
    # move all contents of anything with a _2 to the same with the _2
    dir_list = [ name for name in os.listdir(SPAM_PATH) if os.path.isdir(os.path.join(SPAM_PATH, name)) ]
    for dir in dir_list:
        if "_2" in dir:
            source_dir = dir
            target_dir = "test_dir"
            os

In [45]:
fetch_spam_data()

*Note at this point I manually cleaned the folders to amalgamate the _2 folders with their precedent folders

Next let's load all the emails

In [76]:
HAM_DIR = os.path.join(SPAM_PATH, "easy_ham")
SPAM_DIR = os.path.join(SPAM_PATH, "spam")
ham_filenames = [name for name in sorted(os.listdir(HAM_DIR)) if len(name) > 20]
spam_filenames = [name for name in sorted(os.listdir(SPAM_DIR)) if len(name) > 20]

In [77]:
len(ham_filenames)

6451

In [78]:
len(spam_filenames)

2398

We can use Python's email module to parse emails (this handles headers, encoding, and so on):

In [106]:
import email
import email.policy

def load_email(is_spam, filename, spam_path=SPAM_PATH):
    directory = "spam" if is_spam else "easy_ham"
    with open(os.path.join(spam_path, directory, filename), "rb") as f:
        return email.parser.BytesParser(policy=email.policy.default).parse(f)

In [103]:
ham_emails = [load_email(is_spam=False, filename=name) for name in ham_filenames]
spam_emails = [load_email(is_spam=True, filename=name) for name in spam_filenames]

Let's look at one example of ham and one example of spam, to get a feel of what the data looks like:

In [104]:
print(ham_emails[1].get_content().strip())

Date:        Wed, 21 Aug 2002 10:54:46 -0500
    From:        Chris Garrigues <cwg-dated-1030377287.06fa6d@DeepEddy.Com>
    Message-ID:  <1029945287.4797.TMDA@deepeddy.vircio.com>


  | I can't reproduce this error.

For me it is very repeatable... (like every time, without fail).

This is the debug log of the pick happening ...

18:19:03 Pick_It {exec pick +inbox -list -lbrace -lbrace -subject ftp -rbrace -rbrace} {4852-4852 -sequence mercury}
18:19:03 exec pick +inbox -list -lbrace -lbrace -subject ftp -rbrace -rbrace 4852-4852 -sequence mercury
18:19:04 Ftoc_PickMsgs {{1 hit}}
18:19:04 Marking 1 hits
18:19:04 tkerror: syntax error in expression "int ...

Note, if I run the pick command by hand ...

delta$ pick +inbox -list -lbrace -lbrace -subject ftp -rbrace -rbrace  4852-4852 -sequence mercury
1 hit

That's where the "1 hit" comes from (obviously).  The version of nmh I'm
using is ...

delta$ pick -version
pick -- nmh-1.0.4 [compiled on fuchsia.cs.mu.OZ.AU at Sun Mar 17 14:55:56 

In [114]:
print(spam_emails[18].get_content().strip())

KeyError: 'multipart/mixed'

Some emails are actually multipart, with images and attachments (which can have their own attachments). Let's look at the various types of structures we have:

In [116]:
def get_email_structure(email):
    if isinstance(email, str):
        return email
    payload = email.get_payload()
    if isinstance(payload, list):
        return "multipart({})".format(", ".join([
            get_email_structure(sub_email)
            for sub_email in payload
        ]))
    else:
        return email.get_content_type()

In [117]:
from collections import Counter

def structures_counter(emails):
    structures = Counter()
    for email in emails:
        structure = get_email_structure(email)
        structures[structure] += 1
    return structures

In [118]:
structures_counter(ham_emails).most_common()

[('text/plain', 6204),
 ('multipart(text/plain, application/pgp-signature)', 173),
 ('multipart(text/plain, text/html)', 28),
 ('multipart(text/plain, text/plain)', 8),
 ('multipart(text/plain)', 6),
 ('multipart(text/plain, application/octet-stream)', 4),
 ('multipart(text/plain, application/ms-tnef, text/plain)', 3),
 ('multipart(text/plain, multipart(text/plain))', 3),
 ('multipart(text/plain, multipart(text/plain, text/plain), text/rfc822-headers)',
  3),
 ('multipart(text/plain, text/enriched)', 2),
 ('text/html', 2),
 ('multipart(multipart(text/plain, text/plain, text/plain), application/pgp-signature)',
  2),
 ('multipart(text/plain, video/mng)', 2),
 ('multipart(text/plain, application/x-pkcs7-signature)', 2),
 ('multipart(text/plain, multipart(text/plain, text/plain), multipart(multipart(text/plain, application/x-pkcs7-signature)))',
  2),
 ('multipart(text/plain, application/x-java-applet)', 2),
 ('multipart(text/plain, application/x-patch)', 1),
 ('multipart(multipart(text/p

In [119]:
structures_counter(spam_emails).most_common()

[('text/plain', 1038),
 ('text/html', 953),
 ('multipart(text/plain, text/html)', 204),
 ('multipart(text/html)', 68),
 ('multipart(text/plain)', 63),
 ('multipart(multipart(text/html))', 28),
 ('multipart(text/plain, image/jpeg)', 6),
 ('multipart(multipart(text/plain, text/html))', 5),
 ('multipart(text/plain, application/octet-stream)', 4),
 ('multipart(text/html, text/plain)', 4),
 ('multipart(text/html, application/octet-stream)', 4),
 ('multipart(text/plain, application/octet-stream, text/plain)', 3),
 ('multipart/alternative', 3),
 ('multipart(multipart(text/html), application/octet-stream, image/jpeg)', 2),
 ('multipart(multipart(text/plain, text/html), image/gif)', 2),
 ('multipart(text/html, image/jpeg)', 2),
 ('multipart(multipart(text/plain), application/octet-stream)', 2),
 ('multipart(text/plain, multipart(text/plain))', 1),
 ('multipart(multipart(text/plain, text/html), image/jpeg, image/jpeg, image/jpeg, image/jpeg, image/jpeg)',
  1),
 ('multipart(multipart(text/plain,

It seems that the ham emails are more often plain text, while spam has quite a lot of HTML. Moreover, quite a few ham emails are signed using PGP, while no spam is. In short, it seems that the email structure is useful information to have.

Now let's take a look at the email headers:

In [122]:
for header, value in spam_emails[2].items():
    print(header,":",value)

Return-Path : <12a1mailbot1@web.de>
Delivered-To : zzzz@localhost.spamassassin.taint.org
Received : from localhost (localhost [127.0.0.1])	by phobos.labs.spamassassin.taint.org (Postfix) with ESMTP id 136B943C32	for <zzzz@localhost>; Thu, 22 Aug 2002 08:17:21 -0400 (EDT)
Received : from mail.webnote.net [193.120.211.219]	by localhost with POP3 (fetchmail-5.9.0)	for zzzz@localhost (single-drop); Thu, 22 Aug 2002 13:17:21 +0100 (IST)
Received : from dd_it7 ([210.97.77.167])	by webnote.net (8.9.3/8.9.3) with ESMTP id NAA04623	for <zzzz@spamassassin.taint.org>; Thu, 22 Aug 2002 13:09:41 +0100
From : 12a1mailbot1@web.de
Received : from r-smtp.korea.com - 203.122.2.197 by dd_it7  with Microsoft SMTPSVC(5.5.1775.675.6);	 Sat, 24 Aug 2002 09:42:10 +0900
To : dcek1a1@netsgo.com
Subject : Life Insurance - Why Pay More?
Date : Wed, 21 Aug 2002 20:31:57 -1600
MIME-Version : 1.0
Message-ID : <0103c1042001882DD_IT7@dd_it7>
Content-Type : text/html; charset="iso-8859-1"
Content-Transfer-Encoding : qu

In [126]:
spam_emails[2]["Subject"]

'Life Insurance - Why Pay More?'

Okay, before we learn too much about the data, let's not forget to split it into a training set and a test set:

In [127]:
import numpy as np
from sklearn.model_selection import train_test_split

X = np.array(ham_emails + spam_emails)
y = np.array([0] * len(ham_emails) + [1] * len(spam_emails))

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Prepare the data for machine learnig algorithms

We would like to create a pipeline to convert each email into a feature vector.
The feature vector should include:
* Bag-of-words representation of the content, which counts the no. of occurrences of each word. Perhaps compare with not counting the occurrences as well.
* Hyperparameters, to control whether to or whether not to strip headers
* Convert each email to lowercase.
* Remove punctuation.
* replace all urls with "URL"
* replace all numbers with "NUMBER"
* perform word stemming (i.e. trim off word endings)

Okay, let's start writing the preprocessing functions. First, we will need a function to convert HTML to plain text. Arguably the best way to do this would be to use the great [BeautifulSoup](https://www.crummy.com/software/BeautifulSoup/) library, but I would like to avoid adding another dependency to this project, so let's hack a quick & dirty solution using regular expressions (at the risk of [un̨ho͞ly radiańcé destro҉ying all enli̍̈́̂̈́ghtenment](https://stackoverflow.com/a/1732454/38626)). The following function first drops the `<head>` section, then converts all `<a>` tags to the word HYPERLINK, then it gets rid of all HTML tags, leaving only the plain text. For readability, it also replaces multiple newlines with single newlines, and finally it unescapes html entities (such as `&gt;` or `&nbsp;`):

In [128]:
import re
from html import unescape

def html_to_plain_text(html):
    text = re.sub('<head.*?>.*?</head>', '', html, flags=re.M | re.S | re.I)
    text = re.sub('<a\s.*?>', ' HYPERLINK ', text, flags=re.M | re.S | re.I)
    text = re.sub('<.*?>', '', text, flags=re.M | re.S)
    text = re.sub(r'(\s*\n)+', '\n', text, flags=re.M | re.S)
    return unescape(text)

Let's see if it works. This is HTML spam:

In [131]:
html_spam_emails = [email for email in X_train[y_train==1]
                    if get_email_structure(email) == "text/html"]
sample_html_spam = html_spam_emails[8]
print(sample_html_spam.get_content().strip()[:1000], "...")

<html>
<body bgcolor="#FFFFFF" text="#000000">
<table width="500" border="0" cellspacing="0" cellpadding="0" align="center">
  <tr>
    <td align="center">
      <h2><font face="Verdana, Arial, Helvetica, sans-serif"><a href="http://203.197.254.12/sharworld/enterL.htm"><font size="4">Secretly 
        Attract Women or Men<br>
        Add Some Spice To Your Life</font></a></font></h2>
    </td>
  </tr>
  <tr>
    <td align="center"><a href="http://203.197.254.12/sharworld/enterLhtm"><img src="http://203.197.254.12/sharworld/images/woman.jpeg" width="173.25" height="202.5"></a></td>
  </tr>
  <tr>
    <td align="center">
      <h2><font face="Verdana, Arial, Helvetica, sans-serif"><a href="http://203.197.254.12/sharworld/enterL.htm"><font size="4">SecretlyAttract 
        Women or Men</font></a></font></h2>
      
    </td>
  </tr>
</table>
<p align="center"><font color="#000000" back="#ffffff" style="BACKGROUND-COLOR:" #ffffff="#ffffff" size="2" ptsize="8" family="SANSSERIF" face="arial

And this is the resulting plain text:

In [132]:
print(html_to_plain_text(sample_html_spam.get_content())[:1000], "...")


       HYPERLINK Secretly
        Attract Women or Men
        Add Some Spice To Your Life
     HYPERLINK
       HYPERLINK SecretlyAttract
        Women or Men
 HYPERLINK Delete
 ...


Great! Now let's write a function that takes an email as input and returns its content as plain text, whatever its format is:

In [133]:
def email_to_text(email):
    html = None
    for part in email.walk():
        ctype = part.get_content_type()
        if not ctype in ("text/plain", "text/html"):
            continue
        try:
            content = part.get_content()
        except: # in case of encoding issues
            content = str(part.get_payload())
        if ctype == "text/plain":
            return content
        else:
            html = content
    if html:
        return html_to_plain_text(html)

In [134]:
print(email_to_text(sample_html_spam)[:100], "...")


       HYPERLINK Secretly
        Attract Women or Men
        Add Some Spice To Your Life
     HYP ...


Let's throw in some stemming! For this to work, you need to install the Natural Language Toolkit (NLTK). It's as simple as running the following command (don't forget to activate your virtualenv first; if you don't have one, you will likely need administrator rights, or use the --user option):

In [135]:
try:
    import nltk

    stemmer = nltk.PorterStemmer()
    for word in ("Computations", "Computation", "Computing", "Computed", "Compute", "Compulsive"):
        print(word, "=>", stemmer.stem(word))
except ImportError:
    print("Error: stemming requires the NLTK module.")
    stemmer = None

Computations => comput
Computation => comput
Computing => comput
Computed => comput
Compute => comput
Compulsive => compuls


We will also need a way to replace URLs with the word "URL". For this, we could use hard core regular expressions but we will just use the urlextract library. You can install it with the following command (don't forget to activate your virtualenv first; if you don't have one, you will likely need administrator rights, or use the --user option):

In [136]:
try:
    import urlextract # may require an Internet connection to download root domain names
    
    url_extractor = urlextract.URLExtract()
    print(url_extractor.find_urls("Will it detect github.com and https://youtu.be/7Pq-S557XQU?t=3m32s"))
except ImportError:
    print("Error: replacing URLs requires the urlextract module.")
    url_extractor = None

['github.com', 'https://youtu.be/7Pq-S557XQU?t=3m32s']


We are ready to put all this together into a transformer that we will use to convert emails to word counters. Note that we split sentences into words using Python's split() method, which uses whitespaces for word boundaries. This works for many written languages, but not all. For example, Chinese and Japanese scripts generally don't use spaces between words, and Vietnamese often uses spaces even between syllables. It's okay in this exercise, because the dataset is (mostly) in English.

In [137]:
from sklearn.base import BaseEstimator, TransformerMixin

class EmailToWordCounterTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, strip_headers=True, lower_case=True, remove_punctuation=True,
                 replace_urls=True, replace_numbers=True, stemming=True):
        self.strip_headers = strip_headers
        self.lower_case = lower_case
        self.remove_punctuation = remove_punctuation
        self.replace_urls = replace_urls
        self.replace_numbers = replace_numbers
        self.stemming = stemming
    def fit(self, X, y=None):
        return self
    def transform(self, X, y=None):
        X_transformed = []
        for email in X:
            text = email_to_text(email) or ""
            if self.lower_case:
                text = text.lower()
            if self.replace_urls and url_extractor is not None:
                urls = list(set(url_extractor.find_urls(text)))
                urls.sort(key=lambda url: len(url), reverse=True)
                for url in urls:
                    text = text.replace(url, " URL ")
            if self.replace_numbers:
                text = re.sub(r'\d+(?:\.\d*(?:[eE]\d+))?', 'NUMBER', text)
            if self.remove_punctuation:
                text = re.sub(r'\W+', ' ', text, flags=re.M)
            word_counts = Counter(text.split())
            if self.stemming and stemmer is not None:
                stemmed_word_counts = Counter()
                for word, count in word_counts.items():
                    stemmed_word = stemmer.stem(word)
                    stemmed_word_counts[stemmed_word] += count
                word_counts = stemmed_word_counts
            X_transformed.append(word_counts)
        return np.array(X_transformed)

In [138]:
X_few = X_train[:3]
X_few_wordcounts = EmailToWordCounterTransformer().fit_transform(X_few)
X_few_wordcounts

array([Counter({'number': 6, 'url': 2, 'of': 2, 'date': 1, 'numbertnumb': 1, 'world': 1, 'latest': 1, 'hundr': 1, 'palestinian': 1, 'vent': 1, 'their': 1, 'anger': 1, 'as': 1, 'dozen': 1, 'isra': 1, 'tank': 1, 'withdrew': 1, 'after': 1, 'a': 1, 'gruell': 1, 'three': 1, 'hour': 1, 'raid': 1, 'on': 1, 'the': 1, 'gaza': 1, 'strip': 1}),
       Counter({'of': 8, 'number': 5, 'it': 5, 'are': 5, 'your': 5, 'the': 4, 'we': 3, 'all': 3, 'filter': 3, 'but': 3, 'full': 3, 's': 2, 'a': 2, 'that': 2, 'don': 2, 't': 2, 'use': 2, 'our': 2, 'most': 2, 'and': 2, 'is': 2, 'onli': 2, 'choos': 2, 'shoe': 2, 'toe': 2, 'on': 1, 'sep': 1, 'gari': 1, 'lawrenc': 1, 'murphi': 1, 'wrote': 1, 'myth': 1, 'part': 1, 'brain': 1, 'alway': 1, 'just': 1, 'cultur': 1, 'induc': 1, 'focal': 1, 'point': 1, 'caus': 1, 'us': 1, 'to': 1, 'time': 1, 'ignor': 1, 'wast': 1, 'lucid': 1, 'measur': 1, 'notch': 1, 'bandwidth': 1, 'station': 1, 'broadcast': 1, 'easi': 1, 'rock': 1, 'for': 1, 'exampl': 1, 'look': 1, 'now': 1, 'feet':

This looks about right!

Now we have the word counts, and we need to convert them to vectors. For this, we will build another transformer whose `fit()` method will build the vocabulary (an ordered list of the most common words) and whose `transform()` method will use the vocabulary to convert word counts to vectors. The output is a sparse matrix.

In [140]:
from scipy.sparse import csr_matrix

class WordCounterToVectorTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, vocabulary_size=1000):
        self.vocabulary_size = vocabulary_size
    def fit(self, X, y=None):
        total_count = Counter()
        for word_count in X:
            for word, count in word_count.items():
                total_count[word] += min(count, 10)
        most_common = total_count.most_common()[:self.vocabulary_size]
        self.most_common_ = most_common
        self.vocabulary_ = {word: index + 1 for index, (word, count) in enumerate(most_common)}
        return self
    def transform(self, X, y=None):
        rows = []
        cols = []
        data = []
        for row, word_count in enumerate(X):
            for word, count in word_count.items():
                rows.append(row)
                cols.append(self.vocabulary_.get(word, 0))
                data.append(count)
        return csr_matrix((data, (rows, cols)), shape=(len(X), self.vocabulary_size + 1))

In [141]:
vocab_transformer = WordCounterToVectorTransformer(vocabulary_size=10)
X_few_vectors = vocab_transformer.fit_transform(X_few_wordcounts)
X_few_vectors

<3x11 sparse matrix of type '<class 'numpy.longlong'>'
	with 22 stored elements in Compressed Sparse Row format>

In [142]:
X_few_vectors.toarray()

array([[21,  6,  2,  2,  1,  0,  0,  0,  1,  0,  1],
       [97,  5,  8,  0,  4,  5,  5,  5,  2,  2,  1],
       [32,  0,  0,  4,  0,  0,  0,  0,  1,  2,  1]], dtype=int64)

In [143]:
vocab_transformer.vocabulary_

{'number': 1,
 'of': 2,
 'url': 3,
 'the': 4,
 'it': 5,
 'are': 6,
 'your': 7,
 'a': 8,
 'and': 9,
 'on': 10}

We are now ready to train our first spam classifier! Let's transform the whole dataset:

In [144]:
from sklearn.pipeline import Pipeline

preprocess_pipeline = Pipeline([
    ("email_to_wordcount", EmailToWordCounterTransformer()),
    ("wordcount_to_vector", WordCounterToVectorTransformer()),
])

X_train_transformed = preprocess_pipeline.fit_transform(X_train)

In [145]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score

log_clf = LogisticRegression(solver="lbfgs", random_state=42)
score = cross_val_score(log_clf, X_train_transformed, y_train, cv=3, verbose=3)
score.mean()

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.


[CV]  ................................................................


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html.
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.4s remaining:    0.0s


[CV] .................................... , score=0.989, total=   0.4s
[CV]  ................................................................


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html.
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.7s remaining:    0.0s


[CV] .................................... , score=0.989, total=   0.3s
[CV]  ................................................................
[CV] .................................... , score=0.990, total=   0.3s


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html.
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    1.0s finished


0.9891229166816352

Over 98.9%, not bad for a first try! :) However, remember that we are using the "easy" dataset. You can try with the harder datasets, the results won't be so amazing. You would have to try multiple models, select the best ones and fine-tune them using cross-validation, and so on.

But you get the picture, so let's stop now, and just print out the precision/recall we get on the test set:

In [146]:
from sklearn.metrics import precision_score, recall_score

X_test_transformed = preprocess_pipeline.transform(X_test)

log_clf = LogisticRegression(solver="lbfgs", random_state=42)
log_clf.fit(X_train_transformed, y_train)

y_pred = log_clf.predict(X_test_transformed)

print("Precision: {:.2f}%".format(100 * precision_score(y_test, y_pred)))
print("Recall: {:.2f}%".format(100 * recall_score(y_test, y_pred)))

Precision: 98.77%
Recall: 97.97%


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html.
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
