In [5]:
import datasets
from transformers import BertTokenizer, BertConfig, BertForSequenceClassification, TrainingArguments, Trainer
import numpy as np
from src.eval import eval_preds

dataset = datasets.load_dataset("michaelginn/latent-trees-agreement-ID")


tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
max_length = 100
def tokenize_function(example):
    return tokenizer(example['text'], max_length=max_length, padding='max_length', truncation=True)
dataset = dataset.map(tokenize_function, batched=True, load_from_cache_file=False)

toy_dataset = dataset['train'].select(range(10))

pretrained = True
if pretrained:
    config = BertConfig.from_pretrained('bert-base-uncased', num_labels=2)
else:
    # Create random initialized BERT model
    config = BertConfig(num_labels=2)

model = BertForSequenceClassification(config=config).to('mps')

args = TrainingArguments(
    output_dir=f"../training-checkpoints",
    evaluation_strategy="epoch",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=3,
    save_strategy="epoch",
    save_total_limit=3,
    num_train_epochs=10,
    load_best_model_at_end=False,
    logging_strategy='epoch',
    report_to='wandb'
)

def compute_metrics(eval_pred):
    labels = eval_pred.label_ids
    preds = np.argmax(eval_pred.predictions, axis=-1)
    print(eval_pred.predictions)
    return eval_preds(preds, labels)


