# Multilabel BERT Experiments

In this notebook we do some first experiments with BERT: we finetune a BERT model+classifier on each of our datasets separately and compute the accuracy of the resulting classifier on the test data.

For these experiments we use the `pytorch_transformers` package. It contains a variety of neural network architectures for transfer learning and pretrained models, including BERT and XLNET.

Two different BERT models are relevant for our experiments: 

- BERT-base-uncased: a relatively small BERT model that should already give reasonable results,
- BERT-large-uncased: a larger model for real state-of-the-art results.

In [1]:
from multilabel import EATINGMEAT_BECAUSE_MAP, EATINGMEAT_BUT_MAP, JUNKFOOD_BECAUSE_MAP, JUNKFOOD_BUT_MAP

LABEL_MAP = JUNKFOOD_BUT_MAP
BERT_MODEL = 'bert-base-uncased'
BATCH_SIZE = 16 if "base" in BERT_MODEL else 2
GRADIENT_ACCUMULATION_STEPS = 1 if "base" in BERT_MODEL else 8
MAX_SEQ_LENGTH = 100
PREFIX = "junkfood_but"

## Data

We use the same data as for all our previous experiments. Here we load the training, development and test data for a particular prompt.

In [2]:
import ndjson
import glob
from collections import Counter

train_file = f"../data/interim/{PREFIX}_train_withprompt.ndjson"
synth_files = glob.glob(f"../data/interim/{PREFIX}_train_withprompt_allsynth.ndjson")
dev_file = f"../data/interim/{PREFIX}_dev_withprompt.ndjson"
test_file = f"../data/interim/{PREFIX}_test_withprompt.ndjson"

with open(train_file) as i:
    train_data = ndjson.load(i)

synth_data = []
for f in synth_files:
    with open(f) as i:
        synth_data += ndjson.load(i)
    
with open(dev_file) as i:
    dev_data = ndjson.load(i)
    
with open(test_file) as i:
    test_data = ndjson.load(i)
    
labels = Counter([item["label"] for item in train_data])
print(labels)
print(len(synth_data))

Counter({'Schools providing healthy alternatives': 137, 'Students without choice': 46, 'Unclassified Off-Topic': 32, 'School without generating money': 26, 'Student choice': 24, 'Schools generate money': 16, 'Students can still bring/access junk food': 3})
2556


Next, we build the label vocabulary, which maps every label in the training data to an index.

In [3]:
def map_to_multilabel(items):
    return [{"text": item["text"], "label": LABEL_MAP[item["label"]]} for item in items]

train_data = map_to_multilabel(train_data)
dev_data = map_to_multilabel(dev_data)
synth_data = map_to_multilabel(synth_data)
test_data = map_to_multilabel(test_data)

In [4]:
import sys
sys.path.append('../')

from quillnlp.models.bert.preprocessing import preprocess, create_label_vocabulary

label2idx = create_label_vocabulary(train_data)
idx2label = {v:k for k,v in label2idx.items()}
target_names = [idx2label[s] for s in range(len(idx2label))]

MAX_SEQ_LENGTH = 100
train_dataloader = preprocess(train_data, BERT_MODEL, label2idx, MAX_SEQ_LENGTH, BATCH_SIZE)
dev_dataloader = preprocess(dev_data, BERT_MODEL, label2idx, MAX_SEQ_LENGTH, BATCH_SIZE)
test_dataloader = preprocess(test_data, BERT_MODEL, label2idx, MAX_SEQ_LENGTH, BATCH_SIZE, shuffle=False)

## Model

We load the pretrained model and put it on a GPU if one is available. We also put the model in "training" mode, so that we can correctly update its internal parameters on the basis of our data sets.

In [5]:
import sys
sys.path.append('../')

import torch
from quillnlp.models.bert.models import get_multilabel_bert_classifier

BERT_MODEL = 'bert-base-uncased'

device = "cuda" if torch.cuda.is_available() else "cpu"
model = get_multilabel_bert_classifier(BERT_MODEL, len(label2idx), device=device)

## Training

In [6]:
from quillnlp.models.bert.train import train

batch_size = 16 if "base" in BERT_MODEL else 2
gradient_accumulation_steps = 1 if "base" in BERT_MODEL else 8
output_model_file = train(model, train_dataloader, dev_dataloader, batch_size, gradient_accumulation_steps, device)

Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: []
Dev loss: 0.46770793199539185


