In [2]:
import datasets
from transformers import BertTokenizer, BertConfig, BertForSequenceClassification, TrainingArguments, Trainer
import numpy as np
from src.eval import eval_preds
from src.TreeTransformer import TreeBertForSequenceClassification
import random
import torch

# torch.manual_seed(10)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

dataset = datasets.load_dataset("michaelginn/latent-trees-agreement-ID")


tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
max_length = 100
def tokenize_function(example):
    return tokenizer(example['text'], max_length=max_length, truncation=True)
dataset = dataset.map(tokenize_function, batched=True, load_from_cache_file=False)

toy_dataset = dataset['train'].select([1, 4]) # dataset['train'].select(range(1, 11))

id2label = {0: "VIOLATION", 1: "GRAMMATICAL"}
label2id = {"VIOLATION": 0, "GRAMMATICAL": 1}

pretrained = False
if pretrained:
    config = BertConfig.from_pretrained('bert-base-uncased', num_labels=2, id2label=id2label, label2id=label2id)
else:
    # Create random initialized BERT model
    config = BertConfig(num_labels=2, id2label=id2label, label2id=label2id)

# model = TreeBertForSequenceClassification(config=config, disable_treeing=True).to('mps')
model = BertForSequenceClassification(config=config).to('mps')

args = TrainingArguments(
    output_dir=f"../training-checkpoints",
    learning_rate=2e-5,
    evaluation_strategy="epoch",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=1,
    save_strategy="epoch",
    weight_decay=0,
    save_total_limit=3,
    num_train_epochs=100,
    load_best_model_at_end=False,
    logging_strategy='epoch',
)

def compute_metrics(eval_pred):
    labels = eval_pred.label_ids
    preds = np.argmax(eval_pred.predictions, axis=-1)
    print(eval_pred.predictions)
    return eval_preds(preds, labels)


trainer = Trainer(
    model,
    args,
    train_dataset= toy_dataset, # dataset['train'],
    eval_dataset= toy_dataset, # dataset['eval'], # dataset['test'].select(range(20)),
    compute_metrics=compute_metrics,
    tokenizer=tokenizer
)

trainer.train()

# preds = trainer.predict(dataset['test'].select(range(20)))
# preds

Map:   0%|          | 0/2400 [00:00<?, ? examples/s]

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Epoch,Training Loss,Validation Loss


[[-0.27624887  0.07273558]
 [-0.32365113  0.08062568]]
PREDS [1 1]
LABELS [1 0]
[[-0.27624887  0.07273558]
 [-0.32365113  0.08062568]]
PREDS [1 1]
LABELS [1 0]
[[-0.27624887  0.07273558]
 [-0.32365113  0.08062568]]
PREDS [1 1]
LABELS [1 0]
[[-0.27624887  0.07273558]
 [-0.32365113  0.08062568]]
PREDS [1 1]
LABELS [1 0]
[[-0.27624887  0.07273558]
 [-0.32365113  0.08062568]]
PREDS [1 1]
LABELS [1 0]
[[-0.27624887  0.07273558]
 [-0.32365113  0.08062568]]
PREDS [1 1]
LABELS [1 0]
[[-0.27624887  0.07273558]
 [-0.32365113  0.08062568]]
PREDS [1 1]
LABELS [1 0]
[[-0.27624887  0.07273558]
 [-0.32365113  0.08062568]]
PREDS [1 1]
LABELS [1 0]
[[-0.27624887  0.07273558]
 [-0.32365113  0.08062568]]
PREDS [1 1]
LABELS [1 0]
[[-0.27624887  0.07273558]
 [-0.32365113  0.08062568]]
PREDS [1 1]
LABELS [1 0]
[[-0.27624887  0.07273558]
 [-0.32365113  0.08062568]]
PREDS [1 1]
LABELS [1 0]
[[-0.27624887  0.07273558]
 [-0.32365113  0.08062568]]
PREDS [1 1]
LABELS [1 0]
[[-0.27624887  0.07273558]
 [-0.32365113

KeyboardInterrupt: 

In [1]:
from transformers import BertConfig
from src.TreeTransformer import TreeBertForSequenceClassification
import torch


config = BertConfig(num_labels=2)
model = TreeBertForSequenceClassification(config=config)
model(input_ids=torch.tensor([[1, 2, 3, 0], [2, 1, 0, 1]]), attention_mask=torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]]), return_dict=True)

hidden torch.Size([2, 4, 768])
a tensor([[0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1],
        [0, 0, 0, 0]], dtype=torch.int32)
b tensor([[1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1]], dtype=torch.int32)
c tensor([[0, 0, 0, 0],
        [1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 1, 0]], dtype=torch.int32)
attention mask torch.Size([2, 4])
a+c torch.Size([4, 4])
mask torch.Size([2, 4, 4])
scores torch.Size([2, 4, 4])
scores masked torch.Size([2, 4, 4])
query torch.Size([2, 12, 4, 64])
combined tensor([[[1, 1, 1, 0],
         [1, 1, 1, 0],
         [1, 1, 1, 0],
         [1, 1, 1, 1]],

        [[1, 1, 0, 0],
         [1, 1, 0, 0],
         [1, 1, 1, 0],
         [1, 1, 0, 1]]])
combined unsqueezed tensor([[[[1, 1, 1, 0],
          [1, 1, 1, 0],
          [1, 1, 1, 0],
          [1, 1, 1, 1]],

         [[1, 1, 1, 0],
          [1, 1, 1, 0],
          [1, 1, 1, 0],
          [1, 1, 1, 1]],

         [[1, 1, 1, 0],
          [1, 1, 1, 

SequenceClassifierOutputWithConstituentAttention(loss=None, logits=tensor([[-0.0112, -0.4408],
        [-0.0646, -0.1347]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [3]:
a = torch.IntTensor([[1, 1, 1, 0]])
mask = torch.IntTensor([[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])

results = []
for row in a:
    results.append(row & mask)
result = torch.stack(results)
result.size()

torch.Size([1, 3, 4])

In [9]:
import numpy as np

a = torch.from_numpy(np.diag(np.ones(4 - 1, dtype=np.int32), 1))
a

tensor([[0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1],
        [0, 0, 0, 0]], dtype=torch.int32)

In [9]:
toy_dataset[2]

{'text': 'these ducks and these wugs kiss the ducks that laughs',
 'labels': 0,
 'input_ids': [101,
  2122,
  14875,
  1998,
  2122,
  8814,
  5620,
  3610,
  1996,
  14875,
  2008,
  11680,
  102],
 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [2]:
toy_dataset['labels']

[1, 1, 0, 0, 0, 0, 1, 1, 0, 1]

In [12]:
dataset['train'].select([1, 4])['text']

['she ponders', 'I kicks a duck']

In [4]:
dataset['train']['labels']

[0,
 1,
 1,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 1,
 0,
 1,
 0,
 1,
 1,
 1,
 0,
 1,
 0,
 0,
 1,
 1,
 0,
 1,
 1,
 0,
 1,
 0,
 0,
 1,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 1,
 1,
 0,
 1,
 0,
 1,
 0,
 1,
 1,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 1,
 1,
 1,
 0,
 1,
 0,
 1,
 1,
 0,
 1,
 1,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 0,
 1,
 1,
 0,
 0,
 1,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 1,
 0,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 0,
 0,
 1,
 1,
 0,
 1,
 0,
 0,
 1,
 0,
 1,
 1,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 0,
 0,
 1,
 1,
 1,
 0,
 0,
 0,
 1,
 1,
 1,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 1,
 1,
 0,
 1,
 0,
 0,
 1,
 1,
 0,
 0,
