## Finetune 🤗 Transformers Models with PyTorch Lightning ⚡
  * https://nbviewer.jupyter.org/github/PyTorchLightning/pytorch-lightning/blob/master/notebooks/04-transformers-text-classification.ipynb

In [1]:
!pip install pytorch-lightning datasets transformers --quiet
import os, time
os.environ['CURRENT_FILE'] = 'GLUE-pl.ipynb'
!date "+[%F %R:%S] [INIT] $CURRENT_FILE (on $CONDA_DEFAULT_ENV)"
t0 = time.time()

from argparse import ArgumentParser
from datetime import datetime
from typing import Optional
import torch
import pytorch_lightning as pl
import datasets
from torch.utils.data import DataLoader
from transformers import (
    AdamW,
    AutoModelForSequenceClassification,
    AutoConfig,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)

class GLUELightningData(pl.LightningDataModule):
    task_text_field_map = {
        'cola': ['sentence'],
        'mrpc': ['sentence1', 'sentence2'],
    }
    glue_task_num_labels = {
        'cola': 2,
        'mrpc': 2,
    }
    loader_columns = [
        'datasets_idx',
        'input_ids',
        'token_type_ids',
        'attention_mask',
        'start_positions',
        'end_positions',
        'labels'
    ]

    def __init__(self,
        task_name: str, transformer_name: str, max_seq_length: int = 128,
        train_batch_size: int = 32, eval_batch_size: int = 32):
        super().__init__()
        self.task_name = task_name
        self.transformer_name = transformer_name
        self.max_seq_length = max_seq_length
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size

        self.text_fields = self.task_text_field_map[task_name]
        self.num_labels = self.glue_task_num_labels[task_name]
        self.tokenizer = AutoTokenizer.from_pretrained(self.transformer_name, use_fast=True)

    def prepare_data(self):
        datasets.load_dataset('glue', self.task_name)
        AutoTokenizer.from_pretrained(self.transformer_name, use_fast=True)

    def setup(self, stage: Optional[str] = None):
        self.dataset = datasets.load_dataset('glue', self.task_name)
        for split in self.dataset.keys():
            self.dataset[split] = self.dataset[split].map(self.to_features,
                batched=True, remove_columns=['label'])
            self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
            self.dataset[split].set_format(type="torch", columns=self.columns)
        self.eval_splits = [x for x in self.dataset.keys() if 'validation' in x]

    def to_features(self, example_batch, indices=None):
        if len(self.text_fields) > 1:
            texts = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
        else:
            texts = example_batch[self.text_fields[0]]
        features = self.tokenizer.batch_encode_plus(texts,
            max_length=self.max_seq_length, pad_to_max_length=True, truncation=True)
        features['labels'] = example_batch['label']
        return features

    def train_dataloader(self):
        return DataLoader(self.dataset['train'], batch_size=self.train_batch_size)

    def val_dataloader(self):
        if len(self.eval_splits) == 1:
            return DataLoader(self.dataset['validation'], batch_size=self.eval_batch_size)
        elif len(self.eval_splits) > 1:
            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]

    def test_dataloader(self):
        if len(self.eval_splits) == 1:
            return DataLoader(self.dataset['test'], batch_size=self.eval_batch_size)
        elif len(self.eval_splits) > 1:
            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]

class GLUELightning(pl.LightningModule):
    def __init__(self,
        transformer_name: str,
        num_labels: int,
        learning_rate: float = 2e-5,
        adam_epsilon: float = 1e-8,
        warmup_steps: int = 0,
        weight_decay: float = 0.0,
        train_batch_size: int = 32,
        eval_batch_size: int = 32,
        eval_splits: Optional[list] = None,
        **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.config = AutoConfig.from_pretrained(transformer_name, num_labels=num_labels)
        self.model = AutoModelForSequenceClassification.from_pretrained(transformer_name, config=self.config)
        self.metric = datasets.load_metric('glue', self.hparams.task_name,
            experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
        )

    def forward(self, **inputs):
        return self.model(**inputs)

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs[0]
        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        outputs = self(**batch)
        val_loss, logits = outputs[:2]
        if self.hparams.num_labels >= 1:
            preds = torch.argmax(logits, axis=1)
        else:
            preds = logits.squeeze()
        labels = batch["labels"]
        return {'loss': val_loss, "preds": preds, "labels": labels}

    def validation_epoch_end(self, outputs):
        preds = torch.cat([x['preds'] for x in outputs]).detach().cpu().numpy()
        labels = torch.cat([x['labels'] for x in outputs]).detach().cpu().numpy()
        loss = torch.stack([x['loss'] for x in outputs]).mean()
        self.log('val_loss', loss, prog_bar=True)
        self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)
        return loss

    def setup(self, stage):
        if stage == 'fit':
            train_loader = self.train_dataloader()
            self.total_steps = (
                (len(train_loader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
                // self.hparams.accumulate_grad_batches * float(self.hparams.max_epochs)
            )

    def configure_optimizers(self):
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
        )
        scheduler = {
            'scheduler': scheduler,
            'interval': 'step',
            'frequency': 1
        }
        return [optimizer], [scheduler]

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument("--learning_rate", default=2e-5, type=float)
        parser.add_argument("--adam_epsilon", default=1e-8, type=float)
        parser.add_argument("--warmup_steps", default=0, type=int)
        parser.add_argument("--weight_decay", default=0.0, type=float)
        return parser

[2020-10-28 23:57:21] [INIT] GLUE-pl.ipynb (on lightn)


In [2]:
def parse_args(args=None):
    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = GLUELightningData.add_argparse_args(parser)
    parser = GLUELightning.add_model_specific_args(parser)
    parser.add_argument('--seed', type=int, default=42)
    return parser.parse_args(args)

def main(args):
    pl.seed_everything(args.seed)
    dm = GLUELightningData.from_argparse_args(args)
    dm.prepare_data()
    dm.setup('fit')
    model = GLUELightning(num_labels=dm.num_labels, eval_splits=dm.eval_splits, **vars(args))
    trainer = pl.Trainer.from_argparse_args(args)
    return dm, model, trainer

dm, model, trainer = main(parse_args("""--gpus 1 --max_epochs 3
    --task_name mrpc --transformer_name distilbert-base-cased""".split()))
trainer.fit(model, dm)

Reusing dataset glue (/home/chris/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)
Reusing dataset glue (/home/chris/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)
Some weights of the model checkpoint at distilbert-base-cased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a Bert

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4.0), HTML(value='')))







HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2.0), HTML(value='')))

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

1

In [3]:
os.environ['ELASPED_TIME'] = f"{time.time() - t0:.3f}s"
!date "+[%F %R:%S] [EXIT] $CURRENT_FILE (on $CONDA_DEFAULT_ENV) (in $ELASPED_TIME)"


[2020-10-28 23:58:45] [EXIT] GLUE-pl.ipynb (on lightn) (in 83.380s)
