# Text classification with a CNN in PyTorch

## Data

In [1]:
TRAIN_PATH = "data/text_classification/20newsgroups_train.tsv"
DEV_PATH = "data/text_classification/20newsgroups_dev.tsv"
TEST_PATH = "data/text_classification/20newsgroups_test.tsv"

In [2]:
from sklearn.datasets import fetch_20newsgroups

train = fetch_20newsgroups(subset="train")
target_names = train.target_names
label2idx = {label: idx for idx, label in enumerate(target_names)}

Make sure you have the English spacy model installed, as this will be used for tokenization:
    
```
> pip install spacy
> python -m spacy download en
```

In [3]:
import csv
import sys
from torchtext.data import TabularDataset, Field, BucketIterator

csv.field_size_limit(sys.maxsize)

text = Field(sequential=True, tokenize="spacy")
label = Field(sequential=False, use_vocab=False, preprocessing=lambda x: label2idx[x])

train_data = TabularDataset(path=TRAIN_PATH, format='tsv', fields=[('label', label), ('text', text)])
dev_data = TabularDataset(path=DEV_PATH, format='tsv', fields=[('label', label), ('text', text)])
test_data = TabularDataset(path=TEST_PATH, format='tsv', fields=[('label', label), ('text', text)])

In [4]:
VOCAB_SIZE = 30000

text.build_vocab(train_data, max_size=VOCAB_SIZE)

In [5]:
BATCH_SIZE = 16
train_iter = BucketIterator(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
dev_iter = BucketIterator(dataset=dev_data, batch_size=BATCH_SIZE)
test_iter = BucketIterator(dataset=test_data, batch_size=BATCH_SIZE)

## Model

In [6]:
import torch.nn as nn
import torch.nn.functional as F


class CNNClassifier(nn.Module):

    def __init__(self, embedding_dim, filter_sizes, num_filters, vocab_size, output_size):
        super(CNNClassifier, self).__init__()
        
        # 1. Embedding Layer
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        
        # 2. LSTM Layer
        self.cnn = nn.ModuleList([nn.Conv1d(1, num_filters, (fs, embedding_dim)) for fs in filter_sizes])

        # 3. Dense Layer
        self.hidden2out = nn.Linear(num_filters*len(filter_sizes), output_size)
        
        # Optional dropout layer
        self.dropout_layer = nn.Dropout(p=0.4)

    def forward(self, batch_text):

        embeddings = self.embeddings(batch_text)

        embeddings = embeddings.transpose(0,1)  # (batch, length, embed_dim)
        embeddings = embeddings.unsqueeze(1)    # (batch, channels, length, embed_dim)
        conv_out = [conv(embeddings) for conv in self.cnn]  # (batch, num_filters, output_length, 1)
        conv_out = [F.relu(t).squeeze(3) for t in conv_out]
        conv_out = [F.max_pool1d(t, t.size(2)).squeeze(2) for t in conv_out]
        conv_out = torch.cat(conv_out, 1)

        conv_out = self.dropout_layer(conv_out)
        final_output = self.hidden2out(conv_out)
        return final_output

## Training

In [11]:
import torch
import torch.optim as optim
from tqdm import tqdm_notebook as tqdm
from sklearn.metrics import precision_recall_fscore_support

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def train(model, train_iter, dev_iter, batch_size, num_batches):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())

    max_epochs = 20
    loss_history = []
    patience = 2
    best_state_path = None
    for epoch in range(max_epochs):

        total_loss = 0
        predictions, correct = [], []
        for batch in tqdm(train_iter, total=num_batches):
            optimizer.zero_grad()

            pred = model(batch.text.to(device))
            loss = criterion(pred, batch.label.to(device))
            total_loss += loss.item()

            loss.backward()
            optimizer.step()

            _, pred_indices = torch.max(pred, 1)
            predictions += list(pred_indices.cpu().numpy())
            correct += list(batch.label.cpu().numpy())

        print("=== Epoch", epoch, "===")
        print("Total training loss:", total_loss)
        print("Training performance:", precision_recall_fscore_support(correct, predictions, average="micro"))
        
        total_loss = 0
        predictions, correct = [], []
        for batch in dev_iter:

            pred = model(batch.text.to(device))
            loss = criterion(pred, batch.label.to(device))
            total_loss += loss.item()

            _, pred_indices = torch.max(pred, 1)
            pred_indices = list(pred_indices.cpu().numpy())
            predictions += pred_indices
            correct += list(batch.label.cpu().numpy())

        print("Total development loss:", total_loss)
        dev_stats = precision_recall_fscore_support(correct, predictions, average="micro")
        print("Development performance:", dev_stats)
        
        if len(loss_history) == 0 or total_loss < min(loss_history): 
            fscore = dev_stats[2]
            path = f"model_state_{epoch}_{round(total_loss,2)}_{round(fscore,2)}"
            torch.save(model.state_dict(), path)
            best_state_path = path
            
        if len(loss_history) > 0 and total_loss > max(loss_history[-patience:]):
            print("No improvement on development set. Finishing training.")
            break
            
        loss_history.append(total_loss)
        
    return best_state_path

