# Text Classifier

Chapter 5 of Programming PyTorch for Deep Learning, but using samples from the ~~[TREC 2005 Spam Corpus](https://trec.nist.gov/data/spam.html)~~ preformatted [Enron Spam/Ham datasets](http://nlp.cs.aueb.gr/software_and_datasets/Enron-Spam/index.html) instead of tweets.

### Data preparation

Converting the preformatted spam/ham samples into a tabular dataset was simple:
    
```sh
#!/bin/sh

rm -f ham-spam-samples.tsv

hams=$(ls ham/*.txt)

for ham in $hams
do
    printf "ham\t" >> ham-spam-samples.tsv
    tr '\n' ' ' < $ham | tr -d $'\r'| tr '\t' ' ' | sed -e 's/  / /g' >> ham-spam-samples.tsv
done

lines=$(wc -l ham-spam-samples.tsv)
echo "hams done: $lines"

spams=$(ls spam/*.txt)

for spam in $spams
do
    printf "spam\t" >> ham-spam-samples.tsv
    tr '\n' ' ' < $spam | tr -cd '\11\12\15\40-\176' | tr -d $'\r'| tr '\t' ' ' | sed -e 's/  / /g' >> ham-spam-samples.tsv
done

lines=$(wc -l ham-spam-samples.tsv)
echo "spams done: $lines"
```

In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchtext import data 
import torchtext

In [2]:
def my_tokenize(s):
    return s.split()

In [3]:
TEXT = data.Field(lower=True, tokenize=my_tokenize)
LABEL = data.Field(lower=True)
samples = data.TabularDataset(path='./data/ham-spam-samples.tsv',
                              format='tsv', 
                              fields=[("label",LABEL), ("statement",TEXT)],
                              skip_header=False)

In [4]:
(training, testing, validating) = samples.split(split_ratio=[0.6,0.2,0.2])
(len(training),len(testing),len(validating))

(3103, 1035, 1034)

In [5]:
vocab_size = 30000
TEXT.build_vocab(training, max_size = vocab_size)
LABEL.build_vocab(training)
TEXT.vocab.freqs.most_common(10)

[('-', 51758),
 ('.', 33125),
 ('/', 25324),
 (',', 24081),
 ('the', 15294),
 (':', 14956),
 ('to', 11960),
 ('ect', 7879),
 ('and', 7784),
 ('@', 7603)]

In [6]:
len(TEXT.vocab)

30002

In [7]:
vars(training.examples[0])

{'label': ['ham'],
 'statement': ['subject:',
  'gary',
  ',',
  'could',
  'you',
  'please',
  'remove',
  'scada',
  '/',
  'mips',
  'from',
  'the',
  'following',
  'hpl',
  'meters',
  '.',
  'these',
  'meters',
  'are',
  'where',
  'mid',
  '-',
  'texas',
  'interconnects',
  'with',
  'pg',
  '&',
  'e',
  "'",
  's',
  'system',
  ',',
  'and',
  'there',
  'should',
  'not',
  'be',
  'any',
  'measurement',
  'taking',
  'place',
  'on',
  'hpl',
  '.',
  'these',
  'meters',
  'were',
  'set',
  'up',
  'originally',
  'when',
  'we',
  'thought',
  'that',
  'every',
  'meter',
  'on',
  'mid',
  '-',
  'texas',
  'should',
  'automatically',
  'have',
  'a',
  'corresponding',
  'meter',
  'on',
  'hpl',
  '.',
  'this',
  'assumption',
  'is',
  'now',
  'incorecct',
  '.',
  '980388',
  '980389',
  '980390',
  '980391',
  '980392',
  '980393',
  '987271',
  '987260',
  '987283',
  'thanks',
  'george']}

In [8]:
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")

In [9]:
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
(training, validating, testing), 
batch_size = 32,
device = device,
sort_key = lambda x: len(x.statement),
sort_within_batch = False)

# Defining the model

