BERT for binary or multiclass document classification using the [CLS] token as the document representation; trains a model (on `train.txt`), uses `dev.txt` for early stopping, and evaluates performance on `test.txt`.  Reports test accuracy with 95% confidence intervals.

Before executing this notebook on Colab, make sure you're running on cuda (`Runtime > Change runtime type > GPU`) to make use of GPU speedups.

In [1]:
from transformers import BertModel, BertTokenizer
import nltk
import torch
import torch.nn as nn
import numpy as np
import random
from scipy.stats import norm
import math

import pandas as pd 

from sklearn.model_selection import train_test_split
from sklearn.metrics import average_precision_score
from tqdm import tqdm

In [2]:
# Change this to the directory with your data
directory="../data/"

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on {}".format(device))

Running on cpu


In [6]:
labels = [0,1]

In [48]:
df = pd.read_csv("../data/train_test_set.csv")
train_x, train_y = \
    df[df["train"] == 1]["clean_items"].to_list(), df[df["train"] == 1]["bank_status"].to_list()
test_x, test_y= df[df["train"] == 0]["clean_items"].to_list(), df[df["train"] == 0]["bank_status"].to_list()

In [49]:
def evaluate(model, x, y):
    model.eval()
    y_true = []
    y_preds = []
    with torch.no_grad():
        for x, y in zip(x, y):
            y_true.extend(y.tolist())
            y_preds.extend((nn.functional.softmax(model.forward(x), dim=1)[:,1]).tolist())  
    return average_precision_score(y_true, y_preds)

