# Text Classifier

Chapter 5 of Programming PyTorch for Deep Learning, but using samples from the ~~[TREC 2005 Spam Corpus](https://trec.nist.gov/data/spam.html)~~ preformatted [Enron Spam/Ham datasets](http://nlp.cs.aueb.gr/software_and_datasets/Enron-Spam/index.html) instead of tweets.

### Data preparation

Converting the preformatted spam/ham samples into a tabular dataset was simple:
    
```sh
#!/bin/sh

rm -f ham-spam-samples.tsv

hams=$(ls ham/*.txt)

for ham in $hams
do
    printf "ham\t" >> ham-spam-samples.tsv
    tr '\n' ' ' < $ham | tr -d $'\r'| tr '\t' ' ' | sed -e 's/  / /g' >> ham-spam-samples.tsv
done

lines=$(wc -l ham-spam-samples.tsv)
echo "hams done: $lines"

spams=$(ls spam/*.txt)

for spam in $spams
do
    printf "spam\t" >> ham-spam-samples.tsv
    tr '\n' ' ' < $spam | tr -cd '\11\12\15\40-\176' | tr -d $'\r'| tr '\t' ' ' | sed -e 's/  / /g' >> ham-spam-samples.tsv
done

lines=$(wc -l ham-spam-samples.tsv)
echo "spams done: $lines"
```

In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchtext import data 
import torchtext

In [2]:
def my_tokenize(s):
    return s.split()

In [3]:
TEXT = data.Field(lower=True, tokenize=my_tokenize)
LABEL = data.LabelField()
samples = data.TabularDataset(path='./data/ham-spam-samples.tsv',
                              format='tsv', 
                              fields=[("label",LABEL), ("statement",TEXT)],
                              skip_header=False)

In [4]:
(training, testing, validating) = samples.split(split_ratio=[0.6,0.2,0.2], stratified=True, strata_field='label')
(len(training),len(testing),len(validating))

(3103, 1035, 1034)

In [5]:
vocab_size = 50000
TEXT.build_vocab(training, max_size = vocab_size)
LABEL.build_vocab(training)
TEXT.vocab.freqs.most_common(10)

[('-', 51644),
 ('.', 33018),
 ('/', 26300),
 (',', 25270),
 ('the', 15666),
 (':', 15017),
 ('to', 12512),
 ('ect', 8867),
 ('and', 7817),
 ('@', 7453)]

In [6]:
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")

In [7]:
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
(training, validating, testing), 
batch_size = 32,
device = device,
sort_key = lambda x: len(x.statement),
sort_within_batch = False)

# Defining the model

Start with a simple [Long short-term memory (LSTM)](https://en.wikipedia.org/wiki/Long_short-term_memory) model. 

Unlike the book, which relies on a three-part classifier, this model is doing a binary comparison, with an activation ([sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function)) function.

In [8]:
class BasicLSTM(nn.Module):
    def __init__(self, hidden_size, embedding_dim, vocab_size):
        super(BasicLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.encoder = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=1)
        self.predictor = nn.Linear(hidden_size, 1)
        self.activator = nn.Sigmoid()

    def forward(self, seq):
        output, (hidden, _) = self.encoder(self.embedding(seq))
        prediction = self.predictor(torch.squeeze(hidden))
        prediction = self.activator(prediction)
        return prediction

In [9]:
def train(epochs, model, optimizer, criterion, train_iterator, valid_iterator):
    for epoch in range(1, epochs + 1):
     
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        for batch_idx, batch in enumerate(train_iterator):
            optimizer.zero_grad()
            predict = model(batch.statement)
            loss = criterion(predict, batch.label.float())
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * batch.statement.size(0)
        training_loss /= len(train_iterator)
 
        model.eval()
        for batch_idx, batch in enumerate(valid_iterator):
            predict = model(batch.statement)
            loss = criterion(predict, batch.label.float())
            valid_loss += loss.data.item() * batch.statement.size(0)
 
        valid_loss /= len(valid_iterator)
        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}'.format(epoch, training_loss, valid_loss))

In [10]:
my_model = BasicLSTM(100, 300, len(TEXT.vocab))
my_model.to(device)

BasicLSTM(
  (embedding): Embedding(39003, 300)
  (encoder): LSTM(300, 100)
  (predictor): Linear(in_features=100, out_features=1, bias=True)
  (activator): Sigmoid()
)

