In [1]:
import torch
from transformers import TrainingArguments, Trainer, AutoTokenizer, AutoModelForSequenceClassification
from transformers import DataCollatorWithPadding
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#initialize model and tokenizer
checkpoint = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(checkpoint, cache_dir='E:\\cache')
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, cache_dir='E:\\cache')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly i

In [3]:
#load dataset
raw_dataset = load_dataset('glue', 'mrpc', cache_dir='E:\\cache')
def tokenize_function(example):
    return tokenizer(example['sentence1'], example['sentence2'], truncation=True)
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Found cached dataset glue (E:/cache/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
100%|██████████| 3/3 [00:00<00:00, 599.87it/s]
Loading cached processed dataset at E:\cache\glue\mrpc\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\cache-763d215295102c54.arrow
Loading cached processed dataset at E:\cache\glue\mrpc\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\cache-78443801478c2082.arrow
Loading cached processed dataset at E:\cache\glue\mrpc\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\cache-51f1e1641e481bbf.arrow


In [12]:
import evaluate
import numpy as np
import pdb

#scoring metrics

def compute_metrics(eval_preds):
    metric = evaluate.load('glue', 'mrpc')
    predictions = np.argmax(eval_preds.predictions, axis=-1)
    return metric.compute(predictions=predictions, references=eval_preds.label_ids)

In [13]:
#training setup
args = TrainingArguments('E:\\cache\\bert_sequence_classification')
trainer = Trainer(
            model=model,
            args=args,
            train_dataset=tokenized_dataset['train'],
            eval_dataset=tokenized_dataset['validation'],
            data_collator=data_collator,
            tokenizer=tokenizer,
            compute_metrics=compute_metrics
            )
trainer.train()

 36%|███▋      | 500/1377 [00:45<01:23, 10.50it/s]

{'loss': 0.5062, 'learning_rate': 3.184458968772695e-05, 'epoch': 1.09}


 73%|███████▎  | 1000/1377 [01:38<00:36, 10.26it/s]

{'loss': 0.2686, 'learning_rate': 1.3689179375453886e-05, 'epoch': 2.18}


100%|██████████| 1377/1377 [02:18<00:00,  9.91it/s]

{'train_runtime': 138.9108, 'train_samples_per_second': 79.216, 'train_steps_per_second': 9.913, 'train_loss': 0.31628257291595746, 'epoch': 3.0}





TrainOutput(global_step=1377, training_loss=0.31628257291595746, metrics={'train_runtime': 138.9108, 'train_samples_per_second': 79.216, 'train_steps_per_second': 9.913, 'train_loss': 0.31628257291595746, 'epoch': 3.0})

In [14]:
#testing
# sample = [['sample sentence 1', 'sample sentence 2']]
# tokenized_sample = tokenizer(sample, truncation=True, return_tensors='pt')
# print(tokenized_sample.input_ids.shape)
predictions = trainer.predict(tokenized_dataset['validation'])
print(compute_metrics(predictions))


  0%|          | 0/51 [00:00<?, ?it/s]

100%|██████████| 51/51 [00:03<00:00, 13.18it/s]


{'accuracy': 0.8921568627450981, 'f1': 0.9246575342465753}
