In [1]:
import os
from typing import List

%pip install datasets
%pip install transformers
import torch
from datasets import load_dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from torch.utils.data import DataLoader
from transformers import RobertaForSequenceClassification, RobertaTokenizer, Trainer, TrainingArguments



In [2]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForSequenceClassification.from_pretrained('roberta-base', return_dict=True)

# Taking only subset of data (faster training, fine-tuning the whole dataset takes ~20 hours per epoch)
DATA_SIZE=10_000
dataset = load_dataset("yelp_polarity", split="train").train_test_split(train_size=DATA_SIZE)["train"]
train_test_split = dataset.train_test_split(train_size=0.9, seed=42)
train_dataset = train_test_split["train"]
val_dataset = train_test_split["test"]

print(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}")

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifie

Train size: 9000, Validation size: 1000


In [3]:
class DataCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        
    def __call__(self, examples: List[dict]):
        labels = [example['label'] for example in examples]
        texts = [example['text'] for example in examples]
        tokenizer_output = self.tokenizer(texts, truncation=True, padding=True)
        return {
            'labels': torch.tensor(labels), 
            'input_ids': torch.tensor(tokenizer_output['input_ids']), 
            'attention_mask': torch.tensor(tokenizer_output['attention_mask'])
            }
    
data_collator = DataCollator(tokenizer)

I thought, that using my own DataCollator would slow things down. However, it turns out that it does ~1.20s / it, compared to 1.46s/it of default data collator. So the speed of data loading is not an issue here. The speedup may be due to the smaller sequence length of some batches (it is the same speed after using padding='max_length' strategy).


In [4]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [5]:
training_args = TrainingArguments(
    learning_rate=3e-5,
    weight_decay=0.01
    ,
    output_dir='./results',
    num_train_epochs=2,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,  # actual batch size: 16 (as suggested in Bert paper)
    warmup_steps=250,  # don't have any intuition for the right value here
    logging_dir='./logs',
    logging_steps=25,
    save_steps=250,
    eval_steps=250,
    
    evaluation_strategy='steps', # evaluation every eval_steps (without it no evaluation is done)
    remove_unused_columns=False,
    no_cuda=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)


