In [1]:
import random as rn
import numpy as np
from torchnlp.datasets import imdb_dataset # run pip install pytorch-nlp if you dont have this
from tamnun.bert import BertClassifier, BertVectorizer
from tamnun.transfer import Distiller
from sklearn.pipeline import make_pipeline
from sklearn.metrics import classification_report, accuracy_score
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.linear_model import LogisticRegression, LinearRegression

rn.seed(321)

## Prepare the data

In [2]:
train_data_full, test_data_full = imdb_dataset(train=True, test=True)
rn.shuffle(train_data_full)
rn.shuffle(test_data_full)
train_data = train_data_full[:1000] # We'll use only 1000 examples to speed things up
test_data = test_data_full[:1000] # We'll use only 1000 examples to speed things up
unlabeled_data = train_data_full[1000:6000] # We'll use another 5000 exampels as unlabled data to train the student model


Extract the texts and the labels from the imdb datasets

In [3]:
train_texts, train_labels = list(zip(*map(lambda d: (d['text'], d['sentiment']), train_data)))
test_texts, test_labels = list(zip(*map(lambda d: (d['text'], d['sentiment']), test_data)))
unlabeled_texts = list(map(lambda d: d['text'], unlabeled_data))

len(train_texts), len(train_labels), len(test_texts), len(test_labels), len(unlabeled_texts)

(1000, 1000, 1000, 1000, 5000)

Convert the target variable to a binary vector neg=0, pos=1

In [4]:
train_y = (np.array(train_labels) == 'pos').astype(np.float32)
test_y = (np.array(test_labels) == 'pos').astype(np.float32)
train_y.shape, test_y.shape, np.mean(train_y), np.mean(test_y)

((1000,), (1000,), 0.489, 0.478)

## Baseline (without distilling)

In [5]:
baseline_model = make_pipeline(CountVectorizer(ngram_range=(1,3)), LogisticRegression()).fit(train_texts, train_labels)



In [6]:
baseline_predicted = baseline_model.predict(test_texts)

In [7]:
print('Accuracy:', accuracy_score(test_labels, baseline_predicted))
print(classification_report(test_labels, baseline_predicted))

Accuracy: 0.807
              precision    recall  f1-score   support

         neg       0.83      0.80      0.81       522
         pos       0.79      0.82      0.80       478

   micro avg       0.81      0.81      0.81      1000
   macro avg       0.81      0.81      0.81      1000
weighted avg       0.81      0.81      0.81      1000



## Distilling BERT

In [9]:
bert_clf =  make_pipeline(BertVectorizer(do_truncate=True), BertClassifier(num_of_classes=2, lr=1e-5))
distilled_clf = make_pipeline(CountVectorizer(ngram_range=(1,3)), LinearRegression())

In [10]:
distiller = Distiller(teacher_model=bert_clf, teacher_predict_func=bert_clf.decision_function, student_model=distilled_clf).fit(train_texts, train_y, unlabeled_X=unlabeled_texts)

Epoch 1/5:
249/250 batch loss: 0.6024092435836792 5 avg loss: 0.49781549847126005
Epoch 2/5:
249/250 batch loss: 0.37161239981651306  avg loss: 0.24189828193187712
Epoch 3/5:
249/250 batch loss: 0.02311527729034424  avg loss: 0.11768910476565361
Epoch 4/5:
249/250 batch loss: 0.004524528980255127  avg loss: 0.05928406158089638
Epoch 5/5:
249/250 batch loss: 0.002069532871246338  avg loss: 0.03904806500673294


In [11]:
predicted_logits = distiller.transform(test_texts)

In [12]:
predicted = np.argmax(predicted_logits, axis=1)

In [13]:
print('Accuracy:', accuracy_score(test_y, predicted))
print(classification_report(test_y, predicted))

Accuracy: 0.852
              precision    recall  f1-score   support

         0.0       0.84      0.89      0.86       522
         1.0       0.87      0.82      0.84       478

   micro avg       0.85      0.85      0.85      1000
   macro avg       0.85      0.85      0.85      1000
weighted avg       0.85      0.85      0.85      1000