In [12]:
EMBEDDING_DIM = 300
NUM_FILTERS = 128
FILTER_SIZES = [3,4,5]
NUM_CLASSES = len(label2idx)
num_batches = int(len(train_data) / BATCH_SIZE)

classifier = CNNClassifier(EMBEDDING_DIM, FILTER_SIZES, NUM_FILTERS, VOCAB_SIZE+2, NUM_CLASSES)  

best_state_path = train(classifier.to(device), train_iter, dev_iter, BATCH_SIZE, num_batches)


HBox(children=(IntProgress(value=0, max=707), HTML(value='')))


=== Epoch 0 ===
Total training loss: 1744.2471042275429
Training performance: (0.31474279653526605, 0.31474279653526605, 0.31474279653526605, None)
Total development loss: 1002.7418279647827
Development performance: (0.3990971853425385, 0.3990971853425385, 0.3990971853425385, None)


HBox(children=(IntProgress(value=0, max=707), HTML(value='')))


=== Epoch 1 ===
Total training loss: 807.0345216393471
Training performance: (0.6590065405692063, 0.6590065405692063, 0.6590065405692063, None)
Total development loss: 741.3295960128307
Development performance: (0.5536378120021243, 0.5536378120021243, 0.5536378120021243, None)


HBox(children=(IntProgress(value=0, max=707), HTML(value='')))


=== Epoch 2 ===
Total training loss: 404.7720814496279
Training performance: (0.8332154852395263, 0.8332154852395263, 0.8332154852395263, None)
Total development loss: 623.7967542409897
Development performance: (0.6317047265002655, 0.6317047265002655, 0.6317047265002655, None)


HBox(children=(IntProgress(value=0, max=707), HTML(value='')))


=== Epoch 3 ===
Total training loss: 214.2018345296383
Training performance: (0.9164751635142302, 0.9164751635142302, 0.9164751635142302, None)
Total development loss: 648.8073130249977
Development performance: (0.6509559214020181, 0.6509559214020181, 0.6509559214020181, None)


HBox(children=(IntProgress(value=0, max=707), HTML(value='')))


=== Epoch 4 ===
Total training loss: 137.84940619766712
Training performance: (0.9481173767014318, 0.9481173767014318, 0.9481173767014319, None)
Total development loss: 665.8534139245749
Development performance: (0.658656399362719, 0.658656399362719, 0.658656399362719, None)
No improvement on development set. Finishing training.


## Testing

In [13]:
from sklearn.metrics import classification_report

def test(model, state_path, test_iter, batch_size, num_batches, target_names):
    
    model.load_state_dict(torch.load(state_path))
    
    predictions, correct = [], []
    for batch in test_iter:

        pred = model(batch.text.to(device))
        _, pred_indices = torch.max(pred, 1)

        pred_indices = list(pred_indices.cpu().numpy())
        predictions += pred_indices
        correct += list(batch.label.cpu().numpy())

    print(classification_report(correct, predictions, target_names=target_names))

In [14]:
STATE_PATH = best_state_path
num_batches = int(len(test_data) / BATCH_SIZE)

classifier = CNNClassifier(EMBEDDING_DIM, FILTER_SIZES, NUM_FILTERS, VOCAB_SIZE+2, NUM_CLASSES)  

test(classifier.to(device), STATE_PATH, test_iter, BATCH_SIZE, num_batches, target_names)

                          precision    recall  f1-score   support

             alt.atheism       0.55      0.63      0.58       319
           comp.graphics       0.58      0.38      0.46       389
 comp.os.ms-windows.misc       0.56      0.67      0.61       394
comp.sys.ibm.pc.hardware       0.51      0.47      0.49       392
   comp.sys.mac.hardware       0.59      0.65      0.62       385
          comp.windows.x       0.67      0.69      0.68       395
            misc.forsale       0.76      0.71      0.73       390
               rec.autos       0.75      0.63      0.68       396
         rec.motorcycles       0.75      0.83      0.79       398
      rec.sport.baseball       0.70      0.60      0.65       397
        rec.sport.hockey       0.74      0.81      0.78       399
               sci.crypt       0.84      0.72      0.78       396
         sci.electronics       0.34      0.48      0.40       393
                 sci.med       0.56      0.56      0.56       396
         