In [11]:
optimizer = optim.Adam(my_model.parameters(), lr=0.02)
criterion = nn.BCELoss()

In [12]:
train(10, my_model, optimizer, criterion, train_iterator, valid_iterator)

  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)


Epoch: 1, Training Loss: 829.61, Validation Loss: 226.98
Epoch: 2, Training Loss: 788.93, Validation Loss: 214.70
Epoch: 3, Training Loss: 747.78, Validation Loss: 192.94
Epoch: 4, Training Loss: 531.95, Validation Loss: 190.72
Epoch: 5, Training Loss: 418.90, Validation Loss: 195.91
Epoch: 6, Training Loss: 371.30, Validation Loss: 202.34
Epoch: 7, Training Loss: 334.37, Validation Loss: 226.16
Epoch: 8, Training Loss: 323.76, Validation Loss: 200.34
Epoch: 9, Training Loss: 294.99, Validation Loss: 193.62
Epoch: 10, Training Loss: 275.95, Validation Loss: 172.77


# Making Predictions

In [13]:
def classify_text(text):
    categories = {0: 'ham', 1: 'spam'}
    processed = TEXT.process([TEXT.preprocess(text)])
    processed = processed.to(device)
    return categories[my_model(processed).argmax().item()]

In [14]:
classify_text(testing.examples[0].statement)

'ham'

In [15]:
testing.examples[0].label

'ham'

In [16]:
correct = 0
examined = 0
for test_example in testing.examples:
    actual = test_example.label
    predicted = classify_text(test_example.statement)
    examined += 1
    if actual == predicted:
        correct += 1
        print('Correct   --> {}/{} {:.2%} correct overall'.format(correct, examined, correct / examined))
    else:
        print('Incorrect --> {}/{} {:.2%} correct overall'.format(correct, examined, correct / examined))

Correct   --> 1/1 100.00% correct overall
Correct   --> 2/2 100.00% correct overall
Correct   --> 3/3 100.00% correct overall
Correct   --> 4/4 100.00% correct overall
Correct   --> 5/5 100.00% correct overall
Correct   --> 6/6 100.00% correct overall
Correct   --> 7/7 100.00% correct overall
Correct   --> 8/8 100.00% correct overall
Correct   --> 9/9 100.00% correct overall
Correct   --> 10/10 100.00% correct overall
Correct   --> 11/11 100.00% correct overall
Correct   --> 12/12 100.00% correct overall
Correct   --> 13/13 100.00% correct overall
Correct   --> 14/14 100.00% correct overall
Correct   --> 15/15 100.00% correct overall
Correct   --> 16/16 100.00% correct overall
Correct   --> 17/17 100.00% correct overall
Correct   --> 18/18 100.00% correct overall
Correct   --> 19/19 100.00% correct overall
Correct   --> 20/20 100.00% correct overall
Correct   --> 21/21 100.00% correct overall
Correct   --> 22/22 100.00% correct overall
Correct   --> 23/23 100.00% correct overall
Correc

Correct   --> 188/188 100.00% correct overall
Correct   --> 189/189 100.00% correct overall
Correct   --> 190/190 100.00% correct overall
Correct   --> 191/191 100.00% correct overall
Correct   --> 192/192 100.00% correct overall
Correct   --> 193/193 100.00% correct overall
Correct   --> 194/194 100.00% correct overall
Correct   --> 195/195 100.00% correct overall
Correct   --> 196/196 100.00% correct overall
Correct   --> 197/197 100.00% correct overall
Correct   --> 198/198 100.00% correct overall
Correct   --> 199/199 100.00% correct overall
Correct   --> 200/200 100.00% correct overall
Correct   --> 201/201 100.00% correct overall
Correct   --> 202/202 100.00% correct overall
Correct   --> 203/203 100.00% correct overall
Correct   --> 204/204 100.00% correct overall
Correct   --> 205/205 100.00% correct overall
Correct   --> 206/206 100.00% correct overall
Correct   --> 207/207 100.00% correct overall
Correct   --> 208/208 100.00% correct overall
Correct   --> 209/209 100.00% corr

