# Roberta Classifier on SNLI: baseline
Stanford Natural Langauge Ingerence dataset.
https://huggingface.co/datasets/snli

The SNLI corpus (version 1.0) is a collection of 570k human-written English sentence pairs manually labeled for balanced classification with the labels entailment, contradiction, and neutral, supporting the task of natural language inference (NLI), also known as recognizing textual entailment (RTE).

* training examples: 550152
* test examples: 10000

3 labels.

In [1]:
import json
import os
from typing import List

%pip install datasets
%pip install transformers
import torch
from datasets import load_dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments


ROOT_DIR = "drive/My Drive/Colab Notebooks/nlp/results/snli_baseline"
if not os.path.exists(ROOT_DIR):
    os.mkdir(ROOT_DIR)

Collecting datasets
[?25l  Downloading https://files.pythonhosted.org/packages/f0/f4/2a3d6aee93ae7fce6c936dda2d7f534ad5f044a21238f85e28f0b205adf0/datasets-1.1.2-py3-none-any.whl (147kB)
[K     |██▎                             | 10kB 30.4MB/s eta 0:00:01[K     |████▌                           | 20kB 22.9MB/s eta 0:00:01[K     |██████▊                         | 30kB 17.6MB/s eta 0:00:01[K     |█████████                       | 40kB 14.8MB/s eta 0:00:01[K     |███████████▏                    | 51kB 13.4MB/s eta 0:00:01[K     |█████████████▍                  | 61kB 12.2MB/s eta 0:00:01[K     |███████████████▋                | 71kB 12.2MB/s eta 0:00:01[K     |█████████████████▉              | 81kB 10.9MB/s eta 0:00:01[K     |████████████████████            | 92kB 11.1MB/s eta 0:00:01[K     |██████████████████████▎         | 102kB 11.1MB/s eta 0:00:01[K     |████████████████████████▌       | 112kB 11.1MB/s eta 0:00:01[K     |██████████████████████████▊     | 122kB 11

In [None]:
dataset = load_dataset("snli", split="test")
dataset[1]

Reusing dataset snli (/root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)


{'hypothesis': 'The church is filled with song.',
 'label': 0,
 'premise': 'This church choir sings to the masses as they sing joyous songs from the book at a church.'}

In [2]:
def get_datasets(dataset_name, train_size, val_size=1_000, test_size=None, random_seed: int = 42):
    dataset = load_dataset(dataset_name, split="train").filter(lambda d: d['label'] != -1)
    test_dataset = load_dataset(dataset_name, split="test").filter(lambda d: d['label'] != -1)
    # We want test and validation data to be the same for every experiment
    if test_size:
        test_dataset = test_dataset.train_test_split(test_size=test_size, seed=random_seed)["test"]
    train_val_split = dataset.train_test_split(test_size=val_size, seed=random_seed)
    # Validation and test sets
    train_dataset = train_val_split["train"].train_test_split(train_size=train_size, seed=random_seed)["train"]
    val_dataset = train_val_split["test"]
    return train_dataset, val_dataset, test_dataset


class DataCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        
    def __call__(self, examples: List[dict]):
        labels = [example['label'] for example in examples]
        texts = [example['hypothesis'] + f' {self.tokenizer.sep_token} ' + example['premise'] for example in examples]
        tokenizer_output = self.tokenizer(texts, truncation=True, padding=True)
        return {
            'labels': torch.tensor(labels), 
            'input_ids': torch.tensor(tokenizer_output['input_ids']), 
            'attention_mask': torch.tensor(tokenizer_output['attention_mask'])
            }


def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, micro_f1, _ = precision_recall_fscore_support(labels, preds, average='micro', zero_division=0)
    _, _, macro_f1, _ = precision_recall_fscore_support(labels, preds, average='macro', zero_division=0)
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'micro_f1': micro_f1,
        'micro_precision': precision,
        'micro_recall': recall,
        'macro_f1': macro_f1
    }

In [7]:
tokenizer = AutoTokenizer.from_pretrained('roberta-base', use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained('roberta-base', return_dict=True, num_labels=3)
data_collator = DataCollator(tokenizer)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=481.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898823.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1355863.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=501200538.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifie

In [None]:
TRAIN_SIZES = [20, 100, 1_000, 10_000, 100_000]
for train_size in TRAIN_SIZES:
    train_dataset, val_dataset, test_dataset = get_datasets("snli", train_size, val_size=7_000)
    print(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}, Test size: {len(test_dataset)}")
    print(train_dataset[0])
    print(val_dataset[0])
    print(test_dataset[0])
    output_dir = os.path.join(ROOT_DIR, f"train_size_{train_size}")

    num_train_epochs = 6 if train_size <= 10_000 else 3

    # https://huggingface.co/transformers/main_classes/trainer.html#trainingarguments
    training_args = TrainingArguments(
        learning_rate=3e-5,
        weight_decay=0.01,
        output_dir=output_dir,
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        warmup_steps=0,  # don't have any intuition for the right value here
        logging_dir=output_dir,
        logging_steps=10,
        load_best_model_at_end=True,
        evaluation_strategy='epoch',
        remove_unused_columns=False,
        no_cuda=False,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
        compute_metrics=compute_metrics
        
    )

    trainer.train()

    test_result = trainer.evaluate(test_dataset)

    print(test_result)

    with open(os.path.join(output_dir, 'test_result.json'), 'w') as f:
        json.dump(test_result, f, indent=4)

    print()

Reusing dataset snli (/root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)
Loading cached processed dataset at /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-c5e7cc0489dda538.arrow
Reusing dataset snli (/root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)
Loading cached processed dataset at /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-677b99ebf191dd17.arrow


Train size: 20, Validation size: 7000, Test size: 9824
{'hypothesis': 'A girl is jumping into the pool.', 'label': 2, 'premise': 'A young boy in a red life jacket is swimming in a pool.'}
{'hypothesis': 'A group of people are riding a roller coaster.', 'label': 0, 'premise': 'A group of people riding a yellow roller coaster.'}
{'hypothesis': 'The church has cracks in the ceiling.', 'label': 1, 'premise': 'This church choir sings to the masses as they sing joyous songs from the book at a church.'}


Epoch,Training Loss,Validation Loss,Accuracy,Micro F1,Micro Precision,Micro Recall,Macro F1
1,No log,1.104674,0.333,0.333,0.333,0.333,0.166542
2,No log,1.113841,0.333,0.333,0.333,0.333,0.166542
3,No log,1.129465,0.333,0.333,0.333,0.333,0.166542
4,No log,1.14892,0.333,0.333,0.333,0.333,0.166542
5,1.008167,1.169277,0.333,0.333,0.333,0.333,0.166542
6,1.008167,1.180557,0.333,0.333,0.333,0.333,0.166542


{'eval_loss': 1.1052757501602173, 'eval_accuracy': 0.32949918566775244, 'eval_micro_f1': 0.32949918566775244, 'eval_micro_precision': 0.32949918566775244, 'eval_micro_recall': 0.32949918566775244, 'eval_macro_f1': 0.1652247147997856, 'epoch': 6.0, 'total_flos': 4092441133248}



Reusing dataset snli (/root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)
Loading cached processed dataset at /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-c5e7cc0489dda538.arrow
Reusing dataset snli (/root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)
Loading cached processed dataset at /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-677b99ebf191dd17.arrow
Loading cached split indices for dataset at /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-48e7475b404be1e0.arrow and /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-0a7ebfe579f236a

Train size: 100, Validation size: 7000, Test size: 9824
{'hypothesis': 'A person grilling outside.', 'label': 0, 'premise': 'A man is grilling in a backyard with a large shed.'}
{'hypothesis': 'A group of people are riding a roller coaster.', 'label': 0, 'premise': 'A group of people riding a yellow roller coaster.'}
{'hypothesis': 'The church has cracks in the ceiling.', 'label': 1, 'premise': 'This church choir sings to the masses as they sing joyous songs from the book at a church.'}


Epoch,Training Loss,Validation Loss,Accuracy,Micro F1,Micro Precision,Micro Recall,Macro F1
1,No log,1.122318,0.333,0.333,0.333,0.333,0.166542
2,1.123144,1.106562,0.333,0.333,0.333,0.333,0.166542
3,1.084893,1.100224,0.338857,0.338857,0.338857,0.338857,0.183123
4,1.084893,1.094307,0.377286,0.377286,0.377286,0.377286,0.280408
5,1.048840,1.097938,0.379429,0.379429,0.379429,0.379429,0.284196
6,0.975187,1.094096,0.382571,0.382571,0.382571,0.382571,0.292035


{'eval_loss': 1.0997958183288574, 'eval_accuracy': 0.3830415309446254, 'eval_micro_f1': 0.38304153094462534, 'eval_micro_precision': 0.3830415309446254, 'eval_micro_recall': 0.3830415309446254, 'eval_macro_f1': 0.29040589995792554, 'epoch': 6.0, 'total_flos': 22056702101928}



Reusing dataset snli (/root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)
Loading cached processed dataset at /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-c5e7cc0489dda538.arrow
Reusing dataset snli (/root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)
Loading cached processed dataset at /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-677b99ebf191dd17.arrow
Loading cached split indices for dataset at /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-48e7475b404be1e0.arrow and /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-0a7ebfe579f236a

Train size: 1000, Validation size: 7000, Test size: 9824
{'hypothesis': 'The man and woman are getting into a fight in public.', 'label': 1, 'premise': 'Male and female are jumping and grabbing at each other in public place.'}
{'hypothesis': 'A group of people are riding a roller coaster.', 'label': 0, 'premise': 'A group of people riding a yellow roller coaster.'}
{'hypothesis': 'The church has cracks in the ceiling.', 'label': 1, 'premise': 'This church choir sings to the masses as they sing joyous songs from the book at a church.'}


Epoch,Training Loss,Validation Loss,Accuracy,Micro F1,Micro Precision,Micro Recall,Macro F1
1,0.955532,0.832058,0.658,0.658,0.658,0.658,0.645145
2,0.664773,0.681435,0.731857,0.731857,0.731857,0.731857,0.729509
3,0.416112,0.701479,0.761286,0.761286,0.761286,0.761286,0.763894
4,0.238719,0.794674,0.769857,0.769857,0.769857,0.769857,0.772189
5,0.227689,0.887314,0.784286,0.784286,0.784286,0.784286,0.784051
6,0.153683,0.997029,0.778857,0.778857,0.778857,0.778857,0.781597


{'eval_loss': 0.6485127210617065, 'eval_accuracy': 0.7484731270358306, 'eval_micro_f1': 0.7484731270358306, 'eval_micro_precision': 0.7484731270358306, 'eval_micro_recall': 0.7484731270358306, 'eval_macro_f1': 0.7467358784262372, 'epoch': 6.0, 'total_flos': 204640005965616}



Reusing dataset snli (/root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)
Loading cached processed dataset at /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-c5e7cc0489dda538.arrow
Reusing dataset snli (/root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)
Loading cached processed dataset at /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-677b99ebf191dd17.arrow
Loading cached split indices for dataset at /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-48e7475b404be1e0.arrow and /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-0a7ebfe579f236a

Train size: 10000, Validation size: 7000, Test size: 9824
{'hypothesis': 'A boy has been playing in the leaves.', 'label': 1, 'premise': 'A boy is covered in fall leaves in the yard.'}
{'hypothesis': 'A group of people are riding a roller coaster.', 'label': 0, 'premise': 'A group of people riding a yellow roller coaster.'}
{'hypothesis': 'The church has cracks in the ceiling.', 'label': 1, 'premise': 'This church choir sings to the masses as they sing joyous songs from the book at a church.'}


Epoch,Training Loss,Validation Loss,Accuracy,Micro F1,Micro Precision,Micro Recall,Macro F1
1,0.534528,0.499843,0.830714,0.830714,0.830714,0.830714,0.82945
2,0.364825,0.526132,0.831,0.831,0.831,0.831,0.832015
3,0.289178,0.567803,0.841571,0.841571,0.841571,0.841571,0.841281
4,0.233765,0.694643,0.844857,0.844857,0.844857,0.844857,0.844117
5,0.077844,0.819037,0.841714,0.841714,0.841714,0.841714,0.841415
6,0.170276,0.904347,0.846571,0.846571,0.846571,0.846571,0.846642


{'eval_loss': 0.4842863380908966, 'eval_accuracy': 0.8422231270358306, 'eval_micro_f1': 0.8422231270358306, 'eval_micro_precision': 0.8422231270358306, 'eval_micro_recall': 0.8422231270358306, 'eval_macro_f1': 0.8415326057804551, 'epoch': 6.0, 'total_flos': 1972652355842688}



Reusing dataset snli (/root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)
Loading cached processed dataset at /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-c5e7cc0489dda538.arrow
Reusing dataset snli (/root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)
Loading cached processed dataset at /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-677b99ebf191dd17.arrow
Loading cached split indices for dataset at /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-48e7475b404be1e0.arrow and /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-0a7ebfe579f236a

Train size: 100000, Validation size: 7000, Test size: 9824
{'hypothesis': 'A man is wearing glasses', 'label': 0, 'premise': 'A man in glasses relaxing with his feet up.'}
{'hypothesis': 'A group of people are riding a roller coaster.', 'label': 0, 'premise': 'A group of people riding a yellow roller coaster.'}
{'hypothesis': 'The church has cracks in the ceiling.', 'label': 1, 'premise': 'This church choir sings to the masses as they sing joyous songs from the book at a church.'}


Epoch,Training Loss,Validation Loss,Accuracy,Micro F1,Micro Precision,Micro Recall,Macro F1
1,0.441431,0.374956,0.868857,0.868857,0.868857,0.868857,0.869015
2,0.292187,0.379542,0.876429,0.876429,0.876429,0.876429,0.875292


Buffered data was truncated after reaching the output size limit.

## Model on all data


In [None]:
train_size = 540_000
train_dataset, val_dataset, test_dataset = get_datasets("snli", train_size, val_size=7_000)
print(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}, Test size: {len(test_dataset)}")
print(train_dataset[0])
print(val_dataset[0])
print(test_dataset[0])
output_dir = os.path.join(ROOT_DIR, f"train_size_{train_size}")

num_train_epochs = 3

# https://huggingface.co/transformers/main_classes/trainer.html#trainingarguments
training_args = TrainingArguments(
    learning_rate=3e-5,
    weight_decay=0.01,
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=20,
    per_device_eval_batch_size=16,
    warmup_steps=0,  # don't have any intuition for the right value here
    logging_dir=output_dir,
    logging_steps=10,
    load_best_model_at_end=True,
    evaluation_strategy='epoch',
    remove_unused_columns=False,
    no_cuda=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics
    
)

trainer.train()

test_result = trainer.evaluate(test_dataset)

print(test_result)

with open(os.path.join(output_dir, 'test_result.json'), 'w') as f:
    json.dump(test_result, f, indent=4)

Reusing dataset snli (/root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)
Loading cached processed dataset at /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-c5e7cc0489dda538.arrow
Reusing dataset snli (/root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)
Loading cached processed dataset at /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-677b99ebf191dd17.arrow
Loading cached split indices for dataset at /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-48e7475b404be1e0.arrow and /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-0a7ebfe579f236a

Train size: 540000, Validation size: 7000, Test size: 9824
{'hypothesis': 'A blond lady is letting the bird fly away.', 'label': 2, 'premise': 'A blond lady flipping the bird to the camera.'}
{'hypothesis': 'A group of people are riding a roller coaster.', 'label': 0, 'premise': 'A group of people riding a yellow roller coaster.'}
{'hypothesis': 'The church has cracks in the ceiling.', 'label': 1, 'premise': 'This church choir sings to the masses as they sing joyous songs from the book at a church.'}


Epoch,Training Loss,Validation Loss,Accuracy,Micro F1,Micro Precision,Micro Recall,Macro F1
1,0.326758,0.314996,0.889,0.889,0.889,0.889,0.888219
2,0.335547,0.309064,0.898286,0.898286,0.898286,0.898286,0.898169
3,0.2,0.336158,0.900429,0.900429,0.900429,0.900429,0.900092


{'eval_loss': 0.27008056640625, 'eval_accuracy': 0.9108306188925082, 'eval_micro_f1': 0.9108306188925082, 'eval_micro_precision': 0.9108306188925082, 'eval_micro_recall': 0.9108306188925082, 'eval_macro_f1': 0.9106867920057496, 'epoch': 3.0, 'total_flos': 55218952216401480}


### Evaluating last model 

In [8]:
train_size = 540_000
train_dataset, val_dataset, test_dataset = get_datasets("snli", train_size, val_size=7_000)
print(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}, Test size: {len(test_dataset)}")
print(train_dataset[0])
print(val_dataset[0])
print(test_dataset[0])
output_dir = os.path.join(ROOT_DIR, f"train_size_{train_size}")

last = AutoModelForSequenceClassification.from_pretrained(os.path.join(output_dir, 'checkpoint-81000')).eval()

Reusing dataset snli (/root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)
Loading cached processed dataset at /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-9d920d7ab5cbf792.arrow
Reusing dataset snli (/root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)
Loading cached processed dataset at /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-d60c790f46ec5fe2.arrow
Loading cached split indices for dataset at /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-15f93c3c052c93f3.arrow and /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-0bf97534fd10a0e

Train size: 540000, Validation size: 7000, Test size: 9824
{'hypothesis': 'A blond lady is letting the bird fly away.', 'label': 2, 'premise': 'A blond lady flipping the bird to the camera.'}
{'hypothesis': 'A group of people are riding a roller coaster.', 'label': 0, 'premise': 'A group of people riding a yellow roller coaster.'}
{'hypothesis': 'The church has cracks in the ceiling.', 'label': 1, 'premise': 'This church choir sings to the masses as they sing joyous songs from the book at a church.'}


In [9]:
training_args = TrainingArguments(
    learning_rate=3e-5,
    weight_decay=0.01,
    output_dir=output_dir,
    num_train_epochs=3,
    per_device_train_batch_size=20,
    per_device_eval_batch_size=16,
    warmup_steps=0,  # don't have any intuition for the right value here
    logging_dir=output_dir,
    logging_steps=10,
    load_best_model_at_end=True,
    evaluation_strategy='epoch',
    remove_unused_columns=False,
    no_cuda=False,
)

trainer = Trainer(
    model=last,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics
    
)

trainer.evaluate(test_dataset)

{'eval_accuracy': 0.9115431596091205,
 'eval_loss': 0.3013544976711273,
 'eval_macro_f1': 0.9112687296958534,
 'eval_micro_f1': 0.9115431596091205,
 'eval_micro_precision': 0.9115431596091205,
 'eval_micro_recall': 0.9115431596091205}