In [48]:
import datasets
from transformers import BertTokenizer, BertConfig, BertForSequenceClassification, TrainingArguments, Trainer
import numpy as np
from src.eval import eval_preds

dataset = datasets.load_dataset("michaelginn/latent-trees-agreement-ID")


tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
max_length = 100
def tokenize_function(example):
    return tokenizer(example['text'], max_length=max_length, truncation=True)
dataset = dataset.map(tokenize_function, batched=True, load_from_cache_file=False)

toy_dataset = dataset['train'].select(range(1, 11))

id2label = {0: "VIOLATION", 1: "GRAMMATICAL"}
label2id = {"VIOLATION": 0, "GRAMMATICAL": 1}

pretrained = True
if pretrained:
    config = BertConfig.from_pretrained('bert-base-uncased', num_labels=2, id2label=id2label, label2id=label2id)
else:
    # Create random initialized BERT model
    config = BertConfig(num_labels=2, id2label=id2label, label2id=label2id)

model = BertForSequenceClassification(config=config).to('mps')

args = TrainingArguments(
    output_dir=f"../training-checkpoints",
    learning_rate=2e-5,
    evaluation_strategy="epoch",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=1,
    save_strategy="epoch",
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=10,
    load_best_model_at_end=False,
    logging_strategy='epoch',
)

def compute_metrics(eval_pred):
    labels = eval_pred.label_ids
    preds = np.argmax(eval_pred.predictions, axis=-1)
    print(eval_pred.predictions)
    return eval_preds(preds, labels)


trainer = Trainer(
    model,
    args,
    train_dataset=toy_dataset,
    eval_dataset=toy_dataset, # dataset['test'].select(range(20)),
    compute_metrics=compute_metrics,
    tokenizer=tokenizer
)

trainer.train()

preds = trainer.predict(dataset['test'].select(range(20)))
preds

loading file https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt from cache at /Users/milesper/.cache/huggingface/transformers/45c3f7a79a80e1cf0a489e5c62b43f173c15db47864303a55d623bb3c96f72a5.d789d64ebfe299b0e416afc4a169632f903f693095b4629a7ea271d5a0cf2c99
loading file https://huggingface.co/bert-base-uncased/resolve/main/added_tokens.json from cache at None
loading file https://huggingface.co/bert-base-uncased/resolve/main/special_tokens_map.json from cache at None
loading file https://huggingface.co/bert-base-uncased/resolve/main/tokenizer_config.json from cache at /Users/milesper/.cache/huggingface/transformers/c1d7f0a763fb63861cc08553866f1fc3e5a6f4f07621be277452d26d71303b7e.20430bd8e10ef77a7d2977accefe796051e01bc2fc4aa146bc862997a1a15e79
loading configuration file https://huggingface.co/bert-base-uncased/resolve/main/config.json from cache at /Users/milesper/.cache/huggingface/transformers/3c61d016573b14f7f008c02c4e51a366c67ab274726fe2910691e2a761acf43e.37395cee442ab110

Map:   0%|          | 0/2800 [00:00<?, ? examples/s]

Map:   0%|          | 0/1200 [00:00<?, ? examples/s]

loading configuration file https://huggingface.co/bert-base-uncased/resolve/main/config.json from cache at /Users/milesper/.cache/huggingface/transformers/3c61d016573b14f7f008c02c4e51a366c67ab274726fe2910691e2a761acf43e.37395cee442ab11005bcd270f3c34464dc1704b715b5d7d52b1a461abe3b9e4e
Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "VIOLATION",
    "1": "GRAMMATICAL"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "GRAMMATICAL": 1,
    "VIOLATION": 0
  },
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.21.3",
  "type_vocab_size": 2,
  "use_cache

Epoch,Training Loss,Validation Loss


The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 1


[[-0.46431568 -0.09398253]
 [-0.38338035 -0.11335935]
 [-0.5430481   0.07136033]
 [-0.5524909   0.07485996]
 [-0.5505021   0.08642656]
 [-0.51906747 -0.02079204]
 [-0.42728794 -0.1279895 ]
 [-0.3785798  -0.12420638]
 [-0.40356916 -0.13772671]
 [-0.4894873  -0.02433187]]
PREDS [1 1 1 1 1 1 1 1 1 1]
LABELS [0 0 1 1 1 1 1 0 0 0]


