In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding

raw_dataset = load_dataset('nyu-mll/glue', 'mrpc')

checkpoint = 'google-bert/bert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)


def tokenize_function(examples):
    return tokenizer(
        examples['sentence1'],
        examples['sentence2'],
        truncation=True,
    )


tokenized_dataset = raw_dataset.map(tokenize_function, batched=True)
tokenized_dataset = tokenized_dataset.remove_columns(['sentence1', 'sentence2', 'idx'])
tokenized_dataset = tokenized_dataset.rename_column('label', 'labels')
tokenized_dataset = tokenized_dataset.with_format('torch')

data_collator = DataCollatorWithPadding(tokenizer)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    tokenized_dataset['train'],
    shuffle=True,
    batch_size=64,
    collate_fn=data_collator,
)

val_dataloder = DataLoader(
    tokenized_dataset['validation'],
    batch_size=64,
    collate_fn=data_collator,
)

In [3]:
for batch in train_dataloader:
    break
print({k: v.shape for k, v in batch.items()})

{'labels': torch.Size([64]), 'input_ids': torch.Size([64, 90]), 'token_type_ids': torch.Size([64, 90]), 'attention_mask': torch.Size([64, 90])}


In [4]:
from transformers import AutoModelForSequenceClassification

checkpoint = 'bert-base-cased'

model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

outputs = model(**batch)
print(outputs.loss, outputs.logits.shape)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tensor(0.6113, grad_fn=<NllLossBackward0>) torch.Size([64, 2])


In [5]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=5e-5)

In [7]:
import torch

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model.to(device)

device

device(type='cuda')

In [8]:
from transformers import get_scheduler

num_epochs = 5
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    'linear',
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

In [10]:
progress = 1

model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        print(f'{progress}/{num_training_steps}. Loss: {loss}')
        progress += 1

1/290. Loss: 0.4184406101703644
2/290. Loss: 0.4247891902923584
3/290. Loss: 0.5021761059761047
4/290. Loss: 0.4245592951774597
5/290. Loss: 0.3845520317554474
6/290. Loss: 0.3684302866458893
7/290. Loss: 0.47829633951187134
8/290. Loss: 0.4160231649875641
9/290. Loss: 0.39048781991004944
10/290. Loss: 0.42099347710609436
11/290. Loss: 0.3905502259731293
12/290. Loss: 0.3801862597465515
13/290. Loss: 0.4951932430267334
14/290. Loss: 0.48998838663101196
15/290. Loss: 0.36677631735801697
16/290. Loss: 0.2566084861755371
17/290. Loss: 0.40626609325408936
18/290. Loss: 0.33837369084358215
19/290. Loss: 0.34840187430381775
20/290. Loss: 0.37006378173828125
21/290. Loss: 0.3054410219192505
22/290. Loss: 0.3960416615009308
23/290. Loss: 0.27872946858406067
24/290. Loss: 0.312818706035614
25/290. Loss: 0.3463079631328583
26/290. Loss: 0.26707780361175537
27/290. Loss: 0.3592855930328369
28/290. Loss: 0.31880733370780945
29/290. Loss: 0.34505394101142883
30/290. Loss: 0.29686683416366577
31/290

In [12]:
import evaluate

metric = evaluate.load('glue', 'mrpc')

model.eval()

for batch in val_dataloder:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch['labels'])

metric.compute()

Downloading builder script: 100%|██████████| 5.75k/5.75k [00:00<00:00, 14.9MB/s]


{'accuracy': 0.8382352941176471, 'f1': 0.8862068965517241}