In [None]:
!pip install transformers datasets

In [None]:
from datasets import load_dataset
import numpy as np

In [None]:
raw_datasets = load_dataset("glue", "rte")

In [None]:
raw_datasets

In [None]:
raw_datasets["train"].features

In [None]:
raw_datasets['train']['sentence1'][:10]

In [None]:
checkpoint = "distilbert-base-cased"

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments

In [None]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
tokenizer(
    raw_datasets['train']['sentence1'][0],
    raw_datasets['train']['sentence2'][0],
)

In [None]:
result = _

# distilbert doesn't use token_type_ids
result.keys()

In [None]:
tokenizer.decode(result['input_ids'])

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    checkpoint, num_labels = 2
)

In [None]:
training_args = TrainingArguments(
    output_dir='training_dir', 
    evaluation_strategy='epoch', 
    save_strategy='epoch', 
    num_train_epochs=5, 
    per_device_train_batch_size=16, 
    per_device_eval_batch_size=64, 
    logging_steps=150, # otherwise no log will appear under training loss
)

In [None]:
from datasets import load_metric 
metric = load_metric("glue", "rte")

In [None]:
metric.compute(predictions=[1, 0, 1], references = [1, 0, 0])

In [None]:
from sklearn.metrics import f1_score

In [None]:
def compute_metrics(logits_and_labels): 
    logits, labels = logits_and_labels
    predictions = np.argmax(logits, axis=-1)
    acc = np.mean(predictions == labels) 
    f1 = f1_score(labels, predictions) 
    return {'accuracy': acc, 'f1': f1} 

In [None]:
def tokenize_fn(batch): 
    return tokenizer(batch['sentence1'], batch['sentence2'], truncation=True)

tokenized_datasets = raw_datasets.map(tokenize_fn, batched = True) 


In [None]:
trainer = Trainer(
    model, 
    training_args, 
    train_dataset=tokenized_datasets['train'], 
    eval_dataset = tokenized_datasets['validation'], 
    tokenizer=tokenizer, 
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()