In [1]:
TRAIN_PATH = "data/text_classification/20newsgroups_train.tsv"
TEST_PATH = "data/text_classification/20newsgroups_test.tsv"

In [2]:
from sklearn.datasets import fetch_20newsgroups

train = fetch_20newsgroups(subset="train")
label2idx = {label: idx for idx, label in enumerate(train.target_names)}

In [3]:
import csv
import sys
from torchtext.data import TabularDataset, Field, BucketIterator

csv.field_size_limit(sys.maxsize)

text = Field(sequential=True, tokenize="spacy")
label = Field(sequential=False, use_vocab=False, preprocessing=lambda x: label2idx[x])

train_data = TabularDataset(path=TRAIN_PATH, format='tsv', fields=[('label', label), ('text', text)])
test_data = TabularDataset(path=TEST_PATH, format='tsv', fields=[('label', label), ('text', text)])

In [4]:
train_data.fields

{'label': <torchtext.data.field.Field at 0x1a2265cc50>,
 'text': <torchtext.data.field.Field at 0x1a2265cbe0>}

In [4]:
VOCAB_SIZE = 30000

text.build_vocab(train_data, max_size=VOCAB_SIZE)

In [5]:
BATCH_SIZE = 32
train_iter = BucketIterator(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
dev_iter = BucketIterator(dataset=test_data, batch_size=BATCH_SIZE)

In [6]:
import torch.nn as nn

class LSTMClassifier(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size, output_size):
        super(LSTMClassifier, self).__init__()
        
        # 1. Embedding Layer
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        
        # 2. LSTM Layer
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True, num_layers=2)
        
        # 3. Dense Layer
        self.hidden2out = nn.Linear(hidden_dim, output_size)
        
        # Optional dropout layer
        self.dropout_layer = nn.Dropout(p=0.4)

    def forward(self, batch_text):

        embeddings = self.embeddings(batch_text)
        _, (ht, _) = self.lstm(embeddings)

        lstm_output = ht[-1]
        lstm_output = self.dropout_layer(lstm_output)
        final_output = self.hidden2out(lstm_output)
        return final_output

In [7]:
import torch
import torch.optim as optim
from tqdm import tqdm_notebook as tqdm

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


def train(model, train_iter, dev_iter, batch_size, num_batches):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())

    max_epochs = 20
    for epoch in range(max_epochs):

        total_loss = 0
        predictions, correct = [], []
        for batch in tqdm(train_iter, total=num_batches):
            optimizer.zero_grad()

            pred = model(batch.text.to(device))
            loss = criterion(pred, batch.label.to(device))
            total_loss += loss.item()

            loss.backward()
            optimizer.step()

            _, pred_indices = torch.max(pred, 1)
            predictions += list(pred_indices.cpu().numpy())
            correct += list(batch.label.cpu().numpy())

        print("=== Epoch", epoch, "===")
        print("Total training loss:", total_loss)
        print("Training performance:", precision_recall_fscore_support(correct, predictions))
        
        total_loss = 0
        predictions, correct = [], []
        for batch in dev_iter:

            pred = model(batch.text.to(device))
            loss = criterion(pred, batch.label.to(device))
            total_loss += loss.item()

            _, pred_indices = torch.max(pred, 1)
            pred_indices = list(pred_indices.cpu().numpy())
            predictions += pred_indices
            correct += list(batch.label.cpu().numpy())

        print("Total development loss:", total_loss)
        print("Development performance:", precision_recall_fscore_support(correct, predictions))
            

In [None]:
EMBEDDING_DIM = 300
HIDDEN_DIM = 256
NUM_CLASSES = len(label2idx)
num_batches = int(len(train_data) / BATCH_SIZE)

classifier = LSTMClassifier(EMBEDDING_DIM, HIDDEN_DIM, VOCAB_SIZE+2, NUM_CLASSES)  

train(classifier.to(device), train_iter, dev_iter, BATCH_SIZE, num_batches)


HBox(children=(IntProgress(value=0, max=353), HTML(value='')))