In [10]:
import torch
import numpy as np

In [4]:
import accelerate
print(accelerate.__version__)

0.25.0


In [8]:
import evaluate

In [2]:
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding

  from .autonotebook import tqdm as notebook_tqdm
2023-12-22 15:29:37.751269: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
raw_datasets = load_dataset("glue", "mrpc")
checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)


def tokenize_function(example):
    return tokenizer(example["sentence1"], example["sentence2"], truncation=True)


tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Map: 100%|██████████| 408/408 [00:00<00:00, 8988.79 examples/s]


In [11]:
def compute_metrics(eval_preds):
    metric = evaluate.load("glue", "mrpc")
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [12]:
from transformers import TrainingArguments, AutoModelForSequenceClassification, Trainer

training_args = TrainingArguments("test-trainer", evaluation_strategy="epoch", num_train_epochs=3)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
trainer.train()

  0%|          | 0/1377 [00:00<?, ?it/s]You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
 33%|███▎      | 459/1377 [03:17<10:24,  1.47it/s]
Downloading builder script: 100%|██████████| 5.75k/5.75k [00:00<00:00, 3.20MB/s]
                                                  
 33%|███▎      | 459/1377 [03:23<10:24,  1.47it/s]

{'eval_loss': 0.5216977000236511, 'eval_accuracy': 0.7450980392156863, 'eval_f1': 0.8424242424242424, 'eval_runtime': 6.2439, 'eval_samples_per_second': 65.343, 'eval_steps_per_second': 8.168, 'epoch': 1.0}


 36%|███▋      | 500/1377 [03:40<05:50,  2.50it/s]

{'loss': 0.5989, 'learning_rate': 3.184458968772695e-05, 'epoch': 1.09}


                                                  
 67%|██████▋   | 918/1377 [06:33<03:52,  1.97it/s]

{'eval_loss': 0.38146254420280457, 'eval_accuracy': 0.8455882352941176, 'eval_f1': 0.8951747088186357, 'eval_runtime': 4.2238, 'eval_samples_per_second': 96.596, 'eval_steps_per_second': 12.075, 'epoch': 2.0}


 73%|███████▎  | 1000/1377 [07:06<02:25,  2.59it/s]

{'loss': 0.3822, 'learning_rate': 1.3689179375453886e-05, 'epoch': 2.18}


                                                   
100%|██████████| 1377/1377 [09:43<00:00,  2.36it/s]

{'eval_loss': 0.5278905630111694, 'eval_accuracy': 0.8578431372549019, 'eval_f1': 0.8989547038327526, 'eval_runtime': 4.5142, 'eval_samples_per_second': 90.382, 'eval_steps_per_second': 11.298, 'epoch': 3.0}
{'train_runtime': 583.0514, 'train_samples_per_second': 18.873, 'train_steps_per_second': 2.362, 'train_loss': 0.42975300087301993, 'epoch': 3.0}





TrainOutput(global_step=1377, training_loss=0.42975300087301993, metrics={'train_runtime': 583.0514, 'train_samples_per_second': 18.873, 'train_steps_per_second': 2.362, 'train_loss': 0.42975300087301993, 'epoch': 3.0})