Start with a simple [Long short-term memory (LSTM)](https://en.wikipedia.org/wiki/Long_short-term_memory) model. 

Unlike the book, which relies on a three-part classifier, this model is doing a binary comparison, with an activation ([sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function)) function.

In [10]:
class BasicLSTM(nn.Module):
    def __init__(self, hidden_size, embedding_dim, vocab_size):
        super(BasicLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.encoder = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=1)
        self.predictor = nn.Linear(hidden_size, 1)
        self.activator = nn.Sigmoid()

    def forward(self, seq):
        output, (hidden, _) = self.encoder(self.embedding(seq))
        prediction = self.predictor(torch.squeeze(hidden))
        prediction = self.activator(prediction)
        return prediction

In [11]:
training_example = next(iter(train_iterator))
vars(training_example)

{'batch_size': 32,
 'dataset': <torchtext.data.dataset.Dataset at 0x7ff11e2223d0>,
 'fields': dict_keys(['label', 'statement']),
 'input_fields': ['label', 'statement'],
 'target_fields': [],
 'label': tensor([[2, 3, 2, 3, 2, 2, 2, 3, 2, 2, 3, 2, 2, 2, 3, 3, 2, 2, 3, 3, 2, 3, 2, 2,
          2, 2, 2, 2, 2, 2, 2, 3]]),
 'statement': tensor([[  29,   29,   29,  ...,   29,   29,   29],
         [ 320,  972,  324,  ..., 1941,  796,  701],
         [   7,  660, 2387,  ...,    8,  215,    2],
         ...,
         [   1,    1,    1,  ...,    1,    1,    1],
         [   1,    1,    1,  ...,    1,    1,    1],
         [   1,    1,    1,  ...,    1,    1,    1]])}

In [12]:
model = BasicLSTM(100, 300, 30002)
model.to(device)

BasicLSTM(
  (embedding): Embedding(30002, 300)
  (encoder): LSTM(300, 100)
  (predictor): Linear(in_features=100, out_features=1, bias=True)
  (activator): Sigmoid()
)

In [13]:
def train(epochs, nn, optimizer, criterion, train_iterator, valid_iterator):
    for epoch in range(1, epochs + 1):

        training_loss = 0.0
        valid_loss = 0.0
        nn.train()
        for batch_idx, batch in enumerate(train_iterator):
            optimizer.zero_grad()
            predict = nn(batch.statement)
            loss = criterion(predict, batch.label.float())
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * batch.statement.size(0)
        training_loss /= len(train_iterator)
 
        
        model.eval()
        for batch_idx,batch in enumerate(valid_iterator):
            predict = model(batch.statement)
            loss = criterion(predict, batch.label.float())
            valid_loss += loss.data.item() * batch.statement.size(0)
 
        valid_loss /= len(valid_iterator)
        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}'.format(epoch, training_loss, valid_loss))

In [14]:
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCELoss()

In [15]:
train(5, model, optimizer, criterion, train_iterator, valid_iterator)

  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)


Epoch: 1, Training Loss: -13849.14, Validation Loss: -5616.56
Epoch: 2, Training Loss: -148892.76, Validation Loss: -40282.94
Epoch: 3, Training Loss: -181550.52, Validation Loss: -40282.94
Epoch: 4, Training Loss: -186219.99, Validation Loss: -40282.94
Epoch: 5, Training Loss: -187523.51, Validation Loss: -40282.94


# Making Predictions

In [16]:
def classify_text(text):
    categories = {0: 'ham', 1: 'spam'}
    processed = TEXT.process([TEXT.preprocess(text)])
    processed = processed.to(device)
    return categories[model(processed).argmax().item()]

In [17]:
classify_text(testing.examples[0].statement)

'ham'

In [18]:
testing.examples[0].label[0]

'ham'

In [20]:
correct = 0
examined = 0
for test_example in testing.examples:
    actual = test_example.label[0]
    predicted = classify_text(test_example.statement)
    examined += 1
    if actual == predicted:
        correct += 1
        print('Correct   --> {}/{} {:.2%} correct overall'.format(correct, examined, correct / examined))
    else:
        print('Incorrect --> {}/{} {:.2%} correct overall'.format(correct, examined, correct / examined))

Correct   --> 1/1 100.00% correct overall
Correct   --> 2/2 100.00% correct overall
Correct   --> 3/3 100.00% correct overall
Incorrect --> 3/4 75.00% correct overall
Correct   --> 4/5 80.00% correct overall
Correct   --> 5/6 83.33% correct overall
Incorrect --> 5/7 71.43% correct overall
Incorrect --> 5/8 62.50% correct overall
Incorrect --> 5/9 55.56% correct overall
Correct   --> 6/10 60.00% correct overall
Correct   --> 7/11 63.64% correct overall
Correct   --> 8/12 66.67% correct overall
Correct   --> 9/13 69.23% correct overall
Correct   --> 10/14 71.43% correct overall
Correct   --> 11/15 73.33% correct overall
Correct   --> 12/16 75.00% correct overall
Correct   --> 13/17 76.47% correct overall
Incorrect --> 13/18 72.22% correct overall
Correct   --> 14/19 73.68% correct overall
Incorrect --> 14/20 70.00% correct overall
Incorrect --> 14/21 66.67% correct overall
Correct   --> 15/22 68.18% correct overall
Correct   --> 16/23 69.57% correct overall
Incorrect --> 16/24 66.67% cor

Correct   --> 144/192 75.00% correct overall
Incorrect --> 144/193 74.61% correct overall
Correct   --> 145/194 74.74% correct overall
Correct   --> 146/195 74.87% correct overall
Correct   --> 147/196 75.00% correct overall
Correct   --> 148/197 75.13% correct overall
Incorrect --> 148/198 74.75% correct overall
Correct   --> 149/199 74.87% correct overall
Correct   --> 150/200 75.00% correct overall
Correct   --> 151/201 75.12% correct overall
Correct   --> 152/202 75.25% correct overall
Incorrect --> 152/203 74.88% correct overall
Correct   --> 153/204 75.00% correct overall
Incorrect --> 153/205 74.63% correct overall
Correct   --> 154/206 74.76% correct overall
Correct   --> 155/207 74.88% correct overall
Incorrect --> 155/208 74.52% correct overall
Correct   --> 156/209 74.64% correct overall
Correct   --> 157/210 74.76% correct overall
Correct   --> 158/211 74.88% correct overall
Correct   --> 159/212 75.00% correct overall
Correct   --> 160/213 75.12% correct overall
Correct   

