# First BERT Experiments

In this notebook we do some first experiments with BERT: we finetune a BERT model+classifier on each of our datasets separately and compute the accuracy of the resulting classifier on the test data.

For these experiments we use the `pytorch_transformers` package. It contains a variety of neural network architectures for transfer learning and pretrained models, including BERT and XLNET.

Two different BERT models are relevant for our experiments: 

- BERT-base-uncased: a relatively small BERT model that should already give reasonable results,
- BERT-large-uncased: a larger model for real state-of-the-art results.

In [2]:
BERT_MODEL = 'bert-base-uncased'
BATCH_SIZE = 16 if "base" in BERT_MODEL else 2
GRADIENT_ACCUMULATION_STEPS = 1 if "base" in BERT_MODEL else 8
MAX_SEQ_LENGTH = 100
PREFIX = "junkfood_but"

## Data

We use the same data as for all our previous experiments. Here we load the training, development and test data for a particular prompt.

In [3]:
import sys
sys.path.append('../')

import ndjson
import glob

from quillnlp.models.bert.preprocessing import preprocess, create_label_vocabulary, get_data_loader

train_file = f"../data/interim/{PREFIX}_train_withprompt.ndjson"
dev_file = f"../data/interim/{PREFIX}_dev_withprompt.ndjson"
test_file = f"../data/interim/{PREFIX}_test_withprompt.ndjson"

with open(train_file) as i:
    train_data = ndjson.load(i)
    
with open(dev_file) as i:
    dev_data = ndjson.load(i)
    
with open(test_file) as i:
    test_data = ndjson.load(i)
    

train_data[:3]

[{'text': 'Schools should not allow junk food to be sold on campus but its not healthy',
  'label': 'Unclassified Off-Topic'},
 {'text': 'Schools should not allow junk food to be sold on campus but it is up to the discretion of each school or the school board',
  'label': 'School without generating money'},
 {'text': 'Schools should not allow junk food to be sold on campus but BUT HEALTHY SNACKS',
  'label': 'Schools providing healthy alternatives'}]

In [4]:
label2idx = create_label_vocabulary(train_data)
idx2label = {v:k for k,v in label2idx.items()}
target_names = [idx2label[s] for s in range(len(idx2label))]

print(label2idx)
print(idx2label)
print(target_names)

{'Unclassified Off-Topic': 0, 'School without generating money': 1, 'Schools providing healthy alternatives': 2, 'Student choice': 3, 'Students without choice': 4, 'Schools generate money': 5, 'Students can still bring/access junk food': 6}
{0: 'Unclassified Off-Topic', 1: 'School without generating money', 2: 'Schools providing healthy alternatives', 3: 'Student choice', 4: 'Students without choice', 5: 'Schools generate money', 6: 'Students can still bring/access junk food'}
['Unclassified Off-Topic', 'School without generating money', 'Schools providing healthy alternatives', 'Student choice', 'Students without choice', 'Schools generate money', 'Students can still bring/access junk food']


In [5]:
train_dataloader = get_data_loader(preprocess(train_data, BERT_MODEL, label2idx, MAX_SEQ_LENGTH), BATCH_SIZE)
dev_dataloader = get_data_loader(preprocess(dev_data, BERT_MODEL, label2idx, MAX_SEQ_LENGTH), BATCH_SIZE)
test_dataloader = get_data_loader(preprocess(test_data, BERT_MODEL, label2idx, MAX_SEQ_LENGTH), BATCH_SIZE, shuffle=False)

I1009 17:26:21.039733 140302331909952 tokenization_utils.py:373] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/yves/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
I1009 17:26:21.946402 140302331909952 tokenization_utils.py:373] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/yves/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
I1009 17:26:22.463976 140302331909952 tokenization_utils.py:373] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/yves/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d

## Model

In [3]:
import torch
from quillnlp.models.bert.models import get_bert_classifier

device = "cuda" if torch.cuda.is_available() else "cpu"
model = get_bert_classifier(BERT_MODEL, len(label2idx), device=device)

I1008 10:41:23.733693 140067368421184 configuration_utils.py:151] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json from cache at /home/yves/.cache/torch/transformers/a41e817d5c0743e29e86ff85edc8c257e61bc8d88e4271bb1b243b6e7614c633.1ccd1a11c9ff276830e114ea477ea2407100f4a3be7bdc45d37be9e37fa71c7e
I1008 10:41:23.735444 140067368421184 configuration_utils.py:168] Model config {
  "activation": "gelu",
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "finetuning_task": null,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "n_heads": 12,
  "n_layers": 6,
  "num_labels": 7,
  "output_attentions": false,
  "output_hidden_states": false,
  "pruned_heads": {},
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "torchscript": false,
  "use_bfloat16": false,
  "vocab_size": 30522
}