Correct   --> 374/374 100.00% correct overall
Correct   --> 375/375 100.00% correct overall
Correct   --> 376/376 100.00% correct overall
Correct   --> 377/377 100.00% correct overall
Correct   --> 378/378 100.00% correct overall
Correct   --> 379/379 100.00% correct overall
Correct   --> 380/380 100.00% correct overall
Correct   --> 381/381 100.00% correct overall
Correct   --> 382/382 100.00% correct overall
Correct   --> 383/383 100.00% correct overall
Correct   --> 384/384 100.00% correct overall
Correct   --> 385/385 100.00% correct overall
Correct   --> 386/386 100.00% correct overall
Correct   --> 387/387 100.00% correct overall
Correct   --> 388/388 100.00% correct overall
Correct   --> 389/389 100.00% correct overall
Correct   --> 390/390 100.00% correct overall
Correct   --> 391/391 100.00% correct overall
Correct   --> 392/392 100.00% correct overall
Correct   --> 393/393 100.00% correct overall
Correct   --> 394/394 100.00% correct overall
Correct   --> 395/395 100.00% corr

Correct   --> 557/557 100.00% correct overall
Correct   --> 558/558 100.00% correct overall
Correct   --> 559/559 100.00% correct overall
Correct   --> 560/560 100.00% correct overall
Correct   --> 561/561 100.00% correct overall
Correct   --> 562/562 100.00% correct overall
Correct   --> 563/563 100.00% correct overall
Correct   --> 564/564 100.00% correct overall
Correct   --> 565/565 100.00% correct overall
Correct   --> 566/566 100.00% correct overall
Correct   --> 567/567 100.00% correct overall
Correct   --> 568/568 100.00% correct overall
Correct   --> 569/569 100.00% correct overall
Correct   --> 570/570 100.00% correct overall
Correct   --> 571/571 100.00% correct overall
Correct   --> 572/572 100.00% correct overall
Correct   --> 573/573 100.00% correct overall
Correct   --> 574/574 100.00% correct overall
Correct   --> 575/575 100.00% correct overall
Correct   --> 576/576 100.00% correct overall
Correct   --> 577/577 100.00% correct overall
Correct   --> 578/578 100.00% corr

Incorrect --> 735/742 99.06% correct overall
Incorrect --> 735/743 98.92% correct overall
Incorrect --> 735/744 98.79% correct overall
Incorrect --> 735/745 98.66% correct overall
Incorrect --> 735/746 98.53% correct overall
Incorrect --> 735/747 98.39% correct overall
Incorrect --> 735/748 98.26% correct overall
Incorrect --> 735/749 98.13% correct overall
Incorrect --> 735/750 98.00% correct overall
Incorrect --> 735/751 97.87% correct overall
Incorrect --> 735/752 97.74% correct overall
Incorrect --> 735/753 97.61% correct overall
Incorrect --> 735/754 97.48% correct overall
Incorrect --> 735/755 97.35% correct overall
Incorrect --> 735/756 97.22% correct overall
Incorrect --> 735/757 97.09% correct overall
Incorrect --> 735/758 96.97% correct overall
Incorrect --> 735/759 96.84% correct overall
Incorrect --> 735/760 96.71% correct overall
Incorrect --> 735/761 96.58% correct overall
Incorrect --> 735/762 96.46% correct overall
Incorrect --> 735/763 96.33% correct overall
Incorrect 

Incorrect --> 735/932 78.86% correct overall
Incorrect --> 735/933 78.78% correct overall
Incorrect --> 735/934 78.69% correct overall
Incorrect --> 735/935 78.61% correct overall
Incorrect --> 735/936 78.53% correct overall
Incorrect --> 735/937 78.44% correct overall
Incorrect --> 735/938 78.36% correct overall
Incorrect --> 735/939 78.27% correct overall
Incorrect --> 735/940 78.19% correct overall
Incorrect --> 735/941 78.11% correct overall
Incorrect --> 735/942 78.03% correct overall
Incorrect --> 735/943 77.94% correct overall
Incorrect --> 735/944 77.86% correct overall
Incorrect --> 735/945 77.78% correct overall
Incorrect --> 735/946 77.70% correct overall
Incorrect --> 735/947 77.61% correct overall
Incorrect --> 735/948 77.53% correct overall
Incorrect --> 735/949 77.45% correct overall
Incorrect --> 735/950 77.37% correct overall
Incorrect --> 735/951 77.29% correct overall
Incorrect --> 735/952 77.21% correct overall
Incorrect --> 735/953 77.12% correct overall
Incorrect 