In [50]:
class BERTClassifier(nn.Module):

    def __init__(self, bert_model_name, params):
        super().__init__()
    
        self.model_name=bert_model_name
        self.tokenizer = BertTokenizer.from_pretrained(self.model_name, do_lower_case=params["doLowerCase"], do_basic_tokenize=False)
        self.bert = BertModel.from_pretrained(self.model_name)
        
        self.num_labels = params["label_length"]

        self.fc = nn.Linear(params["embedding_size"], self.num_labels)

    def get_batches(self, all_x, all_y, batch_size=32, max_toks=510, train=True):
        """ Get batches for input x, y data, with data tokenized according to the BERT tokenizer
        (and limited to a maximum number of WordPiece tokens) """

        batches_x = []
        batches_y = []

        if train:
            pos_x = [all_x[i] for i in range(len(all_x)) if all_y[i] == 1]
            neg_x = [all_x[i] for i in range(len(all_x)) if all_y[i] == 0]

            if batch_size // 2 > len(pos_x):
                neg_len = batch_size - len(pos_x)
                neg_batches = [neg_x[i:i + neg_len] for i in range(0, len(neg_x), neg_len)]
                batch_y = [0] * neg_len + [1] * len(pos_x)
                for neg_batch in neg_batches:
                    x = neg_batch + pos_x
                    batch_x = self.tokenizer(x, padding=True, truncation=True, return_tensors="pt", max_length=max_toks)
                    batches_x.append(batch_x.to(device))
                    batches_y.append(torch.LongTensor(batch_y).to(device))
            else:
                neg_batches = [neg_x[i:i + batch_size // 2] for i in range(0, len(neg_x), batch_size // 2)]
                for neg_batch in neg_batches:
                    x = neg_batch + random.sample(pos_x, batch_size // 2)
                    batch_x = self.tokenizer(x, padding=True, truncation=True, return_tensors="pt", max_length=max_toks)
                    batch_y = [0] * len(neg_batch) + [1] * (batch_size // 2)
                    batches_x.append(batch_x.to(device))
                    batches_y.append(torch.LongTensor(batch_y).to(device))
        else:
            for i in range(0, len(all_x), batch_size):

                x=all_x[i:i+batch_size]

                batch_x = self.tokenizer(x, padding=True, truncation=True, return_tensors="pt", max_length=max_toks)
                batch_y=all_y[i:i+batch_size]

                batches_x.append(batch_x.to(device))
                batches_y.append(torch.LongTensor(batch_y).to(device))

        return batches_x, batches_y

    def forward(self, batch_x): 
    
        bert_output = self.bert(input_ids=batch_x["input_ids"],
                         attention_mask=batch_x["attention_mask"],
                         token_type_ids=batch_x["token_type_ids"],
                         output_hidden_states=True)

      # We're going to represent an entire document just by its [CLS] embedding (at position 0)
      # And use the *last* layer output (layer -1)
      # as a result of this choice, this embedding will be optimized for this purpose during the training process.
      
        bert_hidden_states = bert_output['hidden_states']

        out = bert_hidden_states[-1][:,0,:]

        out = self.fc(out)

        return out.squeeze()

In [55]:
def train(bert_model_name, model_filename, train_x, train_y, dev_x, dev_y, labels, embedding_size=768, doLowerCase=None):

    bert_model = BERTClassifier(bert_model_name, params={"label_length": len(labels), "doLowerCase":doLowerCase, "embedding_size":embedding_size})
    bert_model.to(device)

    batch_x, batch_y = bert_model.get_batches(train_x, train_y)

    dev_batch_x, dev_batch_y = bert_model.get_batches(dev_x, dev_y, train=False)


    optimizer = torch.optim.Adam(bert_model.parameters(), lr=1e-5)
    cross_entropy=nn.CrossEntropyLoss()

    num_epochs=30
    best_dev_acc = 0.
    patience=10

    best_epoch=0

    for epoch in tqdm(range(num_epochs)):
        bert_model.train()

        # Train
        for x, y in zip(batch_x, batch_y):
            y_pred = bert_model.forward(x)
            loss = cross_entropy(y_pred.view(-1, bert_model.num_labels), y.view(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Evaluate
        dev_accuracy=evaluate(bert_model, dev_batch_x, dev_batch_y)
        if epoch % 1 == 0:
            print("Epoch %s, dev accuracy: %.3f" % (epoch, dev_accuracy))
            if dev_accuracy > best_dev_acc:
                torch.save(bert_model.state_dict(), model_filename)
                best_dev_acc = dev_accuracy
                best_epoch=epoch
        if epoch - best_epoch > patience:
            print("No improvement in dev accuracy over %s epochs; stopping training" % patience)
            break

    bert_model.load_state_dict(torch.load(model_filename))
    print("\nBest Performing Model achieves dev accuracy of : %.3f" % (best_dev_acc))
    return bert_model


In [56]:
# small BERT -- can run on laptop
bert_model_name="google/bert_uncased_L-2_H-128_A-2"
model_filename="mybert.model"
embedding_size=128
doLowerCase=True

# bert-base -- slow on laptop; better on Colab
# bert_model_name="bert-base-cased"
# model_filename="mybert.model"
# embedding_size=768
# doLowerCase=False

model=train(bert_model_name, model_filename, train_x, train_y, test_x, test_y, labels, embedding_size=embedding_size, doLowerCase=doLowerCase)

  3%|▎         | 1/30 [03:27<1:40:28, 207.87s/it]

Epoch 0, dev accuracy: 0.290


  7%|▋         | 2/30 [06:54<1:36:41, 207.18s/it]

Epoch 1, dev accuracy: 0.328


 10%|█         | 3/30 [10:21<1:33:14, 207.22s/it]

Epoch 2, dev accuracy: 0.307


 13%|█▎        | 4/30 [13:49<1:29:53, 207.46s/it]

Epoch 3, dev accuracy: 0.287


 17%|█▋        | 5/30 [17:18<1:26:42, 208.10s/it]

Epoch 4, dev accuracy: 0.290


 20%|██        | 6/30 [20:46<1:23:10, 207.95s/it]

Epoch 5, dev accuracy: 0.290


 23%|██▎       | 7/30 [24:15<1:19:52, 208.37s/it]

Epoch 6, dev accuracy: 0.246


 27%|██▋       | 8/30 [28:17<1:20:15, 218.89s/it]

Epoch 7, dev accuracy: 0.279


 30%|███       | 9/30 [32:18<1:19:06, 226.01s/it]

Epoch 8, dev accuracy: 0.245


 33%|███▎      | 10/30 [36:20<1:16:59, 230.98s/it]

Epoch 9, dev accuracy: 0.240


 37%|███▋      | 11/30 [40:23<1:14:13, 234.40s/it]

Epoch 10, dev accuracy: 0.235


 40%|████      | 12/30 [44:25<1:11:01, 236.77s/it]

Epoch 11, dev accuracy: 0.231


 40%|████      | 12/30 [48:26<1:12:40, 242.25s/it]

Epoch 12, dev accuracy: 0.213
No improvement in dev accuracy over 10 epochs; stopping training

Best Performing Model achieves dev accuracy of : 0.328





In [57]:
test_batch_x, test_batch_y = model.get_batches(test_x, test_y, len(test_y))

In [58]:
evaluate(model, test_batch_x, test_batch_y)

0.32810212457775983