I1008 10:41:24.216669 140067368421184 mod

## Training

In [4]:
from quillnlp.models.bert.train import train

output_model_file = train(model, train_dataloader, dev_dataloader, BATCH_SIZE, GRADIENT_ACCUMULATION_STEPS, device)

Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: []
Dev loss: 1.6473965406417848


Epoch:   5%|▌         | 1/20 [00:02<00:47,  2.51s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [1.6473965406417848]
Dev loss: 1.2385924696922301


Epoch:  10%|█         | 2/20 [00:04<00:44,  2.45s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [1.6473965406417848, 1.2385924696922301]
Dev loss: 1.0272655010223388


Epoch:  15%|█▌        | 3/20 [00:07<00:40,  2.40s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [1.6473965406417848, 1.2385924696922301, 1.0272655010223388]
Dev loss: 0.9899182200431824


Epoch:  20%|██        | 4/20 [00:09<00:37,  2.36s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [1.6473965406417848, 1.2385924696922301, 1.0272655010223388, 0.9899182200431824]
Dev loss: 0.9163646578788758


Epoch:  25%|██▌       | 5/20 [00:11<00:34,  2.32s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [1.6473965406417848, 1.2385924696922301, 1.0272655010223388, 0.9899182200431824, 0.9163646578788758]
Dev loss: 0.7455862879753112


Epoch:  30%|███       | 6/20 [00:13<00:32,  2.31s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  35%|███▌      | 7/20 [00:15<00:28,  2.21s/it]


Loss history: [1.6473965406417848, 1.2385924696922301, 1.0272655010223388, 0.9899182200431824, 0.9163646578788758, 0.7455862879753112]
Dev loss: 0.8208680033683777


HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [1.6473965406417848, 1.2385924696922301, 1.0272655010223388, 0.9899182200431824, 0.9163646578788758, 0.7455862879753112, 0.8208680033683777]
Dev loss: 0.6801283061504364


Epoch:  40%|████      | 8/20 [00:18<00:26,  2.23s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [1.6473965406417848, 1.2385924696922301, 1.0272655010223388, 0.9899182200431824, 0.9163646578788758, 0.7455862879753112, 0.8208680033683777, 0.6801283061504364]
Dev loss: 0.6410963833332062


Epoch:  45%|████▌     | 9/20 [00:20<00:24,  2.24s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  50%|█████     | 10/20 [00:22<00:21,  2.17s/it]


Loss history: [1.6473965406417848, 1.2385924696922301, 1.0272655010223388, 0.9899182200431824, 0.9163646578788758, 0.7455862879753112, 0.8208680033683777, 0.6801283061504364, 0.6410963833332062]
Dev loss: 0.6926881432533264


HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [1.6473965406417848, 1.2385924696922301, 1.0272655010223388, 0.9899182200431824, 0.9163646578788758, 0.7455862879753112, 0.8208680033683777, 0.6801283061504364, 0.6410963833332062, 0.6926881432533264]
Dev loss: 0.6163372963666915


Epoch:  55%|█████▌    | 11/20 [00:24<00:19,  2.20s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  60%|██████    | 12/20 [00:26<00:17,  2.14s/it]


Loss history: [1.6473965406417848, 1.2385924696922301, 1.0272655010223388, 0.9899182200431824, 0.9163646578788758, 0.7455862879753112, 0.8208680033683777, 0.6801283061504364, 0.6410963833332062, 0.6926881432533264, 0.6163372963666915]
Dev loss: 0.6203684329986572


HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  65%|██████▌   | 13/20 [00:28<00:14,  2.09s/it]


Loss history: [1.6473965406417848, 1.2385924696922301, 1.0272655010223388, 0.9899182200431824, 0.9163646578788758, 0.7455862879753112, 0.8208680033683777, 0.6801283061504364, 0.6410963833332062, 0.6926881432533264, 0.6163372963666915, 0.6203684329986572]
Dev loss: 0.6200320914387702


HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  70%|███████   | 14/20 [00:30<00:12,  2.06s/it]


Loss history: [1.6473965406417848, 1.2385924696922301, 1.0272655010223388, 0.9899182200431824, 0.9163646578788758, 0.7455862879753112, 0.8208680033683777, 0.6801283061504364, 0.6410963833332062, 0.6926881432533264, 0.6163372963666915, 0.6203684329986572, 0.6200320914387702]
Dev loss: 0.7257069826126099


HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  75%|███████▌  | 15/20 [00:32<00:10,  2.05s/it]


Loss history: [1.6473965406417848, 1.2385924696922301, 1.0272655010223388, 0.9899182200431824, 0.9163646578788758, 0.7455862879753112, 0.8208680033683777, 0.6801283061504364, 0.6410963833332062, 0.6926881432533264, 0.6163372963666915, 0.6203684329986572, 0.6200320914387702, 0.7257069826126099]
Dev loss: 0.6236937135457993


HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [1.6473965406417848, 1.2385924696922301, 1.0272655010223388, 0.9899182200431824, 0.9163646578788758, 0.7455862879753112, 0.8208680033683777, 0.6801283061504364, 0.6410963833332062, 0.6926881432533264, 0.6163372963666915, 0.6203684329986572, 0.6200320914387702, 0.7257069826126099, 0.6236937135457993]
Dev loss: 0.6734998047351837
No improvement on development set. Finish training.





## Evaluation

In [5]:
from quillnlp.models.bert.train import evaluate
from sklearn.metrics import precision_recall_fscore_support, classification_report

print("Loading model from", output_model_file)
device="cpu"

model = get_bert_classifier(BERT_MODEL, len(label2idx), model_file=output_model_file, device=device)
model.eval()

_, test_correct, test_predicted = evaluate(model, test_dataloader, device)

print("Test performance:", precision_recall_fscore_support(test_correct, test_predicted, average="micro"))
print(classification_report(test_correct, test_predicted, target_names=target_names))

Loading model from /tmp/model.bin


I1008 10:42:02.733748 140067368421184 configuration_utils.py:151] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json from cache at /home/yves/.cache/torch/transformers/a41e817d5c0743e29e86ff85edc8c257e61bc8d88e4271bb1b243b6e7614c633.1ccd1a11c9ff276830e114ea477ea2407100f4a3be7bdc45d37be9e37fa71c7e
I1008 10:42:02.735480 140067368421184 configuration_utils.py:168] Model config {
  "activation": "gelu",
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "finetuning_task": null,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "n_heads": 12,
  "n_layers": 6,
  "num_labels": 7,
  "output_attentions": false,
  "output_hidden_states": false,
  "pruned_heads": {},
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "torchscript": false,
  "use_bfloat16": false,
  "vocab_size": 30522
}

I1008 10:42:03.191801 140067368421184 mod

HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=10, style=ProgressStyle(descriptio…


Test performance: (0.7908496732026143, 0.7908496732026143, 0.7908496732026143, None)
                                           precision    recall  f1-score   support

                   Unclassified Off-Topic       0.56      0.45      0.50        11
          School without generating money       0.60      0.38      0.46        16
   Schools providing healthy alternatives       0.91      0.93      0.92        75
                           Student choice       0.47      1.00      0.64         7
                  Students without choice       0.74      0.79      0.76        33
                   Schools generate money       1.00      0.88      0.93         8
Students can still bring/access junk food       0.00      0.00      0.00         3

                              avg / total       0.78      0.79      0.78       153



  'precision', 'predicted', average, warn_for)


In [6]:
c = 0
for item, predicted, correct in zip(test_data, test_predicted, test_correct):
    assert item["label"] == idx2label[correct]
    c += (item["label"] == idx2label[predicted])
    print("{}#{}#{}".format(item["text"], idx2label[correct], idx2label[predicted]))
    
print(c)
print(c/len(test_data))

Schools should not allow junk food to be sold on campus but kids will still bring in unhealthy food#Students without choice#Students without choice
Schools should not allow junk food to be sold on campus but some think students should be able to choose what they eat#Student choice#Student choice
Schools should not allow junk food to be sold on campus but maybe on certain special occasions or at events#Unclassified Off-Topic#Unclassified Off-Topic
Schools should not allow junk food to be sold on campus but students may bring it anyway#Students without choice#Students without choice
Schools should not allow junk food to be sold on campus but they can choose to sell food that are nutritious and healthy#Schools providing healthy alternatives#Schools providing healthy alternatives
Schools should not allow junk food to be sold on campus but ultimately it is up to the individual student what they eat or drink#Students without choice#Students without choice
Schools should not allow junk food t