# Roberta Classifier on AG News: baseline
Topic classification dataset:
* 120000 training examples
* 7600 test examples

http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html . The AG's news topic classification dataset is constructed by Xiang Zhang (xiang.zhang@nyu.edu) from the dataset above. It is used as a text classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015).

In [1]:
import json
import os
from typing import List

%pip install -U datasets
%pip install transformers
import torch
from datasets import load_dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments


ROOT_DIR = "drive/My Drive/Colab Notebooks/nlp/results/ag_news_baseline"
if not os.path.exists(ROOT_DIR):
    os.mkdir(ROOT_DIR)

Requirement already up-to-date: datasets in /usr/local/lib/python3.6/dist-packages (1.1.2)


In [2]:
dataset = load_dataset("ag_news", split="train")
print(dataset[0])
from collections import Counter
labels = list()
for d in dataset:
    labels.append(d['label'])
Counter(labels)

Using custom data configuration default
Reusing dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)


{'label': 2, 'text': "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again."}


Counter({0: 30000, 1: 30000, 2: 30000, 3: 30000})

F1 calculation:
* `'micro'`:
            Calculate metrics globally by counting the total true positives,
            false negatives and false positives.
* 'macro':
            Calculate metrics for each label, and find their unweighted
            mean.  This does not take label imbalance into account.
* 'weighted':
            Calculate metrics for each label, and find their average weighted
            by support (the number of true instances for each label). This
            alters 'macro' to account for label imbalance; it can result in an
            F-score that is not between precision and recall.

In [3]:
def get_datasets(dataset_name, train_size, val_size=5_000, test_size=None, random_seed: int = 42):
    """Returns """
    dataset = load_dataset(dataset_name, split="train")
    test_dataset = load_dataset(dataset_name, split="test")
    # We want test and validation data to be the same for every experiment
    if test_size:
        test_dataset = test_dataset.train_test_split(test_size=test_size, seed=random_seed)["test"]
    train_val_split = dataset.train_test_split(test_size=val_size, seed=random_seed)
    # Validation and test sets
    train_dataset = train_val_split["train"].train_test_split(train_size=train_size, seed=random_seed)["train"]
    val_dataset = train_val_split["test"]
    return train_dataset, val_dataset, test_dataset


class DataCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        
    def __call__(self, examples: List[dict]):
        labels = [example['label'] for example in examples]
        texts = [example['text'] for example in examples]
        tokenizer_output = self.tokenizer(texts, truncation=True, padding=True)
        return {
            'labels': torch.tensor(labels), 
            'input_ids': torch.tensor(tokenizer_output['input_ids']), 
            'attention_mask': torch.tensor(tokenizer_output['attention_mask'])
            }


def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, micro_f1, _ = precision_recall_fscore_support(labels, preds, average='micro', zero_division=0)
    _, _, macro_f1, _ = precision_recall_fscore_support(labels, preds, average='macro', zero_division=0)
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'micro_f1': micro_f1,
        'micro_precision': precision,
        'micro_recall': recall,
        'macro_f1': macro_f1
    }