Saving model checkpoint to ../training-checkpoints/checkpoint-10
Configuration saved in ../training-checkpoints/checkpoint-10/config.json
Model weights saved in ../training-checkpoints/checkpoint-10/pytorch_model.bin
tokenizer config file saved in ../training-checkpoints/checkpoint-10/tokenizer_config.json
Special tokens file saved in ../training-checkpoints/checkpoint-10/special_tokens_map.json
Deleting older checkpoint [../training-checkpoints/checkpoint-80] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 1


[[-1.2974775   0.9850861 ]
 [-1.2307818   0.95143634]
 [-1.3605208   1.1232607 ]
 [-1.3670791   1.1233885 ]
 [-1.3608583   1.1224539 ]
 [-1.3606353   1.0652703 ]
 [-1.2911935   0.9917485 ]
 [-1.2272487   0.93031317]
 [-1.2442553   0.92800194]
 [-1.3267769   1.0438648 ]]
PREDS [1 1 1 1 1 1 1 1 1 1]
LABELS [0 0 1 1 1 1 1 0 0 0]


Saving model checkpoint to ../training-checkpoints/checkpoint-20
Configuration saved in ../training-checkpoints/checkpoint-20/config.json
Model weights saved in ../training-checkpoints/checkpoint-20/pytorch_model.bin
tokenizer config file saved in ../training-checkpoints/checkpoint-20/tokenizer_config.json
Special tokens file saved in ../training-checkpoints/checkpoint-20/special_tokens_map.json
Deleting older checkpoint [../training-checkpoints/checkpoint-90] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 1


[[-0.73431355  0.4241377 ]
 [-0.6430441   0.37142813]
 [-0.84299666  0.5973026 ]
 [-0.8539642   0.6018497 ]
 [-0.8564709   0.6084635 ]
 [-0.8355526   0.5362897 ]
 [-0.7349311   0.4393542 ]
 [-0.6439443   0.36146802]
 [-0.6631613   0.35393503]
 [-0.7884347   0.5014813 ]]
PREDS [1 1 1 1 1 1 1 1 1 1]
LABELS [0 0 1 1 1 1 1 0 0 0]


Saving model checkpoint to ../training-checkpoints/checkpoint-30
Configuration saved in ../training-checkpoints/checkpoint-30/config.json
Model weights saved in ../training-checkpoints/checkpoint-30/pytorch_model.bin
tokenizer config file saved in ../training-checkpoints/checkpoint-30/tokenizer_config.json
Special tokens file saved in ../training-checkpoints/checkpoint-30/special_tokens_map.json
Deleting older checkpoint [../training-checkpoints/checkpoint-100] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 1


[[-0.89702904  0.56345814]
 [-0.7756403   0.4772451 ]
 [-1.0818349   0.8322886 ]
 [-1.0936857   0.8375097 ]
 [-1.0909638   0.8360023 ]
 [-1.0646697   0.7553661 ]
 [-0.9113325   0.59612584]
 [-0.7685956   0.4627576 ]
 [-0.80454296  0.46504027]
 [-0.99561816  0.6947219 ]]
PREDS [1 1 1 1 1 1 1 1 1 1]
LABELS [0 0 1 1 1 1 1 0 0 0]


Saving model checkpoint to ../training-checkpoints/checkpoint-40
Configuration saved in ../training-checkpoints/checkpoint-40/config.json
Model weights saved in ../training-checkpoints/checkpoint-40/pytorch_model.bin
tokenizer config file saved in ../training-checkpoints/checkpoint-40/tokenizer_config.json
Special tokens file saved in ../training-checkpoints/checkpoint-40/special_tokens_map.json
Deleting older checkpoint [../training-checkpoints/checkpoint-10] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 1


[[-0.56759036  0.19747162]
 [-0.37455934  0.05177116]
 [-0.8969219   0.6251585 ]
 [-0.90813255  0.6261324 ]
 [-0.9069947   0.62404275]
 [-0.8550825   0.5161369 ]
 [-0.6102451   0.26658   ]
 [-0.3665326   0.03779297]
 [-0.42419553  0.05496217]
 [-0.7305125   0.39843807]]
PREDS [1 1 1 1 1 1 1 1 1 1]
LABELS [0 0 1 1 1 1 1 0 0 0]


Saving model checkpoint to ../training-checkpoints/checkpoint-50
Configuration saved in ../training-checkpoints/checkpoint-50/config.json
Model weights saved in ../training-checkpoints/checkpoint-50/pytorch_model.bin
tokenizer config file saved in ../training-checkpoints/checkpoint-50/tokenizer_config.json
Special tokens file saved in ../training-checkpoints/checkpoint-50/special_tokens_map.json
Deleting older checkpoint [../training-checkpoints/checkpoint-20] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 1