Epoch:   5%|▌         | 1/20 [00:04<01:30,  4.76s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [0.46770793199539185]
Dev loss: 0.3351816892623901


Epoch:  10%|█         | 2/20 [00:09<01:24,  4.69s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [0.46770793199539185, 0.3351816892623901]
Dev loss: 0.2613398343324661


Epoch:  15%|█▌        | 3/20 [00:13<01:18,  4.65s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [0.46770793199539185, 0.3351816892623901, 0.2613398343324661]
Dev loss: 0.2466729998588562


Epoch:  20%|██        | 4/20 [00:18<01:13,  4.62s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [0.46770793199539185, 0.3351816892623901, 0.2613398343324661, 0.2466729998588562]
Dev loss: 0.21184660196304322


Epoch:  25%|██▌       | 5/20 [00:22<01:08,  4.59s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [0.46770793199539185, 0.3351816892623901, 0.2613398343324661, 0.2466729998588562, 0.21184660196304322]
Dev loss: 0.20426980555057525


Epoch:  30%|███       | 6/20 [00:27<01:04,  4.58s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [0.46770793199539185, 0.3351816892623901, 0.2613398343324661, 0.2466729998588562, 0.21184660196304322, 0.20426980555057525]
Dev loss: 0.1862332671880722


Epoch:  35%|███▌      | 7/20 [00:32<00:59,  4.57s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [0.46770793199539185, 0.3351816892623901, 0.2613398343324661, 0.2466729998588562, 0.21184660196304322, 0.20426980555057525, 0.1862332671880722]
Dev loss: 0.16687243729829787


Epoch:  40%|████      | 8/20 [00:36<00:54,  4.56s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [0.46770793199539185, 0.3351816892623901, 0.2613398343324661, 0.2466729998588562, 0.21184660196304322, 0.20426980555057525, 0.1862332671880722, 0.16687243729829787]
Dev loss: 0.15970188975334168


Epoch:  45%|████▌     | 9/20 [00:41<00:50,  4.56s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  50%|█████     | 10/20 [00:45<00:44,  4.41s/it]


Loss history: [0.46770793199539185, 0.3351816892623901, 0.2613398343324661, 0.2466729998588562, 0.21184660196304322, 0.20426980555057525, 0.1862332671880722, 0.16687243729829787, 0.15970188975334168]
Dev loss: 0.16814120709896088


HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [0.46770793199539185, 0.3351816892623901, 0.2613398343324661, 0.2466729998588562, 0.21184660196304322, 0.20426980555057525, 0.1862332671880722, 0.16687243729829787, 0.15970188975334168, 0.16814120709896088]
Dev loss: 0.15619001239538194


Epoch:  55%|█████▌    | 11/20 [00:49<00:40,  4.45s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  60%|██████    | 12/20 [00:53<00:34,  4.34s/it]


Loss history: [0.46770793199539185, 0.3351816892623901, 0.2613398343324661, 0.2466729998588562, 0.21184660196304322, 0.20426980555057525, 0.1862332671880722, 0.16687243729829787, 0.15970188975334168, 0.16814120709896088, 0.15619001239538194]
Dev loss: 0.16108817160129546


HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [0.46770793199539185, 0.3351816892623901, 0.2613398343324661, 0.2466729998588562, 0.21184660196304322, 0.20426980555057525, 0.1862332671880722, 0.16687243729829787, 0.15970188975334168, 0.16814120709896088, 0.15619001239538194, 0.16108817160129546]
Dev loss: 0.1498269259929657


Epoch:  65%|██████▌   | 13/20 [00:58<00:30,  4.40s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  70%|███████   | 14/20 [01:02<00:25,  4.31s/it]


Loss history: [0.46770793199539185, 0.3351816892623901, 0.2613398343324661, 0.2466729998588562, 0.21184660196304322, 0.20426980555057525, 0.1862332671880722, 0.16687243729829787, 0.15970188975334168, 0.16814120709896088, 0.15619001239538194, 0.16108817160129546, 0.1498269259929657]
Dev loss: 0.15279495120048522


HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  75%|███████▌  | 15/20 [01:06<00:21,  4.24s/it]


Loss history: [0.46770793199539185, 0.3351816892623901, 0.2613398343324661, 0.2466729998588562, 0.21184660196304322, 0.20426980555057525, 0.1862332671880722, 0.16687243729829787, 0.15970188975334168, 0.16814120709896088, 0.15619001239538194, 0.16108817160129546, 0.1498269259929657, 0.15279495120048522]
Dev loss: 0.15603040903806686


HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [0.46770793199539185, 0.3351816892623901, 0.2613398343324661, 0.2466729998588562, 0.21184660196304322, 0.20426980555057525, 0.1862332671880722, 0.16687243729829787, 0.15970188975334168, 0.16814120709896088, 0.15619001239538194, 0.16108817160129546, 0.1498269259929657, 0.15279495120048522, 0.15603040903806686]
Dev loss: 0.14055202677845954


Epoch:  80%|████████  | 16/20 [01:11<00:17,  4.34s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  85%|████████▌ | 17/20 [01:15<00:12,  4.26s/it]


Loss history: [0.46770793199539185, 0.3351816892623901, 0.2613398343324661, 0.2466729998588562, 0.21184660196304322, 0.20426980555057525, 0.1862332671880722, 0.16687243729829787, 0.15970188975334168, 0.16814120709896088, 0.15619001239538194, 0.16108817160129546, 0.1498269259929657, 0.15279495120048522, 0.15603040903806686, 0.14055202677845954]
Dev loss: 0.1639494016766548


HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  90%|█████████ | 18/20 [01:19<00:08,  4.21s/it]


Loss history: [0.46770793199539185, 0.3351816892623901, 0.2613398343324661, 0.2466729998588562, 0.21184660196304322, 0.20426980555057525, 0.1862332671880722, 0.16687243729829787, 0.15970188975334168, 0.16814120709896088, 0.15619001239538194, 0.16108817160129546, 0.1498269259929657, 0.15279495120048522, 0.15603040903806686, 0.14055202677845954, 0.1639494016766548]
Dev loss: 0.14596846848726272


HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…

Epoch:  95%|█████████▌| 19/20 [01:23<00:04,  4.17s/it]


Loss history: [0.46770793199539185, 0.3351816892623901, 0.2613398343324661, 0.2466729998588562, 0.21184660196304322, 0.20426980555057525, 0.1862332671880722, 0.16687243729829787, 0.15970188975334168, 0.16814120709896088, 0.15619001239538194, 0.16108817160129546, 0.1498269259929657, 0.15279495120048522, 0.15603040903806686, 0.14055202677845954, 0.1639494016766548, 0.14596846848726272]
Dev loss: 0.1474771097302437


HBox(children=(IntProgress(value=0, description='Training iteration', max=18, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=5, style=ProgressStyle(description…


Loss history: [0.46770793199539185, 0.3351816892623901, 0.2613398343324661, 0.2466729998588562, 0.21184660196304322, 0.20426980555057525, 0.1862332671880722, 0.16687243729829787, 0.15970188975334168, 0.16814120709896088, 0.15619001239538194, 0.16108817160129546, 0.1498269259929657, 0.15279495120048522, 0.15603040903806686, 0.14055202677845954, 0.1639494016766548, 0.14596846848726272, 0.1474771097302437]
Dev loss: 0.14006498232483863


Epoch: 100%|██████████| 20/20 [01:27<00:00,  4.29s/it]


## Evaluation

In [7]:
from quillnlp.models.bert.train import evaluate
from sklearn.metrics import precision_recall_fscore_support, classification_report

print("Loading model from", output_model_file)
device="cpu"

model = get_multilabel_bert_classifier(BERT_MODEL, len(label2idx), model_file=output_model_file, device=device)
model.eval()

_, test_correct, test_predicted = evaluate(model, test_dataloader, device)

print("Test performance:", precision_recall_fscore_support(test_correct, test_predicted, average="micro"))
print(classification_report(test_correct, test_predicted, target_names=target_names))

Loading model from /tmp/model.bin


HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=10, style=ProgressStyle(descriptio…


Test performance: (0.9227467811158798, 0.8739837398373984, 0.8977035490605428, None)
                      precision    recall  f1-score   support

           off-topic       1.00      0.36      0.53        11
             schools       0.97      0.92      0.94        99
healthy alternatives       0.95      0.96      0.95        75
            students       0.81      0.79      0.80        43
              choice       0.67      0.86      0.75         7
               money       1.00      1.00      1.00         8
 can bring junk food       0.00      0.00      0.00         3

         avg / total       0.92      0.87      0.89       246



  'precision', 'predicted', average, warn_for)


In [8]:
all_correct = 0
fp, fn, tp, tn = 0, 0, 0, 0
for c, p in zip(test_correct, test_predicted):
    if sum(c == p) == len(c):
        all_correct +=1
    for ci, pi in zip(c, p):
        if pi == 1 and ci == 1:
            tp += 1
            same = 1
        elif pi == 1 and ci == 0:
            fp += 1
        elif pi == 0 and ci == 1:
            fn += 1
        else:
            tn += 1
            same =1
            
precision = tp/(tp+fp)
recall = tp/(tp+fn)
print("P:", precision)
print("R:", recall)
print("A:", all_correct/len(test_correct))

P: 0.9227467811158798
R: 0.8739837398373984
A: 0.7908496732026143


In [9]:
for item, predicted, correct in zip(test_data, test_predicted, test_correct):
    correct_labels = [idx2label[i] for i, l in enumerate(correct) if l == 1]
    predicted_labels = [idx2label[i] for i, l in enumerate(predicted) if l == 1]
    print("{}#{}#{}".format(item["text"], ";".join(correct_labels), ";".join(predicted_labels)))


Schools should not allow junk food to be sold on campus but kids will still bring in unhealthy food#students#students
Schools should not allow junk food to be sold on campus but some think students should be able to choose what they eat#students;choice#students;choice
Schools should not allow junk food to be sold on campus but maybe on certain special occasions or at events#off-topic#off-topic
Schools should not allow junk food to be sold on campus but students may bring it anyway#students#students
Schools should not allow junk food to be sold on campus but they can choose to sell food that are nutritious and healthy#schools;healthy alternatives#schools;healthy alternatives
Schools should not allow junk food to be sold on campus but ultimately it is up to the individual student what they eat or drink#students#students
Schools should not allow junk food to be sold on campus but should provide healthier choices#schools;healthy alternatives#schools;healthy alternatives
Schools should not 