In [1]:
import os
import nlp
import time
import torch
import random
import logging

import numpy as np

from typing import List
from typing import Dict, Optional
from dataclasses import dataclass, field
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data.dataset import Dataset, IterableDataset
from transformers import PreTrainedTokenizer, DataCollator, PreTrainedModel
from transformers import AutoTokenizer, EvalPrediction, Trainer, HfArgumentParser, TrainingArguments, \
    AutoModelForSequenceClassification, set_seed, AutoConfig

In [2]:
set_seed(42)

In [3]:
dataset = nlp.load_dataset('glue', 'mnli')

# Let's get an idea of the data format
print(dataset)
print(dataset["train"][0])

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=28940.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=30329.0), HTML(value='')))


Downloading and preparing dataset glue/mnli (download: 298.29 MiB, generated: 78.65 MiB, post-processed: Unknown sizetotal: 376.95 MiB) to /home/mirac13/.cache/huggingface/datasets/glue/mnli/1.0.0/637080968c182118f006d3ea39dd9937940e81cfffc8d79836eaae8bba307fc4...


HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=312783507.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Dataset glue downloaded and prepared to /home/mirac13/.cache/huggingface/datasets/glue/mnli/1.0.0/637080968c182118f006d3ea39dd9937940e81cfffc8d79836eaae8bba307fc4. Subsequent calls will reuse this data.
{'train': Dataset(features: {'premise': Value(dtype='string', id=None), 'hypothesis': Value(dtype='string', id=None), 'label': ClassLabel(num_classes=3, names=['entailment', 'neutral', 'contradiction'], names_file=None, id=None), 'idx': Value(dtype='int32', id=None)}, num_rows: 392702), 'validation_matched': Dataset(features: {'premise': Value(dtype='string', id=None), 'hypothesis': Value(dtype='string', id=None), 'label': ClassLabel(num_classes=3, names=['entailment', 'neutral', 'contradiction'], names_file=None, id=None), 'idx': Value(dtype='int32', id=None)}, num_rows: 9815), 'validation_mismatched': Dataset(features: {'premise': Value(dtype='string', id=None), 'hypothesis': Value(dtype='string', id=None), 'label': ClassLabel(num_classes=3, names=['entailment', 'neutral', 'contradict

In [4]:
@dataclass
class Example:
    text_a: str
    text_b: str
    label: int

# to simplify code below, we convert list of dict provided by nlp package to list of Example
train = [Example(text_a=item["premise"], text_b=item["hypothesis"], \
                 label=item["label"]) for item in dataset["train"]]
valid = [Example(text_a=item["premise"], text_b=item["hypothesis"], \
                 label=item["label"]) for item in dataset["validation_matched"]]

### Dynamic padding
On MNLI, shortest sequences are < 20 tokens long, if you set the max length to 512 tokens, you will add 492 pad tokens to those 20 tokens sequences, and then perform computations over those 492 noisy tokens.

Because the learning / gradient descent is performed at the mini batch level, we have the opportunity to limit the padding effect, more precisely we can first search for the longest sequence length in the mini batch, and then pad the other sequences accordingly.

Those operations can be performed in the collate_fn function.

Below, we define a custom Dataset class which doesn't perform any padding (if asked so) and a custom collate_fn (in DataCollator class) which will perform the dynamic padding when possible.

In [5]:
@dataclass
class Features:
    input_ids: List[int]
    attention_mask: List[int]
    label: int


class TextDataset(Dataset):
    def __init__(self, tokenizer: PreTrainedTokenizer, 
                 max_len: int,
                 examples: List[Example]) -> None:
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.examples: List[Example] = examples
        self.current = 0

    def encode(self, ex: Example) -> Features:
        encode_dict = self.tokenizer.encode_plus(text=ex.text_a,
                                                 text_pair=ex.text_b,
                                                 add_special_tokens=True,
                                                 max_length=self.max_len,
                                                 padding = 'max_length',
                                                 return_token_type_ids=False,
                                                 return_attention_mask=True,
                                                 return_overflowing_tokens=False,
                                                 return_special_tokens_mask=False,
                                                 )
        return Features(input_ids=encode_dict["input_ids"],
                        attention_mask=encode_dict["attention_mask"],
                        label=ex.label)

    def __getitem__(self, idx) -> Features:
        return self.encode(ex=self.examples[idx])

    def __len__(self):
        return len(self.examples)


def pad_seq(seq: List[int], max_batch_len: int, pad_value: int) -> List[int]:
    return seq + (max_batch_len - len(seq)) * [pad_value]


@dataclass
class SmartCollator():
    pad_token_id: int

    def __call__(self, batch: List[Features]) -> Dict[str, torch.Tensor]:
        batch_inputs = list()
        batch_attention_masks = list()
        labels = list()
        max_size = max([len(ex.input_ids) for ex in batch])
        for item in batch:
            batch_inputs += [pad_seq(item.input_ids, max_size, self.pad_token_id)]
            batch_attention_masks += [pad_seq(item.attention_mask, max_size, 0)]
            labels.append(item.label)

        return {"input_ids": torch.tensor(batch_inputs, dtype=torch.long),
                "attention_mask": torch.tensor(batch_attention_masks, dtype=torch.long),
                "labels": torch.tensor(labels, dtype=torch.long)
                }

def load_transformers_model(pretrained_model_name_or_path: str,
                            use_cuda: bool,
                            ) -> PreTrainedModel:


    return model

In [None]:
max_sequence_len = 512 # longest sequences are >> 256 tokens, we choose to not apply any truncation.

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="bert-base-cased")

config = AutoConfig.from_pretrained(pretrained_model_name_or_path="bert-base-cased",
                                    num_labels=3)

model = AutoModelForSequenceClassification.from_pretrained(
    pretrained_model_name_or_path="bert-base-cased",
    config=config)

def compute_metrics(p: EvalPrediction) -> Dict:
    preds = np.argmax(p.predictions, axis=1)
    return {"acc": (preds == p.label_ids).mean()}

args = TrainingArguments(output_dir="/tmp/test_dynamic_padding",
                         seed=42,
                         num_train_epochs=1,
                         per_device_train_batch_size=8,  
                         # max batch size without OOM exception, because of the large max token length
                         per_device_eval_batch_size=8,
                         logging_steps=5000,
                         save_steps=0,
                        )

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=433.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=213450.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=435797.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=435779157.0), HTML(value='')))