# Text Classifier

Chapter 5 of Programming PyTorch for Deep Learning, but using samples from the ~~[TREC 2005 Spam Corpus](https://trec.nist.gov/data/spam.html)~~ preformatted [Enron Spam/Ham datasets](http://nlp.cs.aueb.gr/software_and_datasets/Enron-Spam/index.html) instead of tweets.

### Data preparation

Converting the preformatted spam/ham samples into a tabular dataset was simple:
    
```sh
#!/bin/sh

rm -f ham-spam-samples.tsv

hams=$(ls ham/*.txt)

for ham in $hams
do
    printf "ham\t" >> ham-spam-samples.tsv
    tr '\n' ' ' < $ham | tr -d $'\r'| tr '\t' ' ' | sed -e 's/  / /g' >> ham-spam-samples.tsv
done

lines=$(wc -l ham-spam-samples.tsv)
echo "hams done: $lines"

spams=$(ls spam/*.txt)

for spam in $spams
do
    printf "spam\t" >> ham-spam-samples.tsv
    tr '\n' ' ' < $spam | tr -cd '\11\12\15\40-\176' | tr -d $'\r'| tr '\t' ' ' | sed -e 's/  / /g' >> ham-spam-samples.tsv
done

lines=$(wc -l ham-spam-samples.tsv)
echo "spams done: $lines"
```

In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchtext import data 
import torchtext

In [2]:
def my_tokenize(s):
    return s.split()

In [3]:
TEXT = data.Field(lower=True, tokenize=my_tokenize)
LABEL = data.Field(lower=True)
samples = data.TabularDataset(path='./data/ham-spam-samples.tsv',
                              format='tsv', 
                              fields=[("label",LABEL), ("statement",TEXT)],
                              skip_header=False)

In [4]:
(training, testing, validating) = samples.split(split_ratio=[0.6,0.2,0.2])
(len(training),len(testing),len(validating))

(3103, 1035, 1034)

In [5]:
vocab_size = 50000
TEXT.build_vocab(training, max_size = vocab_size)
LABEL.build_vocab(training)
TEXT.vocab.freqs.most_common(10)

[('-', 51212),
 ('.', 33714),
 ('/', 25634),
 (',', 23960),
 (':', 14989),
 ('the', 14925),
 ('to', 11947),
 ('ect', 7904),
 ('@', 7773),
 ('and', 7419)]

In [6]:
len(TEXT.vocab)

38053

In [7]:
vars(training.examples[0])

{'label': ['ham'],
 'statement': ['subject:',
  'supply',
  'for',
  'midlothian',
  'daren',
  '-',
  '-',
  'fyi',
  '-',
  'exxon',
  'supply',
  'at',
  'mobil',
  'cayanosa',
  'has',
  'fallen',
  'off',
  'due',
  'to',
  'problems',
  ',',
  'so',
  'starting',
  '9',
  '/',
  '9',
  'there',
  'nom',
  'will',
  'go',
  'from',
  '5',
  '.',
  '000',
  'down',
  'to',
  '.',
  '520',
  'and',
  'going',
  'at',
  'least',
  'through',
  'the',
  'weekend',
  '9',
  '/',
  '11',
  '.',
  'stacey',
  'has',
  'already',
  'bought',
  'gas',
  'from',
  'duke',
  'to',
  'resupply',
  '(',
  'the',
  'difference',
  '4',
  '.',
  '480',
  ')',
  'midlothian',
  'at',
  'duke',
  'pegasus',
  ',',
  'in',
  'addition',
  'to',
  'the',
  '8',
  '.',
  '600',
  'incremental',
  'she',
  'bought',
  'there',
  '.',
  'so',
  'for',
  'the',
  'weekend',
  ',',
  'we',
  'should',
  'have',
  'a',
  'sale',
  'to',
  'midlothian',
  'of',
  '13',
  ',',
  '080',
  'at',
  'pegasus',


In [8]:
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")

In [9]:
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
(training, validating, testing), 
batch_size = 32,
device = device,
sort_key = lambda x: len(x.statement),
sort_within_batch = False)

# Defining the model

Start with a simple [Long short-term memory (LSTM)](https://en.wikipedia.org/wiki/Long_short-term_memory) model. 

Unlike the book, which relies on a three-part classifier, this model is doing a binary comparison, with an activation ([sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function)) function.

In [10]:
class BasicLSTM(nn.Module):
    def __init__(self, hidden_size, embedding_dim, vocab_size):
        super(BasicLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.encoder = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=1)
        self.predictor = nn.Linear(hidden_size, 1)
        self.activator = nn.Sigmoid()

    def forward(self, seq):
        output, (hidden, _) = self.encoder(self.embedding(seq))
        prediction = self.predictor(torch.squeeze(hidden))
        prediction = self.activator(prediction)
        return prediction

In [11]:
training_example = next(iter(train_iterator))
vars(training_example)

{'batch_size': 32,
 'dataset': <torchtext.data.dataset.Dataset at 0x7fe868233850>,
 'fields': dict_keys(['label', 'statement']),
 'input_fields': ['label', 'statement'],
 'target_fields': [],
 'label': tensor([[3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 3, 3, 2, 2, 2, 2,
          2, 2, 2, 2, 2, 2, 2, 2]]),
 'statement': tensor([[   28,    28,    28,  ...,    28,    28,    28],
         [  363,   288,    23,  ...,    99,    62,    76],
         [31714,   730,     4,  ...,  2629,  8762,     6],
         ...,
         [ 3288,     1,     1,  ...,     1,     1,     1],
         [ 2910,     1,     1,  ...,     1,     1,     1],
         [  443,     1,     1,  ...,     1,     1,     1]])}

In [12]:
my_model = BasicLSTM(100, 300, 38053)
my_model.to(device)

BasicLSTM(
  (embedding): Embedding(38053, 300)
  (encoder): LSTM(300, 100)
  (predictor): Linear(in_features=100, out_features=1, bias=True)
  (activator): Sigmoid()
)

In [13]:
def train(epochs, model, optimizer, criterion, train_iterator, valid_iterator):
    for epoch in range(1, epochs + 1):
     
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        for batch_idx, batch in enumerate(train_iterator):
            optimizer.zero_grad()
            predict = model(batch.statement)
            loss = criterion(predict, batch.label.float())
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * batch.statement.size(0)
        training_loss /= len(train_iterator)
 
        
        model.eval()
        for batch_idx, batch in enumerate(valid_iterator):
            predict = model(batch.statement)
            loss = criterion(predict, batch.label.float())
            valid_loss += loss.data.item() * batch.statement.size(0)
 
        valid_loss /= len(valid_iterator)
        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}'.format(epoch, training_loss, valid_loss))

In [14]:
optimizer = optim.Adam(my_model.parameters(), lr=0.02)
criterion = nn.BCELoss()

In [15]:
train(5, my_model, optimizer, criterion, train_iterator, valid_iterator)

  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)


Epoch: 1, Training Loss: -147430.23, Validation Loss: -41134.60
Epoch: 2, Training Loss: -186169.90, Validation Loss: -41134.60
Epoch: 3, Training Loss: -184700.01, Validation Loss: -41134.60
Epoch: 4, Training Loss: -187422.17, Validation Loss: -41134.60
Epoch: 5, Training Loss: -183625.47, Validation Loss: -41134.60


# Making Predictions

In [18]:
def classify_text(text):
    categories = {0: 'ham', 1: 'spam'}
    processed = TEXT.process([TEXT.preprocess(text)])
    processed = processed.to(device)
    return categories[my_model(processed).argmax().item()]

In [19]:
classify_text(testing.examples[0].statement)

'ham'

In [20]:
testing.examples[0].label[0]

'ham'

In [21]:
correct = 0
examined = 0
for test_example in testing.examples:
    actual = test_example.label[0]
    predicted = classify_text(test_example.statement)
    examined += 1
    if actual == predicted:
        correct += 1
        print('Correct   --> {}/{} {:.2%} correct overall'.format(correct, examined, correct / examined))
    else:
        print('Incorrect --> {}/{} {:.2%} correct overall'.format(correct, examined, correct / examined))

Correct   --> 1/1 100.00% correct overall
Correct   --> 2/2 100.00% correct overall
Correct   --> 3/3 100.00% correct overall
Correct   --> 4/4 100.00% correct overall
Correct   --> 5/5 100.00% correct overall
Correct   --> 6/6 100.00% correct overall
Correct   --> 7/7 100.00% correct overall
Correct   --> 8/8 100.00% correct overall
Correct   --> 9/9 100.00% correct overall
Correct   --> 10/10 100.00% correct overall
Correct   --> 11/11 100.00% correct overall
Incorrect --> 11/12 91.67% correct overall
Correct   --> 12/13 92.31% correct overall
Correct   --> 13/14 92.86% correct overall
Incorrect --> 13/15 86.67% correct overall
Incorrect --> 13/16 81.25% correct overall
Correct   --> 14/17 82.35% correct overall
Correct   --> 15/18 83.33% correct overall
Correct   --> 16/19 84.21% correct overall
Correct   --> 17/20 85.00% correct overall
Correct   --> 18/21 85.71% correct overall
Correct   --> 19/22 86.36% correct overall
Incorrect --> 19/23 82.61% correct overall
Correct   --> 20/2

Correct   --> 146/194 75.26% correct overall
Correct   --> 147/195 75.38% correct overall
Correct   --> 148/196 75.51% correct overall
Correct   --> 149/197 75.63% correct overall
Correct   --> 150/198 75.76% correct overall
Correct   --> 151/199 75.88% correct overall
Correct   --> 152/200 76.00% correct overall
Correct   --> 153/201 76.12% correct overall
Correct   --> 154/202 76.24% correct overall
Correct   --> 155/203 76.35% correct overall
Correct   --> 156/204 76.47% correct overall
Correct   --> 157/205 76.59% correct overall
Incorrect --> 157/206 76.21% correct overall
Incorrect --> 157/207 75.85% correct overall
Correct   --> 158/208 75.96% correct overall
Correct   --> 159/209 76.08% correct overall
Correct   --> 160/210 76.19% correct overall
Correct   --> 161/211 76.30% correct overall
Incorrect --> 161/212 75.94% correct overall
Incorrect --> 161/213 75.59% correct overall
Correct   --> 162/214 75.70% correct overall
Correct   --> 163/215 75.81% correct overall
Incorrect 

Correct   --> 277/380 72.89% correct overall
Correct   --> 278/381 72.97% correct overall
Correct   --> 279/382 73.04% correct overall
Correct   --> 280/383 73.11% correct overall
Incorrect --> 280/384 72.92% correct overall
Correct   --> 281/385 72.99% correct overall
Incorrect --> 281/386 72.80% correct overall
Correct   --> 282/387 72.87% correct overall
Correct   --> 283/388 72.94% correct overall
Incorrect --> 283/389 72.75% correct overall
Incorrect --> 283/390 72.56% correct overall
Incorrect --> 283/391 72.38% correct overall
Correct   --> 284/392 72.45% correct overall
Correct   --> 285/393 72.52% correct overall
Correct   --> 286/394 72.59% correct overall
Incorrect --> 286/395 72.41% correct overall
Incorrect --> 286/396 72.22% correct overall
Correct   --> 287/397 72.29% correct overall
Correct   --> 288/398 72.36% correct overall
Incorrect --> 288/399 72.18% correct overall
Incorrect --> 288/400 72.00% correct overall
Incorrect --> 288/401 71.82% correct overall
Incorrect 

Correct   --> 408/566 72.08% correct overall
Incorrect --> 408/567 71.96% correct overall
Incorrect --> 408/568 71.83% correct overall
Correct   --> 409/569 71.88% correct overall
Correct   --> 410/570 71.93% correct overall
Correct   --> 411/571 71.98% correct overall
Correct   --> 412/572 72.03% correct overall
Incorrect --> 412/573 71.90% correct overall
Correct   --> 413/574 71.95% correct overall
Incorrect --> 413/575 71.83% correct overall
Correct   --> 414/576 71.88% correct overall
Correct   --> 415/577 71.92% correct overall
Correct   --> 416/578 71.97% correct overall
Correct   --> 417/579 72.02% correct overall
Incorrect --> 417/580 71.90% correct overall
Correct   --> 418/581 71.94% correct overall
Incorrect --> 418/582 71.82% correct overall
Correct   --> 419/583 71.87% correct overall
Correct   --> 420/584 71.92% correct overall
Correct   --> 421/585 71.97% correct overall
Correct   --> 422/586 72.01% correct overall
Correct   --> 423/587 72.06% correct overall
Correct   

Correct   --> 543/753 72.11% correct overall
Correct   --> 544/754 72.15% correct overall
Correct   --> 545/755 72.19% correct overall
Incorrect --> 545/756 72.09% correct overall
Incorrect --> 545/757 71.99% correct overall
Correct   --> 546/758 72.03% correct overall
Incorrect --> 546/759 71.94% correct overall
Correct   --> 547/760 71.97% correct overall
Incorrect --> 547/761 71.88% correct overall
Correct   --> 548/762 71.92% correct overall
Incorrect --> 548/763 71.82% correct overall
Correct   --> 549/764 71.86% correct overall
Incorrect --> 549/765 71.76% correct overall
Correct   --> 550/766 71.80% correct overall
Incorrect --> 550/767 71.71% correct overall
Correct   --> 551/768 71.74% correct overall
Correct   --> 552/769 71.78% correct overall
Correct   --> 553/770 71.82% correct overall
Correct   --> 554/771 71.85% correct overall
Incorrect --> 554/772 71.76% correct overall
Correct   --> 555/773 71.80% correct overall
Correct   --> 556/774 71.83% correct overall
Correct   

Correct   --> 677/936 72.33% correct overall
Correct   --> 678/937 72.36% correct overall
Correct   --> 679/938 72.39% correct overall
Correct   --> 680/939 72.42% correct overall
Correct   --> 681/940 72.45% correct overall
Correct   --> 682/941 72.48% correct overall
Correct   --> 683/942 72.51% correct overall
Correct   --> 684/943 72.53% correct overall
Correct   --> 685/944 72.56% correct overall
Correct   --> 686/945 72.59% correct overall
Correct   --> 687/946 72.62% correct overall
Correct   --> 688/947 72.65% correct overall
Correct   --> 689/948 72.68% correct overall
Correct   --> 690/949 72.71% correct overall
Correct   --> 691/950 72.74% correct overall
Correct   --> 692/951 72.77% correct overall
Correct   --> 693/952 72.79% correct overall
Incorrect --> 693/953 72.72% correct overall
Incorrect --> 693/954 72.64% correct overall
Correct   --> 694/955 72.67% correct overall
Incorrect --> 694/956 72.59% correct overall
Correct   --> 695/957 72.62% correct overall
Correct   