In [28]:
from src.WhitespaceTokenizer import WhitespaceTokenizer
import datasets
from transformers import BertConfig, BertForSequenceClassification

dataset = datasets.load_dataset("michaelginn/latent-trees-agreement")
tokenizer = WhitespaceTokenizer(max_length=50)
tokenizer.learn_vocab([row['text'] for row in dataset['train']])
dataset = dataset.map(tokenizer.tokenize_batch, batched=True, load_from_cache_file=False)
print(dataset)

config = BertConfig(vocab_size=tokenizer.vocab_size, num_labels=2, max_position_embeddings=tokenizer.model_max_length)
print(config)
model = BertForSequenceClassification(config=config).to('mps')

Downloading readme:   0%|          | 0.00/583 [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/36.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/16.5k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/2800 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1200 [00:00<?, ? examples/s]

Map:   0%|          | 0/2800 [00:00<?, ? examples/s]

Map:   0%|          | 0/1200 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'labels', 'input_ids', 'attention_mask'],
        num_rows: 2800
    })
    test: Dataset({
        features: ['text', 'labels', 'input_ids', 'attention_mask'],
        num_rows: 1200
    })
})
BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 50,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.21.3",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 80
}



In [35]:
from transformers import Trainer, TrainingArguments
from src.eval import eval_preds
import numpy as np

batch_size = 1
train_epochs = 10

device='mps'

# toy_data = [dataset['train'][0], dataset['train'][26]]

toy_data = dataset['train'].select([0, 1])

def compute_metrics(eval_pred):
    labels = eval_pred.label_ids
    preds = np.argmax(eval_pred.predictions, axis=-1)
    return eval_preds(preds, labels)

args = TrainingArguments(
    output_dir=f"../training-checkpoints",
    evaluation_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=1,
    save_strategy="epoch",
    save_total_limit=3,
    num_train_epochs=train_epochs,
    load_best_model_at_end=False,
    logging_strategy='steps',
    logging_steps=1,
    report_to='wandb'
)

trainer = Trainer(
    model,
    args,
    train_dataset=toy_data,
    eval_dataset=toy_data,
    compute_metrics=compute_metrics
)

trainer.train()


PyTorch: setting up devices
The following columns in the training set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 2
  Num Epochs = 10
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 20
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Epoch,Training Loss,Validation Loss


The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 2
  Batch size = 1
Saving model checkpoint to ../training-checkpoints/checkpoint-2
Configuration saved in ../training-checkpoints/checkpoint-2/config.json
Model weights saved in ../training-checkpoints/checkpoint-2/pytorch_model.bin
Deleting older checkpoint [../training-checkpoints/checkpoint-16] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 2
  Batch size = 1
Saving model checkpoint to ../training-ch

TrainOutput(global_step=20, training_loss=0.00016453956204713903, metrics={'train_runtime': 20.6267, 'train_samples_per_second': 0.97, 'train_steps_per_second': 0.97, 'total_flos': 513888780000.0, 'train_loss': 0.00016453956204713903, 'epoch': 10.0})

In [36]:
trainer.predict(toy_data)

The following columns in the test set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 2
  Batch size = 1


PredictionOutput(predictions=array([[ 5.2712135, -5.3930755],
       [-5.2554264,  5.377078 ]], dtype=float32), label_ids=array([0, 1]), metrics={'test_loss': 2.3722368496237323e-05, 'test_accuracy': 1.0, 'test_precision': 0.9999999999999001, 'test_recall': 0.9999999999999001, 'test_f1': 0.9999999999998501, 'test_runtime': 0.1103, 'test_samples_per_second': 18.138, 'test_steps_per_second': 18.138})

In [22]:
print(toy_data[1])

{'text': 'I walks', 'label': 0, 'input_ids': [1, 57, 37, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}


In [32]:
print(dataset['train'][1])

{'text': 'a mad dog that ponders fights a dude', 'labels': 1, 'input_ids': [1, 7, 20, 12, 5, 21, 22, 7, 23, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}