In [4]:
tokenizer = AutoTokenizer.from_pretrained('roberta-base', use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained('roberta-base', return_dict=True, num_labels=4)
data_collator = DataCollator(tokenizer)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifie

In [5]:
TRAIN_SIZES = [20, 100, 1_000, 10_000, 100_000]
for train_size in TRAIN_SIZES:
    train_dataset, val_dataset, test_dataset = get_datasets("ag_news", train_size, val_size=5_000)
    print(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}, Test size: {len(test_dataset)}")
    print(train_dataset[0])
    print(val_dataset[0])
    print(test_dataset[0])
    output_dir = os.path.join(ROOT_DIR, f"train_size_{train_size}")

    num_train_epochs = 7 if train_size <= 10_000 else 3

    # https://huggingface.co/transformers/main_classes/trainer.html#trainingarguments
    training_args = TrainingArguments(
        learning_rate=3e-5,
        weight_decay=0.01,
        output_dir=output_dir,
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        warmup_steps=0,  # don't have any intuition for the right value here
        logging_dir=output_dir,
        logging_steps=10,
        load_best_model_at_end=True,
        evaluation_strategy='epoch',
        remove_unused_columns=False,
        no_cuda=False,
        metric_for_best_model="eval_accuracy"
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
        compute_metrics=compute_metrics
        
    )

    trainer.train()

    test_result = trainer.evaluate(test_dataset)

    print(test_result)

    with open(os.path.join(output_dir, 'test_result.json'), 'w') as f:
        json.dump(test_result, f, indent=4)

Using custom data configuration default
Reusing dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Using custom data configuration default
Reusing dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Loading cached split indices for dataset at /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a/cache-4e9052b6731fe9c9.arrow and /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a/cache-f67b3d4078cb9cb4.arrow
Loading cached split indices for dataset at /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a/cache-5658d33989a98a69.arrow and /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e90

Train size: 20, Validation size: 5000, Test size: 7600
{'label': 0, 'text': "San Diego's Incumbent Mayor to Be Sworn In (AP) AP - A state appeals court on Tuesday lifted an order that had stopped San Diego Mayor Dick Murphy from being sworn in for a second term."}
{'label': 0, 'text': 'Bangladesh paralysed by strikes Opposition activists have brought many towns and cities in Bangladesh to a halt, the day after 18 people died in explosions at a political rally.'}
{'label': 2, 'text': "Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul."}


Epoch,Training Loss,Validation Loss,Accuracy,Micro F1,Micro Precision,Micro Recall,Macro F1
1,No log,1.38836,0.23,0.23,0.23,0.23,0.093496
2,No log,1.386693,0.23,0.23,0.23,0.23,0.093496
3,No log,1.384825,0.2426,0.2426,0.2426,0.2426,0.11867
4,No log,1.383097,0.2526,0.2526,0.2526,0.2526,0.136689
5,1.341811,1.381434,0.252,0.252,0.252,0.252,0.135455
6,1.341811,1.379814,0.253,0.253,0.253,0.253,0.137221
7,1.341811,1.378942,0.2534,0.2534,0.2534,0.2534,0.137919


{'eval_loss': 1.3726370334625244, 'eval_accuracy': 0.27157894736842103, 'eval_micro_f1': 0.27157894736842103, 'eval_micro_precision': 0.27157894736842103, 'eval_micro_recall': 0.27157894736842103, 'eval_macro_f1': 0.14159861733755555, 'epoch': 7.0, 'total_flos': 7918683121824}


Using custom data configuration default
Reusing dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Using custom data configuration default
Reusing dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Loading cached split indices for dataset at /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a/cache-4e9052b6731fe9c9.arrow and /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a/cache-f67b3d4078cb9cb4.arrow
Loading cached split indices for dataset at /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a/cache-2154b26f70248072.arrow and /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e90

Train size: 100, Validation size: 5000, Test size: 7600
{'label': 0, 'text': 'Deserter Returns to Face Charges After 39 Years Nearly 40 years after he allegedly defected to communist North Korea, US Army Sgt. Charles Jenkins was back in uniform Saturday, billeted on this American '}
{'label': 0, 'text': 'Bangladesh paralysed by strikes Opposition activists have brought many towns and cities in Bangladesh to a halt, the day after 18 people died in explosions at a political rally.'}
{'label': 2, 'text': "Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul."}


Epoch,Training Loss,Validation Loss,Accuracy,Micro F1,Micro Precision,Micro Recall,Macro F1
1,No log,1.090899,0.603,0.603,0.603,0.603,0.565862
2,1.148792,0.799287,0.7948,0.7948,0.7948,0.7948,0.786575
3,0.619679,0.578996,0.851,0.851,0.851,0.851,0.850341
4,0.619679,0.476257,0.8462,0.8462,0.8462,0.8462,0.841437
5,0.270238,0.418265,0.8624,0.8624,0.8624,0.8624,0.860406
6,0.116711,0.424284,0.8596,0.8596,0.8596,0.8596,0.857517
7,0.116711,0.428579,0.8596,0.8596,0.8596,0.8596,0.857298


{'eval_loss': 0.4224748909473419, 'eval_accuracy': 0.8617105263157895, 'eval_micro_f1': 0.8617105263157895, 'eval_micro_precision': 0.8617105263157895, 'eval_micro_recall': 0.8617105263157895, 'eval_macro_f1': 0.8612831644049799, 'epoch': 7.0, 'total_flos': 54673914897792}


Using custom data configuration default
Reusing dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Using custom data configuration default
Reusing dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Loading cached split indices for dataset at /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a/cache-4e9052b6731fe9c9.arrow and /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a/cache-f67b3d4078cb9cb4.arrow
Loading cached split indices for dataset at /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a/cache-cf0aa2ff9e909091.arrow and /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e90

Train size: 1000, Validation size: 5000, Test size: 7600
{'label': 3, 'text': 'Study: High-tech firms praised for online customer respect While many high-tech firms scored well in a new study of how they treat customers online, more than a third of the surveyed companies still share personal data without permission.'}
{'label': 0, 'text': 'Bangladesh paralysed by strikes Opposition activists have brought many towns and cities in Bangladesh to a halt, the day after 18 people died in explosions at a political rally.'}
{'label': 2, 'text': "Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul."}


Epoch,Training Loss,Validation Loss,Accuracy,Micro F1,Micro Precision,Micro Recall,Macro F1
1,0.437667,0.398292,0.8734,0.8734,0.8734,0.8734,0.872989
2,0.201537,0.34555,0.8986,0.8986,0.8986,0.8986,0.898043
3,0.144803,0.469267,0.896,0.896,0.896,0.896,0.89518
4,0.158316,0.49128,0.8964,0.8964,0.8964,0.8964,0.895643
5,0.016593,0.63416,0.8806,0.8806,0.8806,0.8806,0.880448
6,0.004025,0.588814,0.8936,0.8936,0.8936,0.8936,0.892881
7,0.010996,0.587136,0.8958,0.8958,0.8958,0.8958,0.894994


{'eval_loss': 0.33451175689697266, 'eval_accuracy': 0.8996052631578947, 'eval_micro_f1': 0.8996052631578947, 'eval_micro_precision': 0.8996052631578947, 'eval_micro_recall': 0.8996052631578947, 'eval_macro_f1': 0.8996988413146645, 'epoch': 7.0, 'total_flos': 523273281804672}


Using custom data configuration default
Reusing dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Using custom data configuration default
Reusing dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Loading cached split indices for dataset at /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a/cache-4e9052b6731fe9c9.arrow and /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a/cache-f67b3d4078cb9cb4.arrow


Train size: 10000, Validation size: 5000, Test size: 7600
{'label': 3, 'text': 'Remains of hobbit-like species found Sydney: The newly discovered remains of a previously unknown species of miniature human in the Indonesian island of Flores, 600 kilometres east of Bali, by Australian and Indonesian scientists, are being hailed as the most significant scientific find of '}
{'label': 0, 'text': 'Bangladesh paralysed by strikes Opposition activists have brought many towns and cities in Bangladesh to a halt, the day after 18 people died in explosions at a political rally.'}
{'label': 2, 'text': "Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul."}


Epoch,Training Loss,Validation Loss,Accuracy,Micro F1,Micro Precision,Micro Recall,Macro F1
1,0.293333,0.309321,0.9026,0.9026,0.9026,0.9026,0.902344
2,0.168494,0.317511,0.9186,0.9186,0.9186,0.9186,0.91768
3,0.092587,0.327511,0.9274,0.9274,0.9274,0.9274,0.926881
4,0.114929,0.455769,0.9198,0.9198,0.9198,0.9198,0.919035
5,0.129007,0.469331,0.9206,0.9206,0.9206,0.9206,0.919661
6,0.022995,0.52434,0.9222,0.9222,0.9222,0.9222,0.921295
7,0.059387,0.535899,0.9234,0.9234,0.9234,0.9234,0.922595


{'eval_loss': 0.3259236812591553, 'eval_accuracy': 0.9272368421052631, 'eval_micro_f1': 0.9272368421052631, 'eval_micro_precision': 0.9272368421052631, 'eval_micro_recall': 0.9272368421052631, 'eval_macro_f1': 0.927262873750888, 'epoch': 7.0, 'total_flos': 5129715148238208}


Using custom data configuration default
Reusing dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Using custom data configuration default
Reusing dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Loading cached split indices for dataset at /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a/cache-4e9052b6731fe9c9.arrow and /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a/cache-f67b3d4078cb9cb4.arrow


Train size: 100000, Validation size: 5000, Test size: 7600
{'label': 1, 'text': 'Santini reveals truth about Arnesen  quot;Right from the beginning there was a problem with responsibilities within the club, especially regarding recruitment, quot; the Frenchman said during half-time, with the scores still level at 1-1.'}
{'label': 0, 'text': 'Bangladesh paralysed by strikes Opposition activists have brought many towns and cities in Bangladesh to a halt, the day after 18 people died in explosions at a political rally.'}
{'label': 2, 'text': "Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul."}


Epoch,Training Loss,Validation Loss,Accuracy,Micro F1,Micro Precision,Micro Recall,Macro F1
1,0.200244,0.224033,0.9348,0.9348,0.9348,0.9348,0.933656
2,0.130688,0.214099,0.9444,0.9444,0.9444,0.9444,0.943756
3,0.104321,0.246796,0.9458,0.9458,0.9458,0.9458,0.945288


{'eval_loss': 0.22464153170585632, 'eval_accuracy': 0.9502631578947368, 'eval_micro_f1': 0.9502631578947368, 'eval_micro_precision': 0.9502631578947368, 'eval_micro_recall': 0.9502631578947368, 'eval_macro_f1': 0.9502557592927297, 'epoch': 3.0, 'total_flos': 21705436517939712}
