<a href="https://colab.research.google.com/github/eliranabdoo/variance-regularization/blob/main/variance_regularization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Packages

In [3]:
!pip install torch transformers datasets pydantic seqeval evaluate 

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.22.1-py3-none-any.whl (4.9 MB)
[K     |████████████████████████████████| 4.9 MB 7.0 MB/s 
[?25hCollecting datasets
  Downloading datasets-2.5.1-py3-none-any.whl (431 kB)
[K     |████████████████████████████████| 431 kB 67.8 MB/s 
Collecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[K     |████████████████████████████████| 43 kB 2.4 MB/s 
[?25hCollecting evaluate
  Downloading evaluate-0.2.2-py3-none-any.whl (69 kB)
[K     |████████████████████████████████| 69 kB 9.7 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 58.6 MB/s 
Collecting huggingface-hub<1.0,>=0.9.0
  Downloading huggingface_hub-0.9.1-py3-none-any.whl (120 kB)
[K     |████████████████████████████████| 120 kB

In [4]:
%load_ext tensorboard

## Set mountpoint

In [5]:
from google.colab import drive
import os
drive.mount("/content/gdrive")
os.chdir('/content/gdrive/MyDrive/ColabNotebooks/variance-regularization')

Mounted at /content/gdrive


## Config

In [6]:
from pydantic import BaseModel
class Config(BaseModel):
  checkpoint: str = "bert-base-uncased"
  dataset: str = "wikiann"
  debug_mode: bool = False
  debug_train_size: int = 1000
  seed: int = 42

config = Config(debug_mode=False)

## Set Seed

In [7]:
import transformers
transformers.set_seed(config.seed)

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


Moving 0 files to the new cache system


0it [00:00, ?it/s]

## Load dataset

In [8]:
from datasets import load_dataset
dataset = load_dataset(config.dataset, "en")
if config.debug_mode:
  for dataset_split_name in dataset:
    dataset[dataset_split_name] = dataset[dataset_split_name].select(range(config.debug_train_size))

Downloading builder script:   0%|          | 0.00/11.6k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/617k [00:00<?, ?B/s]

Downloading and preparing dataset wikiann/en (download: 223.17 MiB, generated: 8.88 MiB, post-processed: Unknown size, total: 232.05 MiB) to /root/.cache/huggingface/datasets/wikiann/en/1.1.0/4bfd4fe4468ab78bb6e096968f61fab7a888f44f9d3371c2f3fea7e74a5a354e...


Downloading data:   0%|          | 0.00/234M [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/20000 [00:00<?, ? examples/s]

Dataset wikiann downloaded and prepared to /root/.cache/huggingface/datasets/wikiann/en/1.1.0/4bfd4fe4468ab78bb6e096968f61fab7a888f44f9d3371c2f3fea7e74a5a354e. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

## Load tokenizer

In [9]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(config.checkpoint)

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

## Tokenization

In [10]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # Set the special tokens to -100.
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:  # Only label the first token of a given word.
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [11]:
tokenized_dataset = dataset.map(tokenize_and_align_labels, batched=True)

  0%|          | 0/10 [00:00<?, ?ba/s]

  0%|          | 0/10 [00:00<?, ?ba/s]

  0%|          | 0/20 [00:00<?, ?ba/s]

In [12]:
dataset['train'].features

{'tokens': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
 'ner_tags': Sequence(feature=ClassLabel(num_classes=7, names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC'], id=None), length=-1, id=None),
 'langs': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
 'spans': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)}

In [13]:
tokenized_dataset['train'][0]

{'tokens': ['R.H.',
  'Saunders',
  '(',
  'St.',
  'Lawrence',
  'River',
  ')',
  '(',
  '968',
  'MW',
  ')'],
 'ner_tags': [3, 4, 0, 3, 4, 4, 0, 0, 0, 0, 0],
 'langs': ['en', 'en', 'en', 'en', 'en', 'en', 'en', 'en', 'en', 'en', 'en'],
 'spans': ['ORG: R.H. Saunders', 'ORG: St. Lawrence River'],
 'input_ids': [101,
  1054,
  1012,
  1044,
  1012,
  15247,
  1006,
  2358,
  1012,
  5623,
  2314,
  1007,
  1006,
  5986,
  2620,
  12464,
  1007,
  102],
 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 'labels': [-100,
  3,
  -100,
  -100,
  -100,
  4,
  0,
  3,
  -100,
  4,
  4,
  0,
  0,
  0,
  -100,
  0,
  0,
  -100]}

## Loss combiners


In [14]:
def linear_combiner(loss, variance, weight, max_clip=float('inf'), min_clip=float('-inf')):
  variance_loss = -1*variance
  return loss + weight * torch.clip(variance_loss, max=max_clip, min=min_clip)

def inverse_combiner(loss, variance, weight, max_clip=float('inf')):
  variance_loss = 1/variance
  return loss + weight * torch.clip(variance_loss, max=max_clip)


loss_combiner_registry = {
    'linear': linear_combiner,
    'inverse': inverse_combiner,
    'identity': lambda loss, variance, **kwargs: loss
}

## Variance loss trainer

In [15]:
from transformers import Trainer, TrainingArguments
import torch
import pdb
from torch.linalg import matrix_norm, vector_norm

class VarianceLossTrainer(Trainer):
  def __init__(self, parameters_groups_getter, loss_combiner_name, loss_combiner_params, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.parameters_groups_getter = parameters_groups_getter
    self.loss_combiner_params = loss_combiner_params
    self.loss_combiner = loss_combiner_registry[loss_combiner_name]
    self.last_mean_std = None
    self.last_clean_loss = None
    self.last_combined_loss = None

  
  def compute_loss(self, model, inputs, return_outputs=False):
    if not return_outputs:
      loss = super().compute_loss(model, inputs, return_outputs)
    else:
      loss, outputs = super().compute_loss(model, inputs, return_outputs)

    parameters_groups = self.parameters_groups_getter(model)
    total_weight = torch.tensor(0.0, requires_grad=True)
    weighted_std_sum = torch.tensor(0.0, requires_grad=True)

    for group, parameters in parameters_groups.items():
      stacked = torch.stack(parameters, 0)
      std_mat = torch.std(stacked, 0)
      try:
        std = matrix_norm(std_mat)
      except RuntimeError:
        std = vector_norm(std_mat)
      curr_weight = torch.numel(stacked)
      total_weight = total_weight + curr_weight
      weighted_std_sum = weighted_std_sum + std * curr_weight
    mean_std = weighted_std_sum / total_weight

    self.last_mean_std = mean_std.item()
    self.last_clean_loss = loss.item()
    loss = self.loss_combiner(loss, mean_std, **self.loss_combiner_params)
    self.last_combined_loss = loss.item()
    return (loss, outputs) if return_outputs else loss
  
  def log(self, logs) -> None:
    logs['eval_mean-std'] = self.last_mean_std
    logs['eval_clean-loss'] = self.last_clean_loss
    logs['eval_combined-loss'] = self.last_combined_loss
    super().log(logs)

In [16]:
def get_self_attention_heads_groups(model):
  res = {}
  named_parameters = model.named_parameters()
  for k in ['n_heads', 'num_attention_heads']:
    if hasattr(model.config, k):
      num_attention_heads = getattr(model.config, k)
  hidden_size = model.config.hidden_size
  chunk_size = int(hidden_size/num_attention_heads)
  for name, parameter in named_parameters:
    if 'attention' in name and any([k in name for k in ['q_', 'query', 'k_', 'key', 'v_', 'value']]) and ('weight' in name or 'bias' in name):
      last_dim = 1 if 'weight' in name else 0  # weight | bias
      res[name] = torch.split(parameter, chunk_size, last_dim)
  return res

In [2]:
model = AutoModelForTokenClassification.from_pretrained(config.checkpoint, num_labels=len(labels))
{k: len(g) for k,g in get_self_attention_heads_groups(model).items()}

NameError: ignored

## Load collator

In [60]:
from transformers import DataCollatorForTokenClassification
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, return_tensors="pt")

In [61]:
import numpy as np
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, average_precision_score
import evaluate

metric = evaluate.load("seqeval")
label_names = dataset['train'].features['ner_tags'].feature.names

def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": all_metrics["overall_precision"],
        "recall": all_metrics["overall_recall"],
        "f1": all_metrics["overall_f1"],
        "accuracy": all_metrics["overall_accuracy"],
    }

## Args Data Models

In [62]:
from datetime import datetime
from transformers import IntervalStrategy

class CoreTrainArgs(BaseModel):
    output_dir: str
    evaluation_strategy: IntervalStrategy
    eval_steps: int
    logging_steps: int
    save_steps: int
    save_strategy: IntervalStrategy
    learning_rate: float
    per_device_train_batch_size: int
    per_device_eval_batch_size: int
    num_train_epochs: int
    weight_decay: float
    metric_for_best_model: str
    load_best_model_at_end: bool
    save_total_limit: int

class TrainerArgs(BaseModel):
    loss_combiner_params: dict
    loss_combiner_name: str
    early_stopping_patience: int

class TrainHyperparameters(BaseModel):
    core_train_args: CoreTrainArgs
    trainer_args: TrainerArgs

def create_baseline_hparams(train_hparams) -> TrainHyperparameters:
  res = train_hparams.copy(deep=True)
  res.trainer_args.loss_combiner_name = "identity"
  res.trainer_args.loss_combiner_params = {}
  res.core_train_args.output_dir = f"{res.core_train_args.output_dir}-baseline"
  return res

## Tensorboard Callback

In [63]:
from transformers.integrations import TensorBoardCallback

import collections

def flatten(d, parent_key='', sep='.'):
    items = []
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, collections.MutableMapping):
            items.extend(flatten(v, new_key, sep=sep).items())
        else:
            if type(v) not in [str, bool, int, float, torch.Tensor]:
                v = str(v)
            items.append((new_key, v))
    return dict(items)

class CustomTensorBoardCallback(TensorBoardCallback):
  def __init__(self, train_hparams: TrainHyperparameters, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.hparams = flatten(train_hparams.dict())

  def on_train_end(self, args, state, control, **kwargs):
    self.tb_writer.add_hparams(
      hparam_dict=self.hparams,
      metric_dict=state.log_history[-1],
      run_name='.'
    )

## Set Hyperparameters

In [64]:
import math
batch_size = 16
num_epochs = 10
num_steps = math.ceil((num_epochs * len(tokenized_dataset['train'])) // batch_size)
eval_steps = 250
save_steps = num_steps - (num_steps % eval_steps)

run_name = datetime.now().strftime("%Y%m%d-%H%M%S")
output_dir = "./results/runs" 

train_hparams = TrainHyperparameters(
    core_train_args=CoreTrainArgs(
        output_dir=os.path.join(output_dir, run_name),
        evaluation_strategy=IntervalStrategy.STEPS,
        eval_steps=250,
        logging_steps=250,
        save_strategy=IntervalStrategy.STEPS,
        save_steps=save_steps,
        learning_rate=2e-5,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=32,
        num_train_epochs=num_epochs,
        weight_decay=1e-5,
        metric_for_best_model='eval_f1',
        load_best_model_at_end=True,
        save_total_limit=2
    ),
    trainer_args=TrainerArgs(
        loss_combiner_params={
            'weight': 10
        },
        loss_combiner_name='inverse',
        early_stopping_patience=20
    )
)

## Create Grid

In [65]:
hp_grid = [train_hparams, create_baseline_hparams(train_hparams)]

## Train

In [68]:
from torch.utils.tensorboard.writer import SummaryWriter
from transformers import EarlyStoppingCallback, IntervalStrategy
import os
from datetime import datetime
from transformers import Trainer, AutoModelForTokenClassification

labels = dataset['train'].features['ner_tags'].feature.names
for hparams in hp_grid:
  model = AutoModelForTokenClassification.from_pretrained(config.checkpoint, num_labels=len(labels))
  trainer_args = hparams.trainer_args
  core_train_args = hparams.core_train_args
  training_args = TrainingArguments(**core_train_args.dict())
  tensorboard_writer = SummaryWriter(core_train_args.output_dir)
  trainer = VarianceLossTrainer(
      parameters_groups_getter=get_self_attention_heads_groups,
      loss_combiner_name=trainer_args.loss_combiner_name,
      loss_combiner_params=trainer_args.loss_combiner_params,
      args=training_args,
      model=model,
      train_dataset=tokenized_dataset['train'],
      eval_dataset=tokenized_dataset['validation'],
      data_collator=data_collator,
      tokenizer=tokenizer,
      callbacks=[EarlyStoppingCallback(early_stopping_patience=trainer_args.early_stopping_patience),
                CustomTensorBoardCallback(tb_writer=tensorboard_writer, train_hparams=hparams)
      ],
      compute_metrics=compute_metrics
  )
  trainer.train()

loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--bert-base-uncased/snapshots/bdb420bf56ef3f72ee07cd75ab6df1b765b6012a/config.json
Model config BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2",
    "3": "LABEL_3",
    "4": "LABEL_4",
    "5": "LABEL_5",
    "6": "LABEL_6"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2,
    "LABEL_3": 3,
    "LABEL_4": 4,
    "LABEL_5": 5,
    "LABEL_6": 6
  },
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "positi

Step,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy,Mean-std,Clean-loss,Combined-loss
250,1.7312,1.461306,0.706307,0.754489,0.729603,0.897375,8.945827,0.585998,1.703838
500,1.4518,1.404267,0.7599,0.812597,0.785365,0.916013,8.955233,0.591123,1.707788
750,1.4353,1.374985,0.777495,0.82009,0.798225,0.920806,8.965652,0.543541,1.658909
1000,1.3884,1.368208,0.799876,0.818818,0.809236,0.922966,8.976004,0.469162,1.583243
1250,1.378,1.373515,0.808213,0.826382,0.817197,0.925189,8.986821,0.564336,1.677076
1500,1.3112,1.374591,0.812496,0.830129,0.821218,0.92612,8.999887,0.368182,1.479307
1750,1.3196,1.351322,0.813997,0.836208,0.824953,0.929584,9.012965,0.481476,1.590988
2000,1.3135,1.349764,0.815134,0.842217,0.828454,0.931981,9.024515,0.443242,1.551335
2250,1.2942,1.362373,0.798475,0.821504,0.809826,0.924394,9.036142,0.375156,1.481823
2500,1.3102,1.341957,0.805487,0.836349,0.820628,0.930193,9.046973,0.476748,1.58209


The following columns in the evaluation set don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: tokens, langs, spans, ner_tags. If tokens, langs, spans, ner_tags are not expected by `BertForTokenClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10000
  Batch size = 32
The following columns in the evaluation set don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: tokens, langs, spans, ner_tags. If tokens, langs, spans, ner_tags are not expected by `BertForTokenClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10000
  Batch size = 32
The following columns in the evaluation set don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: tokens, langs, spans, ner_tags. If tokens, langs, spans, ner_tags are not expected by `BertForTokenClassi

Step,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy,Mean-std,Clean-loss,Combined-loss
250,0.6156,0.343073,0.7148,0.757953,0.735744,0.898219,8.93574,0.549429,0.549429
500,0.3309,0.288188,0.770054,0.813658,0.791256,0.916882,8.93589,0.55863,0.55863
750,0.3191,0.263298,0.773364,0.818959,0.795509,0.920632,8.936036,0.525846,0.525846
1000,0.2801,0.255999,0.79608,0.818253,0.807014,0.924344,8.936071,0.483973,0.483973
1250,0.2644,0.263157,0.805594,0.822565,0.813991,0.924692,8.936203,0.517393,0.517393
1500,0.2011,0.256453,0.811202,0.828291,0.819657,0.926642,8.936455,0.414479,0.414479
1750,0.2096,0.245114,0.808812,0.83437,0.821393,0.928901,8.936672,0.464581,0.464581
2000,0.2065,0.246641,0.812843,0.83847,0.825458,0.930888,8.936867,0.455305,0.455305
2250,0.1888,0.255769,0.80525,0.821928,0.813504,0.925673,8.93704,0.410516,0.410516
2500,0.2034,0.241965,0.802907,0.835713,0.818982,0.9291,8.937192,0.436775,0.436775


The following columns in the evaluation set don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: tokens, langs, spans, ner_tags. If tokens, langs, spans, ner_tags are not expected by `BertForTokenClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10000
  Batch size = 32
The following columns in the evaluation set don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: tokens, langs, spans, ner_tags. If tokens, langs, spans, ner_tags are not expected by `BertForTokenClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10000
  Batch size = 32
The following columns in the evaluation set don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: tokens, langs, spans, ner_tags. If tokens, langs, spans, ner_tags are not expected by `BertForTokenClassi

Step,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy,Mean-std,Clean-loss,Combined-loss
250,0.6156,0.343073,0.7148,0.757953,0.735744,0.898219,8.93574,0.549429,0.549429
500,0.3309,0.288188,0.770054,0.813658,0.791256,0.916882,8.93589,0.55863,0.55863
750,0.3191,0.263298,0.773364,0.818959,0.795509,0.920632,8.936036,0.525846,0.525846
1000,0.2801,0.255999,0.79608,0.818253,0.807014,0.924344,8.936071,0.483973,0.483973
1250,0.2644,0.263157,0.805594,0.822565,0.813991,0.924692,8.936203,0.517393,0.517393
1500,0.2011,0.256453,0.811202,0.828291,0.819657,0.926642,8.936455,0.414479,0.414479
1750,0.2096,0.245114,0.808812,0.83437,0.821393,0.928901,8.936672,0.464581,0.464581
2000,0.2065,0.246641,0.812843,0.83847,0.825458,0.930888,8.936867,0.455305,0.455305
2250,0.1888,0.255769,0.80525,0.821928,0.813504,0.925673,8.93704,0.410516,0.410516
2500,0.2034,0.241965,0.802907,0.835713,0.818982,0.9291,8.937192,0.436775,0.436775


The following columns in the evaluation set don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: tokens, langs, spans, ner_tags. If tokens, langs, spans, ner_tags are not expected by `BertForTokenClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10000
  Batch size = 32
The following columns in the evaluation set don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: tokens, langs, spans, ner_tags. If tokens, langs, spans, ner_tags are not expected by `BertForTokenClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 10000
  Batch size = 32
The following columns in the evaluation set don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: tokens, langs, spans, ner_tags. If tokens, langs, spans, ner_tags are not expected by `BertForTokenClassi

In [None]:
%tensorboard --logdir_spec "./results/runs"

In [None]:
# !rm -r ./results

In [None]:
! ls

In [None]:
import allennlp.training