[[-0.48493037  0.10595788]
 [-0.12446091 -0.1929117 ]
 [-1.1381111   0.90916336]
 [-1.1552248   0.9083038 ]
 [-1.1477555   0.89978456]
 [-1.0379833   0.7133919 ]
 [-0.5776015   0.2343863 ]
 [-0.12830043 -0.19567537]
 [-0.21150765 -0.16468304]
 [-0.81719124  0.49114057]]
PREDS [1 0 1 1 1 1 1 0 1 1]
LABELS [0 0 1 1 1 1 1 0 0 0]


Saving model checkpoint to ../training-checkpoints/checkpoint-60
Configuration saved in ../training-checkpoints/checkpoint-60/config.json
Model weights saved in ../training-checkpoints/checkpoint-60/pytorch_model.bin
tokenizer config file saved in ../training-checkpoints/checkpoint-60/tokenizer_config.json
Special tokens file saved in ../training-checkpoints/checkpoint-60/special_tokens_map.json
Deleting older checkpoint [../training-checkpoints/checkpoint-30] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 1


[[-0.20435876 -0.18196955]
 [ 0.46274093 -0.74940467]
 [-1.415817    1.2591076 ]
 [-1.4758974   1.2992947 ]
 [-1.4894832   1.3129562 ]
 [-1.2792498   0.9858619 ]
 [-0.3970693   0.05820819]
 [ 0.45260555 -0.75121105]
 [ 0.29184556 -0.65355814]
 [-0.8519279   0.5287627 ]]
PREDS [1 0 1 1 1 1 1 0 0 1]
LABELS [0 0 1 1 1 1 1 0 0 0]


Saving model checkpoint to ../training-checkpoints/checkpoint-70
Configuration saved in ../training-checkpoints/checkpoint-70/config.json
Model weights saved in ../training-checkpoints/checkpoint-70/pytorch_model.bin
tokenizer config file saved in ../training-checkpoints/checkpoint-70/tokenizer_config.json
Special tokens file saved in ../training-checkpoints/checkpoint-70/special_tokens_map.json
Deleting older checkpoint [../training-checkpoints/checkpoint-40] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 1


[[-0.31256682 -0.06805959]
 [ 0.71766055 -0.9741449 ]
 [-1.9551431   1.9841692 ]
 [-2.010648    2.0361893 ]
 [-2.0308142   2.050364  ]
 [-1.8632245   1.7267557 ]
 [-0.7548868   0.46609837]
 [ 0.72172284 -0.99236363]
 [ 0.47986355 -0.8195762 ]
 [-1.3503804   1.111244  ]]
PREDS [1 0 1 1 1 1 1 0 0 1]
LABELS [0 0 1 1 1 1 1 0 0 0]


Saving model checkpoint to ../training-checkpoints/checkpoint-80
Configuration saved in ../training-checkpoints/checkpoint-80/config.json
Model weights saved in ../training-checkpoints/checkpoint-80/pytorch_model.bin
tokenizer config file saved in ../training-checkpoints/checkpoint-80/tokenizer_config.json
Special tokens file saved in ../training-checkpoints/checkpoint-80/special_tokens_map.json
Deleting older checkpoint [../training-checkpoints/checkpoint-50] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 1


[[ 0.24375689 -0.6347455 ]
 [ 1.3267066  -1.5692118 ]
 [-2.2447357   2.3709633 ]
 [-2.2946167   2.4184656 ]
 [-2.3195634   2.4349923 ]
 [-2.143381    2.0783303 ]
 [-0.45168376  0.16582389]
 [ 1.3231709  -1.5911183 ]
 [ 1.1375697  -1.4491277 ]
 [-1.4357418   1.208245  ]]
PREDS [0 0 1 1 1 1 1 0 0 1]
LABELS [0 0 1 1 1 1 1 0 0 0]


Saving model checkpoint to ../training-checkpoints/checkpoint-90
Configuration saved in ../training-checkpoints/checkpoint-90/config.json
Model weights saved in ../training-checkpoints/checkpoint-90/pytorch_model.bin
tokenizer config file saved in ../training-checkpoints/checkpoint-90/tokenizer_config.json
Special tokens file saved in ../training-checkpoints/checkpoint-90/special_tokens_map.json
Deleting older checkpoint [../training-checkpoints/checkpoint-60] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 1


[[ 0.56818163 -0.9523997 ]
 [ 1.5320762  -1.7794728 ]
 [-2.3002353   2.4384127 ]
 [-2.3496423   2.482844  ]
 [-2.3765564   2.5007553 ]
 [-2.1907153   2.1300793 ]
 [-0.2567394  -0.02931753]
 [ 1.5295653  -1.8047838 ]
 [ 1.3675306  -1.6769116 ]
 [-1.3370025   1.0929039 ]]
