# Augmentation pipeline: MLM Insertion
This notebook serves as an example how to do build data processing and model training pipeline to run experiments on data augmentation. We consider MLM insertion augmentation as an example. The augmentation we use is simply augmentation on the fly when retrieving samples from torch's Dataset (like in Computer Vision augmentations). The disadvantage is clearly the speed of computation: model-based augmentation is computationally (and memory) expensive.

On AGNews dataset of size 1000 it runs in 3.5 minutes compared to 2 without augmentation. I think it could be sped up when using more jobs for data loading (that would probably require using script and not notebook environment).

In [1]:
import json
import os
import random
from typing import List

%pip install -U datasets
%pip install transformers
import torch
from datasets import load_dataset
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from torch.utils.data import DataLoader
from transformers import AutoModelForMaskedLM, AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments


ROOT_DIR = "drive/My Drive/Colab Notebooks/nlp/results/ag_news_mlm_insertion"
if not os.path.exists(ROOT_DIR):
    os.mkdir(ROOT_DIR)

Requirement already up-to-date: datasets in /usr/local/lib/python3.6/dist-packages (1.1.2)


## Defining augmentation function

In [2]:
class MLMInsertionAugmenter:
    def __init__(self, model, tokenizer, p: float, min_mask: int = 1, device=None):
        self.model = model.eval()
        self.tokenizer = tokenizer
        self.mask_token = tokenizer.mask_token
        self.mask_token_id = tokenizer.mask_token_id
        self.min_mask = min_mask
        self.p = p
        self.device = device or torch.device('cpu')
        
    def __call__(self, text: str):
        words = np.array(text.split(), dtype='>U6')
        n_mask = max(self.min_mask, int(len(words) * self.p))
        masked_indices = np.sort(np.random.choice(len(words) + 1, size=n_mask))

        masked_words = np.insert(words, masked_indices, self.mask_token)
        masked_text = " ".join(masked_words)
        
        tokenizer_output = self.tokenizer([masked_text])
        input_ids = torch.tensor(tokenizer_output['input_ids']).to(self.device)
        attention_mask = torch.tensor(tokenizer_output['attention_mask']).to(self.device)
        with torch.no_grad():
            output = self.model(input_ids)

        predicted_logits = output.logits[input_ids == self.mask_token_id]
        predicted_tokens = predicted_logits.argmax(1)
        predicted_words = [tokenizer.decode(token.item()).strip() for token in predicted_tokens]  # stripping to avoid multiple spaces
        
        new_words = np.insert(words, masked_indices, predicted_words)
        new_text = " ".join(new_words)
        return new_text

In [3]:
class DatasetWithAugmentation(torch.utils.data.Dataset):
    def __init__(self, dataset, augmenter, augmentation_prob: float = 0.9):
        self.dataset = dataset
        self.augmenter = augmenter
        self.augmentation_prob = augmentation_prob

    def __getitem__(self, i):
        item = self.dataset[i]
        if random.random() < self.augmentation_prob:
            item['text'] = self.augmenter(item['text'])
        return item

    def __len__(self):
        return len(self.dataset)

In [4]:
def get_datasets(dataset_name, augmenter, train_size, val_size=5_000, test_size=None, augmentation_prob = 0.9, random_seed: int = 42):
    """Returns """
    dataset = load_dataset(dataset_name, split="train")
    test_dataset = load_dataset(dataset_name, split="test")
    # We want test and validation data to be the same for every experiment
    if test_size:
        test_dataset = test_dataset.train_test_split(test_size=test_size, seed=random_seed)["test"]
    train_val_split = dataset.train_test_split(test_size=val_size, seed=random_seed)
    # Validation and test sets
    train_dataset = train_val_split["train"].train_test_split(train_size=train_size, seed=random_seed)["train"]
    train_dataset = DatasetWithAugmentation(train_dataset, augmenter, augmentation_prob=augmentation_prob)
    val_dataset = train_val_split["test"]
    return train_dataset, val_dataset, test_dataset


class DataCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        
    def __call__(self, examples: List[dict]):
        labels = [example['label'] for example in examples]
        texts = [example['text'] for example in examples]
        tokenizer_output = self.tokenizer(texts, truncation=True, padding=True)
        return {
            'labels': torch.tensor(labels), 
            'input_ids': torch.tensor(tokenizer_output['input_ids']), 
            'attention_mask': torch.tensor(tokenizer_output['attention_mask'])
            }