trainer = Trainer(
    model,
    args,
    train_dataset=toy_dataset,
    eval_dataset=toy_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()

loading file https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt from cache at /Users/milesper/.cache/huggingface/transformers/45c3f7a79a80e1cf0a489e5c62b43f173c15db47864303a55d623bb3c96f72a5.d789d64ebfe299b0e416afc4a169632f903f693095b4629a7ea271d5a0cf2c99
loading file https://huggingface.co/bert-base-uncased/resolve/main/added_tokens.json from cache at None
loading file https://huggingface.co/bert-base-uncased/resolve/main/special_tokens_map.json from cache at None
loading file https://huggingface.co/bert-base-uncased/resolve/main/tokenizer_config.json from cache at /Users/milesper/.cache/huggingface/transformers/c1d7f0a763fb63861cc08553866f1fc3e5a6f4f07621be277452d26d71303b7e.20430bd8e10ef77a7d2977accefe796051e01bc2fc4aa146bc862997a1a15e79
loading configuration file https://huggingface.co/bert-base-uncased/resolve/main/config.json from cache at /Users/milesper/.cache/huggingface/transformers/3c61d016573b14f7f008c02c4e51a366c67ab274726fe2910691e2a761acf43e.37395cee442ab110

Map:   0%|          | 0/2800 [00:00<?, ? examples/s]

Map:   0%|          | 0/1200 [00:00<?, ? examples/s]

loading configuration file https://huggingface.co/bert-base-uncased/resolve/main/config.json from cache at /Users/milesper/.cache/huggingface/transformers/3c61d016573b14f7f008c02c4e51a366c67ab274726fe2910691e2a761acf43e.37395cee442ab11005bcd270f3c34464dc1704b715b5d7d52b1a461abe3b9e4e
Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.21.3",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

PyTorch: setting up devices
The following columns in the training set don't have a correspondi

Epoch,Training Loss,Validation Loss


The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 1


[[-1.1179357   1.3664817 ]
 [-1.0935471   1.3240856 ]
 [-0.9150632   1.1935802 ]
 [-1.2055941   1.4077724 ]
 [-1.2706304   1.4327023 ]
 [-1.2617466   1.4232585 ]
 [-1.2208904   1.4263747 ]
 [-1.1233699   1.3454708 ]
 [-0.90578747  1.2026079 ]
 [-1.0126162   1.250495  ]]
PREDS [1 1 1 1 1 1 1 1 1 1]
LABELS [1 0 0 1 1 1 1 1 0 0]


Saving model checkpoint to ../training-checkpoints/checkpoint-3
Configuration saved in ../training-checkpoints/checkpoint-3/config.json
Model weights saved in ../training-checkpoints/checkpoint-3/pytorch_model.bin
Deleting older checkpoint [../training-checkpoints/checkpoint-14] due to args.save_total_limit
Deleting older checkpoint [../training-checkpoints/checkpoint-16] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 1


[[-0.29804322  0.43019196]
 [-0.2520776   0.3678865 ]
 [-0.06757163  0.2216199 ]
 [-0.42850274  0.4985996 ]
 [-0.42320266  0.4943235 ]
 [-0.42653924  0.5018227 ]
 [-0.3867669   0.4814942 ]
 [-0.28525406  0.38903785]
 [-0.0827678   0.22735633]
 [-0.17032826  0.2847673 ]]
PREDS [1 1 1 1 1 1 1 1 1 1]
LABELS [1 0 0 1 1 1 1 1 0 0]


Saving model checkpoint to ../training-checkpoints/checkpoint-6
Configuration saved in ../training-checkpoints/checkpoint-6/config.json
Model weights saved in ../training-checkpoints/checkpoint-6/pytorch_model.bin
Deleting older checkpoint [../training-checkpoints/checkpoint-18] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 1


[[ 0.47202072 -0.18887904]
 [ 0.5784718  -0.30033913]
 [ 0.7452614  -0.4567238 ]
 [ 0.31946757 -0.09457489]
 [ 0.34048903 -0.11347176]
 [ 0.33718264 -0.10397281]
 [ 0.3978076  -0.15101987]
 [ 0.5474052  -0.28453022]
 [ 0.74125516 -0.46360838]
 [ 0.6715506  -0.4096954 ]]
PREDS [0 0 0 0 0 0 0 0 0 0]
LABELS [1 0 0 1 1 1 1 1 0 0]


Saving model checkpoint to ../training-checkpoints/checkpoint-9
Configuration saved in ../training-checkpoints/checkpoint-9/config.json
Model weights saved in ../training-checkpoints/checkpoint-9/pytorch_model.bin
Deleting older checkpoint [../training-checkpoints/checkpoint-20] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 1


[[-0.16394049  0.34273928]
 [ 0.01937813  0.17554685]
 [ 0.25546134 -0.04028283]
 [-0.3510821   0.45982942]
 [-0.321734    0.43410522]
 [-0.3330962   0.45109624]
 [-0.26890352  0.40121636]
 [-0.03289691  0.20996882]
 [ 0.26233914 -0.05275812]
 [ 0.15693396  0.03410645]]
PREDS [1 1 0 1 1 1 1 1 0 0]
LABELS [1 0 0 1 1 1 1 1 0 0]


Saving model checkpoint to ../training-checkpoints/checkpoint-12
Configuration saved in ../training-checkpoints/checkpoint-12/config.json
Model weights saved in ../training-checkpoints/checkpoint-12/pytorch_model.bin
Deleting older checkpoint [../training-checkpoints/checkpoint-3] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 1


[[-0.539212    0.649873  ]
 [-0.20865627  0.35452166]
 [ 0.1684573   0.01224801]
 [-0.81245714  0.83648944]
 [-0.7798877   0.8079491 ]
 [-0.81234515  0.84246206]
 [-0.73465574  0.77926093]
 [-0.34513533  0.46022457]
 [ 0.19076195 -0.01226859]
 [ 0.01166182  0.13585564]]
PREDS [1 1 0 1 1 1 1 1 0 1]
LABELS [1 0 0 1 1 1 1 1 0 0]


Saving model checkpoint to ../training-checkpoints/checkpoint-15
Configuration saved in ../training-checkpoints/checkpoint-15/config.json
Model weights saved in ../training-checkpoints/checkpoint-15/pytorch_model.bin
Deleting older checkpoint [../training-checkpoints/checkpoint-6] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 1


[[-0.3434649   0.4600704 ]
 [ 0.38856807 -0.19446984]
 [ 1.0211767  -0.78650343]
 [-0.9075805   0.88514864]
 [-0.8567387   0.8384602 ]
 [-0.94372153  0.9187646 ]
 [-0.789724    0.79093957]
 [ 0.03541064  0.09760509]
 [ 1.0450215  -0.81219095]
 [ 0.7716863  -0.5866253 ]]
PREDS [1 0 0 1 1 1 1 1 0 0]
LABELS [1 0 0 1 1 1 1 1 0 0]


Saving model checkpoint to ../training-checkpoints/checkpoint-18
Configuration saved in ../training-checkpoints/checkpoint-18/config.json
Model weights saved in ../training-checkpoints/checkpoint-18/pytorch_model.bin
Deleting older checkpoint [../training-checkpoints/checkpoint-9] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 1


[[-0.5158009   0.5917436 ]
 [ 0.8861027  -0.6591667 ]
 [ 1.6215199  -1.3677391 ]
 [-1.4454768   1.3301756 ]
 [-1.3878131   1.2712854 ]
 [-1.5104678   1.3810422 ]
 [-1.316827    1.219561  ]
 [ 0.11971101  0.00228929]
 [ 1.6437691  -1.3966019 ]
 [ 1.356835   -1.156396  ]]
PREDS [1 0 0 1 1 1 1 0 0 0]
LABELS [1 0 0 1 1 1 1 1 0 0]


Saving model checkpoint to ../training-checkpoints/checkpoint-21
Configuration saved in ../training-checkpoints/checkpoint-21/config.json
Model weights saved in ../training-checkpoints/checkpoint-21/pytorch_model.bin
Deleting older checkpoint [../training-checkpoints/checkpoint-12] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 1


[[-1.7512172   1.6235843 ]
 [ 0.46027535 -0.27537924]
 [ 1.8564813  -1.5747919 ]
 [-2.2927408   2.0875492 ]
 [-2.295187    2.0679202 ]
 [-2.3390903   2.1156523 ]
 [-2.2529528   2.0317397 ]
 [-0.87836915  0.86512554]
 [ 1.8698337  -1.601629  ]
 [ 1.4808506  -1.2712641 ]]
PREDS [1 0 0 1 1 1 1 1 0 0]
LABELS [1 0 0 1 1 1 1 1 0 0]


Saving model checkpoint to ../training-checkpoints/checkpoint-24
Configuration saved in ../training-checkpoints/checkpoint-24/config.json
Model weights saved in ../training-checkpoints/checkpoint-24/pytorch_model.bin
Deleting older checkpoint [../training-checkpoints/checkpoint-15] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 1


[[-2.2711802   2.064136  ]
 [ 0.08755343  0.04653877]
 [ 1.9891806  -1.6843555 ]
 [-2.5673738   2.3597765 ]
 [-2.5866818   2.34948   ]
 [-2.601398    2.376001  ]
 [-2.561885    2.3253427 ]
 [-1.6400346   1.50453   ]
 [ 1.999433   -1.7147781 ]
 [ 1.5091825  -1.2923696 ]]
PREDS [1 0 0 1 1 1 1 1 0 0]
LABELS [1 0 0 1 1 1 1 1 0 0]


Saving model checkpoint to ../training-checkpoints/checkpoint-27
Configuration saved in ../training-checkpoints/checkpoint-27/config.json
Model weights saved in ../training-checkpoints/checkpoint-27/pytorch_model.bin
Deleting older checkpoint [../training-checkpoints/checkpoint-18] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10
  Batch size = 1


[[-2.3537512   2.1381905 ]
 [ 0.29000133 -0.12994525]
 [ 2.1392136  -1.819892  ]
 [-2.6273084   2.4221878 ]
 [-2.6499925   2.413395  ]
 [-2.661004    2.4375591 ]
 [-2.627305    2.3910153 ]
 [-1.754251    1.6009631 ]
 [ 2.1385055  -1.8438537 ]
 [ 1.6870191  -1.4523613 ]]
PREDS [1 0 0 1 1 1 1 1 0 0]
LABELS [1 0 0 1 1 1 1 1 0 0]


Saving model checkpoint to ../training-checkpoints/checkpoint-30
Configuration saved in ../training-checkpoints/checkpoint-30/config.json
Model weights saved in ../training-checkpoints/checkpoint-30/pytorch_model.bin
Deleting older checkpoint [../training-checkpoints/checkpoint-21] due to args.save_total_limit


Training completed. Do not forget to share your model on huggingface.co/models =)