trainer.train()

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=2.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1125.0, style=ProgressStyle(description_w…

{'loss': 0.706366195678711, 'learning_rate': 3e-06, 'epoch': 0.044444444444444446, 'total_flos': 120486946054080, 'step': 25}
{'loss': 0.6933075714111329, 'learning_rate': 6e-06, 'epoch': 0.08888888888888889, 'total_flos': 231341158810560, 'step': 50}
{'loss': 0.6836430358886719, 'learning_rate': 9e-06, 'epoch': 0.13333333333333333, 'total_flos': 347825434941600, 'step': 75}
{'loss': 0.4069061279296875, 'learning_rate': 1.2e-05, 'epoch': 0.17777777777777778, 'total_flos': 467917498761120, 'step': 100}
{'loss': 0.2678303527832031, 'learning_rate': 1.5e-05, 'epoch': 0.2222222222222222, 'total_flos': 580961513000160, 'step': 125}
{'loss': 0.2453839111328125, 'learning_rate': 1.8e-05, 'epoch': 0.26666666666666666, 'total_flos': 690619112924640, 'step': 150}
{'loss': 0.2029010009765625, 'learning_rate': 2.1e-05, 'epoch': 0.3111111111111111, 'total_flos': 802801565924640, 'step': 175}
{'loss': 0.224925537109375, 'learning_rate': 2.4e-05, 'epoch': 0.35555555555555557, 'total_flos': 9188131799

HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=125.0, style=ProgressStyle(description_w…


{'eval_loss': 0.1470115454941988, 'eval_accuracy': 0.949, 'eval_f1': 0.9486404833836858, 'eval_precision': 0.9476861167002012, 'eval_recall': 0.9495967741935484, 'epoch': 0.4444444444444444, 'total_flos': 1137571954869120, 'step': 250}
{'loss': 0.2541351318359375, 'learning_rate': 2.914187643020595e-05, 'epoch': 0.4888888888888889, 'total_flos': 1253182703632800, 'step': 275}
{'loss': 0.1855731201171875, 'learning_rate': 2.8283752860411898e-05, 'epoch': 0.5333333333333333, 'total_flos': 1373107241655840, 'step': 300}
{'loss': 0.2211431884765625, 'learning_rate': 2.742562929061785e-05, 'epoch': 0.5777777777777777, 'total_flos': 1497369501194880, 'step': 325}
{'loss': 0.24103790283203125, 'learning_rate': 2.65675057208238e-05, 'epoch': 0.6222222222222222, 'total_flos': 1610736600898560, 'step': 350}
{'loss': 0.1191278076171875, 'learning_rate': 2.5709382151029748e-05, 'epoch': 0.6666666666666666, 'total_flos': 1716445378477440, 'step': 375}
{'loss': 0.2090057373046875, 'learning_rate': 

HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=125.0, style=ProgressStyle(description_w…


{'eval_loss': 0.26251811466738584, 'eval_accuracy': 0.932, 'eval_f1': 0.9347408829174665, 'eval_precision': 0.891941391941392, 'eval_recall': 0.9818548387096774, 'epoch': 0.8888888888888888, 'total_flos': 2293553798158560, 'step': 500}
{'loss': 0.14956298828125, 'learning_rate': 2.0560640732265445e-05, 'epoch': 0.9333333333333333, 'total_flos': 2409918413006400, 'step': 525}
{'loss': 0.15126953125, 'learning_rate': 1.9702517162471395e-05, 'epoch': 0.9777777777777777, 'total_flos': 2528861728507200, 'step': 550}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1125.0, style=ProgressStyle(description_w…

{'loss': 0.1243646240234375, 'learning_rate': 1.884439359267735e-05, 'epoch': 1.023111111111111, 'total_flos': 2650886322050400, 'step': 575}
{'loss': 0.1008349609375, 'learning_rate': 1.7986270022883295e-05, 'epoch': 1.0675555555555556, 'total_flos': 2769566382728160, 'step': 600}
{'loss': 0.062974853515625, 'learning_rate': 1.7128146453089245e-05, 'epoch': 1.112, 'total_flos': 2884955758117920, 'step': 625}
{'loss': 0.1151953125, 'learning_rate': 1.6270022883295195e-05, 'epoch': 1.1564444444444444, 'total_flos': 2998478417489760, 'step': 650}
{'loss': 0.1976959228515625, 'learning_rate': 1.5411899313501142e-05, 'epoch': 1.200888888888889, 'total_flos': 3116709748355520, 'step': 675}
{'loss': 0.122261962890625, 'learning_rate': 1.4553775743707094e-05, 'epoch': 1.2453333333333334, 'total_flos': 3224961328202400, 'step': 700}
{'loss': 0.1538525390625, 'learning_rate': 1.3695652173913042e-05, 'epoch': 1.2897777777777777, 'total_flos': 3342558454267200, 'step': 725}
{'loss': 0.10355834960

HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=125.0, style=ProgressStyle(description_w…


{'eval_loss': 0.25463570963125676, 'eval_accuracy': 0.95, 'eval_f1': 0.9498997995991985, 'eval_precision': 0.9442231075697212, 'eval_recall': 0.9556451612903226, 'epoch': 1.3342222222222222, 'total_flos': 3454441754059200, 'step': 750}
{'loss': 0.1496527099609375, 'learning_rate': 1.1979405034324942e-05, 'epoch': 1.3786666666666667, 'total_flos': 3565002796671840, 'step': 775}
{'loss': 0.1710833740234375, 'learning_rate': 1.1121281464530894e-05, 'epoch': 1.423111111111111, 'total_flos': 3677077554516960, 'step': 800}
{'loss': 0.1243231201171875, 'learning_rate': 1.0263157894736843e-05, 'epoch': 1.4675555555555555, 'total_flos': 3787806122926080, 'step': 825}
{'loss': 0.1135552978515625, 'learning_rate': 9.405034324942791e-06, 'epoch': 1.512, 'total_flos': 3908293068980160, 'step': 850}
{'loss': 0.1376123046875, 'learning_rate': 8.546910755148743e-06, 'epoch': 1.5564444444444443, 'total_flos': 4026560298230880, 'step': 875}
{'loss': 0.144462890625, 'learning_rate': 7.688787185354691e-0

HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=125.0, style=ProgressStyle(description_w…


{'eval_loss': 0.18014563730172814, 'eval_accuracy': 0.958, 'eval_f1': 0.9579158316633266, 'eval_precision': 0.952191235059761, 'eval_recall': 0.9637096774193549, 'epoch': 1.7786666666666666, 'total_flos': 4603004597790240, 'step': 1000}
{'loss': 0.04727783203125, 'learning_rate': 3.3981693363844395e-06, 'epoch': 1.8231111111111111, 'total_flos': 4718758940093760, 'step': 1025}
{'loss': 0.08021484375, 'learning_rate': 2.540045766590389e-06, 'epoch': 1.8675555555555556, 'total_flos': 4833789331633920, 'step': 1050}
{'loss': 0.1810028076171875, 'learning_rate': 1.6819221967963386e-06, 'epoch': 1.912, 'total_flos': 4949663335220640, 'step': 1075}
{'loss': 0.06732666015625, 'learning_rate': 8.237986270022883e-07, 'epoch': 1.9564444444444444, 'total_flos': 5066805748409280, 'step': 1100}




TrainOutput(global_step=1124, training_loss=0.19723334261531084)