PREDS [0 0 1 1 1 1 1 0 0 1]
LABELS [0 0 1 1 1 1 1 0 0 0]


Saving model checkpoint to ../training-checkpoints/checkpoint-100
Configuration saved in ../training-checkpoints/checkpoint-100/config.json
Model weights saved in ../training-checkpoints/checkpoint-100/pytorch_model.bin
tokenizer config file saved in ../training-checkpoints/checkpoint-100/tokenizer_config.json
Special tokens file saved in ../training-checkpoints/checkpoint-100/special_tokens_map.json
Deleting older checkpoint [../training-checkpoints/checkpoint-70] due to args.save_total_limit


Training completed. Do not forget to share your model on huggingface.co/models =)


The following columns in the test set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 20
  Batch size = 1


[[ 0.83150387 -1.0455612 ]
 [ 1.3024168  -1.5945535 ]
 [ 1.2472796  -1.5266438 ]
 [-1.2660668   0.96118355]
 [ 1.043215   -1.3469826 ]
 [-1.7565308   1.653706  ]
 [ 0.8188604  -1.1564572 ]
 [-0.31361818  0.03787396]
 [ 0.91278684 -1.1873815 ]
 [ 1.2342386  -1.5319548 ]
 [ 0.5119903  -0.81276834]
 [ 0.7178125  -0.9798484 ]
 [-1.8974295   1.8141869 ]
 [-2.3183742   2.4706788 ]
 [-1.6356996   1.4102662 ]
 [ 0.52811825 -0.8062937 ]
 [ 0.8965265  -1.2483277 ]
 [-1.9545257   1.8107259 ]
 [ 0.9581102  -1.214535  ]
 [ 0.21868679 -0.54523504]]
PREDS [0 0 0 1 0 1 0 1 0 0 0 0 1 1 1 0 0 1 0 0]
LABELS [0 0 1 1 0 0 0 1 1 1 1 1 1 1 0 0 1 1 0 0]


PredictionOutput(predictions=array([[ 0.83150387, -1.0455612 ],
       [ 1.3024168 , -1.5945535 ],
       [ 1.2472796 , -1.5266438 ],
       [-1.2660668 ,  0.96118355],
       [ 1.043215  , -1.3469826 ],
       [-1.7565308 ,  1.653706  ],
       [ 0.8188604 , -1.1564572 ],
       [-0.31361818,  0.03787396],
       [ 0.91278684, -1.1873815 ],
       [ 1.2342386 , -1.5319548 ],
       [ 0.5119903 , -0.81276834],
       [ 0.7178125 , -0.9798484 ],
       [-1.8974295 ,  1.8141869 ],
       [-2.3183742 ,  2.4706788 ],
       [-1.6356996 ,  1.4102662 ],
       [ 0.52811825, -0.8062937 ],
       [ 0.8965265 , -1.2483277 ],
       [-1.9545257 ,  1.8107259 ],
       [ 0.9581102 , -1.214535  ],
       [ 0.21868679, -0.54523504]], dtype=float32), label_ids=array([0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0]), metrics={'test_loss': 1.0961285829544067, 'test_accuracy': 0.6, 'test_precision': 0.6263736263736193, 'test_recall': 0.6161616161616098, 'test_f1': 0.6212256566016944, 'test_r

In [50]:
eval = trainer.evaluate(toy_dataset)
{k.replace('eval', 'test'): eval[k] for k in eval}

The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 1


[[ 0.56818163 -0.9523997 ]
 [ 1.5320762  -1.7794728 ]
 [-2.3002353   2.4384127 ]
 [-2.3496423   2.482844  ]
 [-2.3765564   2.5007553 ]
 [-2.1907153   2.1300793 ]
 [-0.2567394  -0.02931753]
 [ 1.5295653  -1.8047838 ]
 [ 1.3675306  -1.6769116 ]
 [-1.3370025   1.0929039 ]]
PREDS [0 0 1 1 1 1 1 0 0 1]
LABELS [0 0 1 1 1 1 1 0 0 0]


{'test_loss': 0.34526562690734863,
 'test_accuracy': 0.9,
 'test_precision': 0.9166666666666472,
 'test_recall': 0.8999999999999819,
 'test_f1': 0.9082568807338762,
 'test_runtime': 0.4494,
 'test_samples_per_second': 22.253,
 'test_steps_per_second': 22.253,
 'epoch': 10.0}