TrainOutput(global_step=30, training_loss=0.5061087836821874, metrics={'train_runtime': 45.4667, 'train_samples_per_second': 2.199, 'train_steps_per_second': 0.66, 'total_flos': 5087498922000.0, 'train_loss': 0.5061087836821874, 'epoch': 9.9})

In [8]:
preds = trainer.predict(dataset['test'].select(range(20)))
preds

The following columns in the test set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 20
  Batch size = 1


[[ 0.37556407 -0.20810999]
 [ 1.9821789  -1.6856345 ]
 [ 1.8124082  -1.5262196 ]
 [-2.1486216   1.9749408 ]
 [ 1.6546522  -1.3272806 ]
 [-2.480386    2.229885  ]
 [ 1.4240357  -1.1799687 ]
 [-1.0209055   1.003545  ]
 [ 1.0481441  -0.828573  ]
 [ 1.8723347  -1.5579777 ]
 [ 0.133605   -0.02459436]
 [ 1.0626667  -0.8508463 ]
 [-2.4228776   2.2447116 ]
 [-2.7038057   2.4753833 ]
 [-2.4699807   2.2760224 ]
 [-0.47930747  0.4828345 ]
 [ 1.1681967  -0.8876208 ]
 [-2.5493135   2.3202372 ]
 [ 1.0073977  -0.81787634]
 [-0.79542565  0.7664305 ]]
PREDS [0 0 0 1 0 1 0 1 0 0 0 0 1 1 1 1 0 1 0 1]
LABELS [0 0 1 1 0 0 0 1 1 1 1 1 1 1 0 0 1 1 0 0]


PredictionOutput(predictions=array([[ 0.37556407, -0.20810999],
       [ 1.9821789 , -1.6856345 ],
       [ 1.8124082 , -1.5262196 ],
       [-2.1486216 ,  1.9749408 ],
       [ 1.6546522 , -1.3272806 ],
       [-2.480386  ,  2.229885  ],
       [ 1.4240357 , -1.1799687 ],
       [-1.0209055 ,  1.003545  ],
       [ 1.0481441 , -0.828573  ],
       [ 1.8723347 , -1.5579777 ],
       [ 0.133605  , -0.02459436],
       [ 1.0626667 , -0.8508463 ],
       [-2.4228776 ,  2.2447116 ],
       [-2.7038057 ,  2.4753833 ],
       [-2.4699807 ,  2.2760224 ],
       [-0.47930747,  0.4828345 ],
       [ 1.1681967 , -0.8876208 ],
       [-2.5493135 ,  2.3202372 ],
       [ 1.0073977 , -0.81787634],
       [-0.79542565,  0.7664305 ]], dtype=float32), label_ids=array([0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0]), metrics={'test_loss': 1.363548994064331, 'test_accuracy': 0.5, 'test_precision': 0.5050505050504999, 'test_recall': 0.5050505050504999, 'test_f1': 0.5050505050504499, 'test_ru

In [None]:
preds