def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, micro_f1, _ = precision_recall_fscore_support(labels, preds, average='micro', zero_division=0)
    _, _, macro_f1, _ = precision_recall_fscore_support(labels, preds, average='macro', zero_division=0)
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'micro_f1': micro_f1,
        'micro_precision': precision,
        'micro_recall': recall,
        'macro_f1': macro_f1
    }

In [5]:
AUGMENTATION_PROB = 0.5
AUGMENTATION_FRACTION = 0.1

tokenizer = AutoTokenizer.from_pretrained('roberta-base', use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained('roberta-base', return_dict=True, num_labels=4)
data_collator = DataCollator(tokenizer)

device = torch.device('cuda')
mlm_model = AutoModelForMaskedLM.from_pretrained('roberta-base', return_dict=True).eval().to(device)
augmenter = MLMInsertionAugmenter(mlm_model, tokenizer, AUGMENTATION_FRACTION, device=device)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifie

In [6]:
train_size = 1000
train_dataset, val_dataset, test_dataset = get_datasets("ag_news", augmenter, train_size, val_size=5_000, augmentation_prob=AUGMENTATION_PROB)
print(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}, Test size: {len(test_dataset)}")
print(train_dataset[0])
print(val_dataset[0])
print(test_dataset[0])
output_dir = os.path.join(ROOT_DIR, f"train_size_{train_size}")

num_train_epochs = 7

# https://huggingface.co/transformers/main_classes/trainer.html#trainingarguments
training_args = TrainingArguments(
    learning_rate=3e-5,
    weight_decay=0.01,
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    warmup_steps=0,  # don't have any intuition for the right value here
    logging_dir=output_dir,
    logging_steps=10,
    load_best_model_at_end=True,
    evaluation_strategy='epoch',
    remove_unused_columns=False,
    no_cuda=False,
    metric_for_best_model="eval_accuracy"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics
    
)

trainer.train()

test_result = trainer.evaluate(test_dataset)

print(test_result)

with open(os.path.join(output_dir, 'test_result.json'), 'w') as f:
    json.dump(test_result, f, indent=4)

Using custom data configuration default
Reusing dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Using custom data configuration default
Reusing dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Loading cached split indices for dataset at /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a/cache-2db6642321dfdef3.arrow and /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a/cache-6ff55a4f79810bf1.arrow
Loading cached split indices for dataset at /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a/cache-96a0b3f12072cafc.arrow and /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e90

Train size: 1000, Validation size: 5000, Test size: 7600
{'label': 3, 'text': 'Study: High-t firms praise for online custom respec While many high-t firms scored well overal in a new study of the how they treat custom online more than a third of respon the survey compan still share person data withou permis'}
{'label': 0, 'text': 'Bangladesh paralysed by strikes Opposition activists have brought many towns and cities in Bangladesh to a halt, the day after 18 people died in explosions at a political rally.'}
{'label': 2, 'text': "Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul."}


Epoch,Training Loss,Validation Loss,Accuracy,Micro F1,Micro Precision,Micro Recall,Macro F1
1,0.450826,0.329737,0.8918,0.8918,0.8918,0.8918,0.890798
2,0.200697,0.349792,0.8938,0.8938,0.8938,0.8938,0.892341
3,0.246281,0.385103,0.886,0.886,0.886,0.886,0.883295
4,0.209226,0.377906,0.8968,0.8968,0.8968,0.8968,0.895536
5,0.048192,0.441861,0.8946,0.8946,0.8946,0.8946,0.893212
6,0.063677,0.488496,0.8912,0.8912,0.8912,0.8912,0.890379
7,0.086549,0.479927,0.8938,0.8938,0.8938,0.8938,0.893098


{'eval_loss': 0.3638724684715271, 'eval_accuracy': 0.9040789473684211, 'eval_micro_f1': 0.9040789473684211, 'eval_micro_precision': 0.9040789473684211, 'eval_micro_recall': 0.9040789473684211, 'eval_macro_f1': 0.9037342459041019, 'epoch': 7.0, 'total_flos': 499592021664000}
