In [1]:
import numpy as np
from torchnlp.datasets import imdb_dataset # run pip install pytorch-nlp if you dont have this
from tamnun.bert import BertClassifier, BertVectorizer
from sklearn.pipeline import make_pipeline
from sklearn.metrics import classification_report, accuracy_score

## Prepare the data

In [2]:
train_data, test_data = imdb_dataset(train=True, test=True)

Extract the texts and the labels from the imdb datasets

In [3]:
train_texts, train_labels = list(zip(*map(lambda d: (d['text'], d['sentiment']), train_data)))
test_texts, test_labels = list(zip(*map(lambda d: (d['text'], d['sentiment']), test_data)))

len(train_texts), len(train_labels), len(test_texts), len(test_labels)

(25000, 25000, 25000, 25000)

Convert the target variable to a binary vector neg=0, pos=1

In [4]:
train_y = np.array(train_labels) == 'pos'
test_y = np.array(test_labels) == 'pos'
train_y.shape, test_y.shape, np.mean(train_y), np.mean(test_y)

((25000,), (25000,), 0.5, 0.5)

## Fine-Tune BERT

In [5]:
clf = make_pipeline(BertVectorizer(do_truncate=True), BertClassifier(num_of_classes=2, lr=1e-5)).fit(train_texts, train_y)

Epoch 1/5:
6249/6250 batch loss: 0.5317028760910034    avg loss: 0.2504522643852234
Epoch 2/5:
6249/6250 batch loss: 0.4307538568973541 6  avg loss: 0.12871999006032944
Epoch 3/5:
6249/6250 batch loss: 0.006784558296203613  avg loss: 0.07053191775560379
Epoch 4/5:
6249/6250 batch loss: 0.003382742404937744   avg loss: 0.04482248139500618
Epoch 5/5:
6249/6250 batch loss: 0.0027724504470825195  avg loss: 0.03316646183252334


In [6]:
predicted = clf.predict(test_texts)

In [7]:
print('Accuracy:', accuracy_score(predicted, test_y))
print(classification_report(predicted, test_y))

Accuracy: 0.93604
              precision    recall  f1-score   support

           0       0.94      0.93      0.94     12663
           1       0.93      0.94      0.94     12337

   micro avg       0.94      0.94      0.94     25000
   macro avg       0.94      0.94      0.94     25000
weighted avg       0.94      0.94      0.94     25000

