In [5]:
import datasets
from transformers import BertTokenizer, BertConfig, BertForSequenceClassification, TrainingArguments, Trainer
import numpy as np
from src.eval import eval_preds
from src.TreeTransformer import TreeBertForSequenceClassification

dataset = datasets.load_dataset("michaelginn/latent-trees-agreement-ID")


tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
max_length = 100
def tokenize_function(example):
    return tokenizer(example['text'], max_length=max_length, truncation=True)
dataset = dataset.map(tokenize_function, batched=True, load_from_cache_file=False)

toy_dataset = dataset['train'].select(range(1, 11))

id2label = {0: "VIOLATION", 1: "GRAMMATICAL"}
label2id = {"VIOLATION": 0, "GRAMMATICAL": 1}

pretrained = False
if pretrained:
    config = BertConfig.from_pretrained('bert-base-uncased', num_labels=2, id2label=id2label, label2id=label2id)
else:
    # Create random initialized BERT model
    config = BertConfig(num_labels=2, id2label=id2label, label2id=label2id)

model = TreeBertForSequenceClassification(config=config).to('mps')

args = TrainingArguments(
    output_dir=f"../training-checkpoints",
    learning_rate=2e-5,
    evaluation_strategy="epoch",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=1,
    save_strategy="epoch",
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=10,
    load_best_model_at_end=False,
    logging_strategy='epoch',
)

def compute_metrics(eval_pred):
    labels = eval_pred.label_ids
    preds = np.argmax(eval_pred.predictions, axis=-1)
    print(eval_pred.predictions)
    return eval_preds(preds, labels)


trainer = Trainer(
    model,
    args,
    train_dataset=toy_dataset,
    eval_dataset=toy_dataset, # dataset['test'].select(range(20)),
    compute_metrics=compute_metrics,
    tokenizer=tokenizer
)

# trainer.train()

# preds = trainer.predict(dataset['test'].select(range(20)))
# preds

Map:   0%|          | 0/2400 [00:00<?, ? examples/s]

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

In [1]:
from transformers import BertConfig
from src.TreeTransformer import TreeBertForSequenceClassification
import torch


config = BertConfig(num_labels=2)
model = TreeBertForSequenceClassification(config=config)
model(input_ids=torch.tensor([[1, 2, 3, 0]]), attention_mask=torch.tensor([[1, 1, 1, 0]]), return_dict=True)

(tensor([[[3.1623e-05, 7.0538e-01, 3.1623e-05, 3.1623e-05],
         [7.0538e-01, 3.1623e-05, 7.0883e-01, 3.1623e-05],
         [3.1623e-05, 7.0883e-01, 3.1623e-05, 3.1623e-05],
         [3.1623e-05, 3.1623e-05, 3.1623e-05, 3.1623e-05]]],
       grad_fn=<AddBackward0>), tensor([[[6.3245e-05, 9.1367e-01, 5.0001e-01, 4.7435e-05],
         [9.1367e-01, 6.3245e-05, 9.1476e-01, 5.4039e-05],
         [5.0001e-01, 9.1476e-01, 6.3245e-05, 6.3247e-05],
         [4.7435e-05, 5.4039e-05, 6.3247e-05, 6.3245e-05]]],
       grad_fn=<AddBackward0>), tensor([[[9.4865e-05, 9.7482e-01, 8.3579e-01, 8.4483e-05],
         [9.7482e-01, 9.4865e-05, 9.7493e-01, 8.9478e-05],
         [8.3579e-01, 9.7493e-01, 9.4865e-05, 9.4869e-05],
         [8.4483e-05, 8.9478e-05, 9.4869e-05, 9.4865e-05]]],
       grad_fn=<AddBackward0>), tensor([[[1.2649e-04, 9.9262e-01, 9.5038e-01, 1.2178e-04],
         [9.9262e-01, 1.2649e-04, 9.9267e-01, 1.2411e-04],
         [9.5038e-01, 9.9267e-01, 1.2649e-04, 1.2649e-04],
         [1.

SequenceClassifierOutputWithConstituentAttention(loss=None, logits=tensor([[-0.0768,  0.0854]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)