<a href="https://colab.research.google.com/github/eliranabdoo/variance-regularization/blob/main/variance_regularization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Packages

In [None]:
!pip install torch transformers datasets pydantic seqeval evaluate

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.22.1-py3-none-any.whl (4.9 MB)
[K     |████████████████████████████████| 4.9 MB 9.0 MB/s 
[?25hCollecting datasets
  Downloading datasets-2.4.0-py3-none-any.whl (365 kB)
[K     |████████████████████████████████| 365 kB 44.6 MB/s 
Collecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[K     |████████████████████████████████| 43 kB 1.8 MB/s 
[?25hCollecting evaluate
  Downloading evaluate-0.2.2-py3-none-any.whl (69 kB)
[K     |████████████████████████████████| 69 kB 6.8 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 43.3 MB/s 
Collecting huggingface-hub<1.0,>=0.9.0
  Downloading huggingface_hub-0.9.1-py3-none-any.whl (120 kB)
[K     |████████████████████████████████| 120 kB

In [None]:
%load_ext tensorboard

## Set mountpoint

In [None]:
from google.colab import drive
import os
drive.mount("/content/gdrive")
os.chdir('/content/gdrive/MyDrive/ColabNotebooks/variance-regularization')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


## Config

In [None]:
from pydantic import BaseModel
class Config(BaseModel):
  checkpoint: str = "distilbert-base-uncased"
  dataset: str = "wikiann"
  debug_mode: bool = False
  debug_train_size: int = 1000

config = Config(debug_mode=False)

## Load dataset

In [None]:
from datasets import load_dataset
dataset = load_dataset(config.dataset, "en")
if config.debug_mode:
  for dataset_split_name in dataset:
    dataset[dataset_split_name] = dataset[dataset_split_name].select(range(config.debug_train_size))

Downloading builder script:   0%|          | 0.00/3.94k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/12.6k [00:00<?, ?B/s]

Downloading and preparing dataset wikiann/en (download: 223.17 MiB, generated: 8.88 MiB, post-processed: Unknown size, total: 232.05 MiB) to /root/.cache/huggingface/datasets/wikiann/en/1.1.0/4bfd4fe4468ab78bb6e096968f61fab7a888f44f9d3371c2f3fea7e74a5a354e...


Downloading data:   0%|          | 0.00/234M [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/20000 [00:00<?, ? examples/s]

Dataset wikiann downloaded and prepared to /root/.cache/huggingface/datasets/wikiann/en/1.1.0/4bfd4fe4468ab78bb6e096968f61fab7a888f44f9d3371c2f3fea7e74a5a354e. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

## Load model & tokenizer

In [None]:
from transformers import Trainer, BertForTokenClassification, AutoTokenizer

labels = dataset['train'].features['ner_tags'].feature.names
model = BertForTokenClassification.from_pretrained(config.checkpoint, num_labels=len(labels))
tokenizer = AutoTokenizer.from_pretrained(config.checkpoint)

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


Moving 0 files to the new cache system


0it [00:00, ?it/s]

Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]

You are using a model of type distilbert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.


Downloading:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing BertForTokenClassification: ['distilbert.transformer.layer.4.sa_layer_norm.weight', 'distilbert.transformer.layer.0.output_layer_norm.weight', 'distilbert.transformer.layer.5.ffn.lin2.bias', 'distilbert.transformer.layer.5.attention.k_lin.weight', 'vocab_layer_norm.bias', 'distilbert.transformer.layer.4.attention.v_lin.bias', 'distilbert.transformer.layer.2.attention.k_lin.bias', 'distilbert.transformer.layer.1.output_layer_norm.bias', 'distilbert.transformer.layer.5.ffn.lin2.weight', 'distilbert.transformer.layer.4.attention.v_lin.weight', 'distilbert.transformer.layer.0.ffn.lin1.weight', 'distilbert.transformer.layer.5.ffn.lin1.weight', 'distilbert.embeddings.LayerNorm.bias', 'distilbert.transformer.layer.5.attention.q_lin.weight', 'distilbert.transformer.layer.4.output_layer_norm.weight', 'distilbert.transformer.layer.2.attention.k_lin.weight', 'distilbert.transformer.layer.5.attention.ou

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

## Tokenization

In [None]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # Set the special tokens to -100.
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:  # Only label the first token of a given word.
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [None]:
tokenized_dataset = dataset.map(tokenize_and_align_labels, batched=True)

  0%|          | 0/10 [00:00<?, ?ba/s]

  0%|          | 0/10 [00:00<?, ?ba/s]

  0%|          | 0/20 [00:00<?, ?ba/s]

In [None]:
tokenized_dataset['train'][0]

{'tokens': ['R.H.',
  'Saunders',
  '(',
  'St.',
  'Lawrence',
  'River',
  ')',
  '(',
  '968',
  'MW',
  ')'],
 'ner_tags': [3, 4, 0, 3, 4, 4, 0, 0, 0, 0, 0],
 'langs': ['en', 'en', 'en', 'en', 'en', 'en', 'en', 'en', 'en', 'en', 'en'],
 'spans': ['ORG: R.H. Saunders', 'ORG: St. Lawrence River'],
 'input_ids': [101,
  1054,
  1012,
  1044,
  1012,
  15247,
  1006,
  2358,
  1012,
  5623,
  2314,
  1007,
  1006,
  5986,
  2620,
  12464,
  1007,
  102],
 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 'labels': [-100,
  3,
  -100,
  -100,
  -100,
  4,
  0,
  3,
  -100,
  4,
  4,
  0,
  0,
  0,
  -100,
  0,
  0,
  -100]}

## Loss combiners


In [None]:
def linear_combiner(loss, variance, weight, max_clip=float('inf'), min_clip=float('-inf')):
  variance_loss = -1*variance
  return loss + weight * torch.clip(variance_loss, max=max_clip, min=min_clip)

def inverse_combiner(loss, variance, weight, max_clip=float('inf')):
  return loss + weight/torch.clip(variance, max=max_clip)


loss_combiner_registry = {
    'linear': linear_combiner,
    'inverse': inverse_combiner,
    'identity': lambda loss, variance, **kwargs: loss
}

## Variance loss trainer

In [None]:
from transformers import Trainer, TrainingArguments
import torch
import pdb

class VarianceLossTrainer(Trainer):
  def __init__(self, parameters_groups_getter, loss_combiner_name, loss_combiner_params, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.parameters_groups_getter = parameters_groups_getter
    self.loss_combiner_params = loss_combiner_params
    self.loss_combiner = loss_combiner_registry[loss_combiner_name]
  
  def compute_loss(self, model, inputs, return_outputs=False):
    if not return_outputs:
      loss = super().compute_loss(model, inputs, return_outputs)
    else:
      loss, outputs = super().compute_loss(model, inputs, return_outputs)
    parameters_groups = self.parameters_groups_getter(model)
    cum_variance = 0.0
    for group, parameters in parameters_groups.items():
      stacked = torch.stack(parameters, 0)
      cum_variance = cum_variance + torch.std(stacked, 0).norm()
    loss = self.loss_combiner(loss, cum_variance, **self.loss_combiner_params)
    return (loss, outputs) if return_outputs else loss

In [None]:
def get_self_attention_heads_groups(model):
  res = {}
  named_parameters = model.named_parameters()
  num_attention_heads = model.config.n_heads
  hidden_size = model.config.hidden_size
  chunk_size = int(hidden_size/num_attention_heads)
  for name, parameter in named_parameters:
    if 'self' in name and 'weight' in name:
      res[name] = torch.split(parameter, chunk_size, 1)
  return res

## Load collator

In [None]:
from transformers import DataCollatorForTokenClassification
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, return_tensors="pt")

In [None]:
import numpy as np
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, average_precision_score
import evaluate

metric = evaluate.load("seqeval")
label_names = dataset['train'].features['ner_tags'].feature.names

def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": all_metrics["overall_precision"],
        "recall": all_metrics["overall_recall"],
        "f1": all_metrics["overall_f1"],
        "accuracy": all_metrics["overall_accuracy"],
    }

Downloading builder script:   0%|          | 0.00/6.34k [00:00<?, ?B/s]

## Args Data Models

In [None]:
from datetime import datetime
from transformers import IntervalStrategy

class CoreTrainArgs(BaseModel):
    output_dir: str
    evaluation_strategy: IntervalStrategy
    eval_steps: int
    logging_steps: int
    learning_rate: float
    per_device_train_batch_size: int
    per_device_eval_batch_size: int
    num_train_epochs: int
    weight_decay: float
    metric_for_best_model: str
    load_best_model_at_end: bool

class TrainerArgs(BaseModel):
    loss_combiner_params: dict
    loss_combiner_name: str
    early_stopping_patience: int
    run_name: str
    log_dir: str

    @property
    def run_dir(self):
      return os.path.join(self.log_dir, self.run_name)

class TrainHyperparameters(BaseModel):
    core_train_args: CoreTrainArgs
    trainer_args: TrainerArgs

def create_baseline_hparams(train_hparams) -> TrainHyperparameters:
  res = train_hparams.copy(deep=True)
  res.trainer_args.loss_combiner_name = "identity"
  res.trainer_args.run_name = f"baseline-{res.trainer_args.run_name}"
  res.trainer_args.loss_combiner_params = {}
  return res

## Tensorboard Callback

In [None]:
from transformers.integrations import TensorBoardCallback

import collections

def flatten(d, parent_key='', sep='.'):
    items = []
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, collections.MutableMapping):
            items.extend(flatten(v, new_key, sep=sep).items())
        else:
            if type(v) not in [str, bool, int, float, torch.Tensor]:
                v = str(v)
            items.append((new_key, v))
    return dict(items)

class CustomTensorBoardCallback(TensorBoardCallback):
  def __init__(self, train_hparams: TrainHyperparameters, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.hparams = flatten(train_hparams.dict())

  def on_train_end(self, args, state, control, **kwargs):
    self.tb_writer.add_hparams(
      hparam_dict=self.hparams,
      metric_dict=state.log_history[-1],
      run_name='.'
    )

## Set Hyperparameters

In [None]:


train_hparams = TrainHyperparameters(
    core_train_args=CoreTrainArgs(
        output_dir="./results",
        evaluation_strategy=IntervalStrategy.STEPS,
        eval_steps=250,
        logging_steps=250,
        save_steps=250,
        learning_rate=2e-5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=10,
        weight_decay=1e-5,
        metric_for_best_model='eval_f1',
        load_best_model_at_end=True,
        save_total_limit=1
    ),
    trainer_args=TrainerArgs(
        loss_combiner_params={
            'weight': 0.1
        },
        loss_combiner_name='inverse',
        early_stopping_patience=5,
        run_name=datetime.now().strftime("%Y%m%d-%H%M%S"),
        log_dir="./runs"
    )
)

baseline_train_hparams = create_baseline_hparams(train_hparams)

## Train

In [None]:
from torch.utils.tensorboard.writer import SummaryWriter
from transformers import EarlyStoppingCallback, IntervalStrategy
import os
from datetime import datetime

for hparams in [train_hparams, baseline_train_hparams]:
  trainer_args = hparams.trainer_args
  training_args = TrainingArguments(**hparams.core_train_args.dict())
  tensorboard_writer = SummaryWriter(trainer_args.run_dir)
  trainer = VarianceLossTrainer(
      parameters_groups_getter=get_self_attention_heads_groups,
      loss_combiner_name=trainer_args.loss_combiner_name,
      loss_combiner_params=trainer_args.loss_combiner_params,
      args=training_args,
      model=model,
      train_dataset=tokenized_dataset['train'],
      eval_dataset=tokenized_dataset['validation'],
      data_collator=data_collator,
      tokenizer=tokenizer,
      callbacks=[EarlyStoppingCallback(early_stopping_patience=trainer_args.early_stopping_patience),
                CustomTensorBoardCallback(tb_writer=tensorboard_writer, train_hparams=hparams)
      ],
      compute_metrics=compute_metrics
  )
  trainer.train()

  if __name__ == '__main__':
The following columns in the training set don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: ner_tags, tokens, langs, spans. If ner_tags, tokens, langs, spans are not expected by `BertForTokenClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 20000
  Num Epochs = 10
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 12500
You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss


KeyboardInterrupt: ignored

In [None]:
%tensorboard --logdir_spec "./runs"

In [None]:
# !rm -r runs

In [None]:
! ls runs