Incorrect --> 268/375 71.47% correct overall
Incorrect --> 268/376 71.28% correct overall
Correct   --> 269/377 71.35% correct overall
Incorrect --> 269/378 71.16% correct overall
Incorrect --> 269/379 70.98% correct overall
Correct   --> 270/380 71.05% correct overall
Correct   --> 271/381 71.13% correct overall
Incorrect --> 271/382 70.94% correct overall
Incorrect --> 271/383 70.76% correct overall
Correct   --> 272/384 70.83% correct overall
Correct   --> 273/385 70.91% correct overall
Correct   --> 274/386 70.98% correct overall
Correct   --> 275/387 71.06% correct overall
Correct   --> 276/388 71.13% correct overall
Incorrect --> 276/389 70.95% correct overall
Correct   --> 277/390 71.03% correct overall
Correct   --> 278/391 71.10% correct overall
Correct   --> 279/392 71.17% correct overall
Incorrect --> 279/393 70.99% correct overall
Correct   --> 280/394 71.07% correct overall
Correct   --> 281/395 71.14% correct overall
Correct   --> 282/396 71.21% correct overall
Correct   

Correct   --> 419/578 72.49% correct overall
Incorrect --> 419/579 72.37% correct overall
Incorrect --> 419/580 72.24% correct overall
Incorrect --> 419/581 72.12% correct overall
Incorrect --> 419/582 71.99% correct overall
Correct   --> 420/583 72.04% correct overall
Incorrect --> 420/584 71.92% correct overall
Incorrect --> 420/585 71.79% correct overall
Incorrect --> 420/586 71.67% correct overall
Correct   --> 421/587 71.72% correct overall
Incorrect --> 421/588 71.60% correct overall
Correct   --> 422/589 71.65% correct overall
Correct   --> 423/590 71.69% correct overall
Incorrect --> 423/591 71.57% correct overall
Correct   --> 424/592 71.62% correct overall
Correct   --> 425/593 71.67% correct overall
Correct   --> 426/594 71.72% correct overall
Correct   --> 427/595 71.76% correct overall
Incorrect --> 427/596 71.64% correct overall
Correct   --> 428/597 71.69% correct overall
Correct   --> 429/598 71.74% correct overall
Correct   --> 430/599 71.79% correct overall
Correct   

Correct   --> 555/769 72.17% correct overall
Correct   --> 556/770 72.21% correct overall
Incorrect --> 556/771 72.11% correct overall
Correct   --> 557/772 72.15% correct overall
Correct   --> 558/773 72.19% correct overall
Incorrect --> 558/774 72.09% correct overall
Correct   --> 559/775 72.13% correct overall
Incorrect --> 559/776 72.04% correct overall
Incorrect --> 559/777 71.94% correct overall
Correct   --> 560/778 71.98% correct overall
Correct   --> 561/779 72.02% correct overall
Correct   --> 562/780 72.05% correct overall
Correct   --> 563/781 72.09% correct overall
Correct   --> 564/782 72.12% correct overall
Correct   --> 565/783 72.16% correct overall
Incorrect --> 565/784 72.07% correct overall
Correct   --> 566/785 72.10% correct overall
Correct   --> 567/786 72.14% correct overall
Incorrect --> 567/787 72.05% correct overall
Incorrect --> 567/788 71.95% correct overall
Correct   --> 568/789 71.99% correct overall
Correct   --> 569/790 72.03% correct overall
Correct   

Correct   --> 684/954 71.70% correct overall
Correct   --> 685/955 71.73% correct overall
Incorrect --> 685/956 71.65% correct overall
Incorrect --> 685/957 71.58% correct overall
Correct   --> 686/958 71.61% correct overall
Correct   --> 687/959 71.64% correct overall
Correct   --> 688/960 71.67% correct overall
Incorrect --> 688/961 71.59% correct overall
Correct   --> 689/962 71.62% correct overall
Correct   --> 690/963 71.65% correct overall
Correct   --> 691/964 71.68% correct overall
Correct   --> 692/965 71.71% correct overall
Correct   --> 693/966 71.74% correct overall
Incorrect --> 693/967 71.66% correct overall
Correct   --> 694/968 71.69% correct overall
Correct   --> 695/969 71.72% correct overall
Correct   --> 696/970 71.75% correct overall
Correct   --> 697/971 71.78% correct overall
Correct   --> 698/972 71.81% correct overall
Correct   --> 699/973 71.84% correct overall
Correct   --> 700/974 71.87% correct overall
Correct   --> 701/975 71.90% correct overall
Correct   