# First BERT Experiments

In this notebook we do some first experiments with BERT: we finetune a BERT model+classifier on each of our datasets separately and compute the accuracy of the resulting classifier on the test data.

For these experiments we use the `pytorch_transformers` package. It contains a variety of neural network architectures for transfer learning and pretrained models, including BERT and XLNET.

Two different BERT models are relevant for our experiments: 

- BERT-base-uncased: a relatively small BERT model that should already give reasonable results,
- BERT-large-uncased: a larger model for real state-of-the-art results.

In [1]:
BERT_MODEL = 'bert-base-uncased'
BATCH_SIZE = 16 if "base" in BERT_MODEL else 2
GRADIENT_ACCUMULATION_STEPS = 1 if "base" in BERT_MODEL else 8
MAX_SEQ_LENGTH = 100
PREFIX = "junkfood_but"

## Data

We use the same data as for all our previous experiments. Here we load the training, development and test data for a particular prompt.

In [2]:
import sys
sys.path.append('../')

import ndjson
import glob

from quillnlp.models.bert.preprocessing import preprocess, create_label_vocabulary

train_file = f"../data/interim/{PREFIX}_train_withprompt_diverse200.ndjson"
synth_files = glob.glob(f"../data/interim/{PREFIX}_train_withprompt_*.ndjson")
dev_file = f"../data/interim/{PREFIX}_dev_withprompt.ndjson"
test_file = f"../data/interim/{PREFIX}_test_withprompt.ndjson"

with open(train_file) as i:
    train_data = ndjson.load(i)

synth_data = []
for f in synth_files:
    if "allsynth" in f:
        continue
    with open(f) as i:
        synth_data += ndjson.load(i)
    
with open(dev_file) as i:
    dev_data = ndjson.load(i)
    
with open(test_file) as i:
    test_data = ndjson.load(i)
    
label2idx = create_label_vocabulary(train_data)
idx2label = {v:k for k,v in label2idx.items()}
target_names = [idx2label[s] for s in range(len(idx2label))]

train_dataloader = preprocess(train_data, BERT_MODEL, label2idx, MAX_SEQ_LENGTH, BATCH_SIZE)
dev_dataloader = preprocess(dev_data, BERT_MODEL, label2idx, MAX_SEQ_LENGTH, BATCH_SIZE)
test_dataloader = preprocess(test_data, BERT_MODEL, label2idx, MAX_SEQ_LENGTH, BATCH_SIZE, shuffle=False)

## Model

In [3]:
import torch
from quillnlp.models.bert.models import get_bert_classifier

device = "cuda" if torch.cuda.is_available() else "cpu"
model = get_bert_classifier(BERT_MODEL, len(label2idx), device=device)

## Training

In [4]:
from quillnlp.models.bert.train import train

output_model_file = train(model, train_dataloader, dev_dataloader, BATCH_SIZE, GRADIENT_ACCUMULATION_STEPS, device)

Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

HBox(children=(IntProgress(value=0, description='Training iteration', max=13, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: []
Dev loss: 1.6459675550460815


Epoch:   5%|▌         | 1/20 [00:03<01:08,  3.63s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=13, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [1.6459675550460815]
Dev loss: 1.5981466054916382


Epoch:  10%|█         | 2/20 [00:07<01:04,  3.56s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=13, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [1.6459675550460815, 1.5981466054916382]
Dev loss: 1.494327187538147


Epoch:  15%|█▌        | 3/20 [00:10<00:59,  3.52s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=13, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [1.6459675550460815, 1.5981466054916382, 1.494327187538147]
Dev loss: 1.2899302244186401


Epoch:  20%|██        | 4/20 [00:13<00:55,  3.50s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=13, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [1.6459675550460815, 1.5981466054916382, 1.494327187538147, 1.2899302244186401]
Dev loss: 1.02061585187912


Epoch:  25%|██▌       | 5/20 [00:17<00:52,  3.48s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=13, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  30%|███       | 6/20 [00:20<00:46,  3.32s/it]


Loss history: [1.6459675550460815, 1.5981466054916382, 1.494327187538147, 1.2899302244186401, 1.02061585187912]
Dev loss: 1.047785222530365


HBox(children=(IntProgress(value=0, description='Training iteration', max=13, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [1.6459675550460815, 1.5981466054916382, 1.494327187538147, 1.2899302244186401, 1.02061585187912, 1.047785222530365]
Dev loss: 0.9007106423377991


Epoch:  35%|███▌      | 7/20 [00:23<00:43,  3.36s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=13, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [1.6459675550460815, 1.5981466054916382, 1.494327187538147, 1.2899302244186401, 1.02061585187912, 1.047785222530365, 0.9007106423377991]
Dev loss: 0.7825210511684417


Epoch:  40%|████      | 8/20 [00:27<00:40,  3.38s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=13, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  45%|████▌     | 9/20 [00:30<00:35,  3.25s/it]


Loss history: [1.6459675550460815, 1.5981466054916382, 1.494327187538147, 1.2899302244186401, 1.02061585187912, 1.047785222530365, 0.9007106423377991, 0.7825210511684417]
Dev loss: 0.8229950904846192


HBox(children=(IntProgress(value=0, description='Training iteration', max=13, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  50%|█████     | 10/20 [00:33<00:31,  3.17s/it]


Loss history: [1.6459675550460815, 1.5981466054916382, 1.494327187538147, 1.2899302244186401, 1.02061585187912, 1.047785222530365, 0.9007106423377991, 0.7825210511684417, 0.8229950904846192]
Dev loss: 0.803441333770752


HBox(children=(IntProgress(value=0, description='Training iteration', max=13, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [1.6459675550460815, 1.5981466054916382, 1.494327187538147, 1.2899302244186401, 1.02061585187912, 1.047785222530365, 0.9007106423377991, 0.7825210511684417, 0.8229950904846192, 0.803441333770752]
Dev loss: 0.72945476770401


Epoch:  55%|█████▌    | 11/20 [00:36<00:29,  3.25s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=13, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  60%|██████    | 12/20 [00:39<00:25,  3.16s/it]


Loss history: [1.6459675550460815, 1.5981466054916382, 1.494327187538147, 1.2899302244186401, 1.02061585187912, 1.047785222530365, 0.9007106423377991, 0.7825210511684417, 0.8229950904846192, 0.803441333770752, 0.72945476770401]
Dev loss: 0.7813196420669556


HBox(children=(IntProgress(value=0, description='Training iteration', max=13, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [1.6459675550460815, 1.5981466054916382, 1.494327187538147, 1.2899302244186401, 1.02061585187912, 1.047785222530365, 0.9007106423377991, 0.7825210511684417, 0.8229950904846192, 0.803441333770752, 0.72945476770401, 0.7813196420669556]
Dev loss: 0.7008043587207794


Epoch:  65%|██████▌   | 13/20 [00:42<00:22,  3.25s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=13, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  70%|███████   | 14/20 [00:45<00:18,  3.16s/it]


Loss history: [1.6459675550460815, 1.5981466054916382, 1.494327187538147, 1.2899302244186401, 1.02061585187912, 1.047785222530365, 0.9007106423377991, 0.7825210511684417, 0.8229950904846192, 0.803441333770752, 0.72945476770401, 0.7813196420669556, 0.7008043587207794]
Dev loss: 0.7750071108341217


HBox(children=(IntProgress(value=0, description='Training iteration', max=13, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  75%|███████▌  | 15/20 [00:48<00:15,  3.11s/it]


Loss history: [1.6459675550460815, 1.5981466054916382, 1.494327187538147, 1.2899302244186401, 1.02061585187912, 1.047785222530365, 0.9007106423377991, 0.7825210511684417, 0.8229950904846192, 0.803441333770752, 0.72945476770401, 0.7813196420669556, 0.7008043587207794, 0.7750071108341217]
Dev loss: 0.7170183479785919


HBox(children=(IntProgress(value=0, description='Training iteration', max=13, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [1.6459675550460815, 1.5981466054916382, 1.494327187538147, 1.2899302244186401, 1.02061585187912, 1.047785222530365, 0.9007106423377991, 0.7825210511684417, 0.8229950904846192, 0.803441333770752, 0.72945476770401, 0.7813196420669556, 0.7008043587207794, 0.7750071108341217, 0.7170183479785919]
Dev loss: 0.6437321200966835


Epoch:  80%|████████  | 16/20 [00:52<00:12,  3.21s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=13, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  85%|████████▌ | 17/20 [00:55<00:09,  3.14s/it]


Loss history: [1.6459675550460815, 1.5981466054916382, 1.494327187538147, 1.2899302244186401, 1.02061585187912, 1.047785222530365, 0.9007106423377991, 0.7825210511684417, 0.8229950904846192, 0.803441333770752, 0.72945476770401, 0.7813196420669556, 0.7008043587207794, 0.7750071108341217, 0.7170183479785919, 0.6437321200966835]
Dev loss: 0.7638194292783738


HBox(children=(IntProgress(value=0, description='Training iteration', max=13, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  90%|█████████ | 18/20 [00:58<00:06,  3.09s/it]


Loss history: [1.6459675550460815, 1.5981466054916382, 1.494327187538147, 1.2899302244186401, 1.02061585187912, 1.047785222530365, 0.9007106423377991, 0.7825210511684417, 0.8229950904846192, 0.803441333770752, 0.72945476770401, 0.7813196420669556, 0.7008043587207794, 0.7750071108341217, 0.7170183479785919, 0.6437321200966835, 0.7638194292783738]
Dev loss: 0.7193580329418182


HBox(children=(IntProgress(value=0, description='Training iteration', max=13, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  95%|█████████▌| 19/20 [01:01<00:03,  3.05s/it]


Loss history: [1.6459675550460815, 1.5981466054916382, 1.494327187538147, 1.2899302244186401, 1.02061585187912, 1.047785222530365, 0.9007106423377991, 0.7825210511684417, 0.8229950904846192, 0.803441333770752, 0.72945476770401, 0.7813196420669556, 0.7008043587207794, 0.7750071108341217, 0.7170183479785919, 0.6437321200966835, 0.7638194292783738, 0.7193580329418182]
Dev loss: 0.7742080330848694


HBox(children=(IntProgress(value=0, description='Training iteration', max=13, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch: 100%|██████████| 20/20 [01:04<00:00,  3.03s/it]


Loss history: [1.6459675550460815, 1.5981466054916382, 1.494327187538147, 1.2899302244186401, 1.02061585187912, 1.047785222530365, 0.9007106423377991, 0.7825210511684417, 0.8229950904846192, 0.803441333770752, 0.72945476770401, 0.7813196420669556, 0.7008043587207794, 0.7750071108341217, 0.7170183479785919, 0.6437321200966835, 0.7638194292783738, 0.7193580329418182, 0.7742080330848694]
Dev loss: 0.7039884328842163





## Evaluation

In [5]:
from quillnlp.models.bert.train import evaluate
from sklearn.metrics import precision_recall_fscore_support, classification_report

print("Loading model from", output_model_file)
device="cpu"

model = get_bert_classifier(BERT_MODEL, len(label2idx), model_file=output_model_file, device=device)
model.eval()

_, test_correct, test_predicted = evaluate(model, test_dataloader, device)

print("Test performance:", precision_recall_fscore_support(test_correct, test_predicted, average="micro"))
print(classification_report(test_correct, test_predicted, target_names=target_names))

Loading model from /tmp/model.bin


HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=10, style=ProgressStyle(descriptio…


Test performance: (0.7712418300653595, 0.7712418300653595, 0.7712418300653595, None)
                                           precision    recall  f1-score   support

   Schools providing healthy alternatives       0.95      0.93      0.94        75
                  Students without choice       0.70      0.79      0.74        33
                   Schools generate money       1.00      0.75      0.86         8
                           Student choice       0.46      0.86      0.60         7
Students can still bring/access junk food       0.00      0.00      0.00         3
                   Unclassified Off-Topic       0.36      0.45      0.40        11
          School without generating money       0.56      0.31      0.40        16

                              avg / total       0.77      0.77      0.76       153



  'precision', 'predicted', average, warn_for)


In [6]:
c = 0
for item, predicted, correct in zip(test_data, test_predicted, test_correct):
    assert item["label"] == idx2label[correct]
    c += (item["label"] == idx2label[predicted])
    print("{}#{}#{}".format(item["text"], idx2label[correct], idx2label[predicted]))
    
print(c)
print(c/len(test_data))

Schools should not allow junk food to be sold on campus but kids will still bring in unhealthy food#Students without choice#Students without choice
Schools should not allow junk food to be sold on campus but some think students should be able to choose what they eat#Student choice#Student choice
Schools should not allow junk food to be sold on campus but maybe on certain special occasions or at events#Unclassified Off-Topic#Unclassified Off-Topic
Schools should not allow junk food to be sold on campus but students may bring it anyway#Students without choice#Students without choice
Schools should not allow junk food to be sold on campus but they can choose to sell food that are nutritious and healthy#Schools providing healthy alternatives#Schools providing healthy alternatives
Schools should not allow junk food to be sold on campus but ultimately it is up to the individual student what they eat or drink#Students without choice#Student choice
Schools should not allow junk food to be sold