## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import math
from pprint import pprint
from time import perf_counter

import sys
sys.path.append('../nn_pruning')
from nn_pruning.sparse_trainer import SparseTrainer
from nn_pruning.sparse_xp import SparseXP
from nn_pruning.patch_coordinator import SparseTrainingArguments
from nn_pruning.examples.xp import XP, DataTrainingArguments, ModelArguments, XPTrainingArguments, XPTrainer

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import pandas as pd
import datasets
import transformers
datasets.logging.set_verbosity_error()
transformers.logging.set_verbosity_error()

from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using transformers v{transformers.__version__} and datasets v{datasets.__version__}")
print(f"Running on device: {device}")

Using transformers v4.1.1 and datasets v1.2.1
Running on device: cuda


## Load and inspect data

In [None]:
boolq = load_dataset("super_glue", "boolq")
boolq

DatasetDict({
    train: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 9427
    })
    validation: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 3270
    })
    test: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 3245
    })
})

In [None]:
boolq.rename_column_("label", "labels")

In [None]:
boolq['train'][-1]

{'idx': 9426,
 'labels': 0,
 'passage': "Margin of error -- The margin of error is usually defined as the ``radius'' (or half the width) of a confidence interval for a particular statistic from a survey. One example is the percent of people who prefer product A versus product B. When a single, global margin of error is reported for a survey, it refers to the maximum margin of error for all reported percentages using the full sample from the survey. If the statistic is a percentage, this maximum margin of error can be calculated as the radius of the confidence interval for a reported percentage of 50%.",
 'question': 'is margin of error the same as confidence interval'}

## Time model

In [None]:
bert_ckpt = "lewtun/bert-base-uncased-finetuned-boolq"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_ckpt)
bert_model = AutoModelForSequenceClassification.from_pretrained(bert_ckpt).to("cpu")

In [None]:
def compute_latency(
    model,
    tokenizer,
    question="is margin of error the same as confidence interval",
    passage="Margin of error -- The margin of error is usually defined as the ``radius'' (or half the width) of a confidence interval for a particular statistic from a survey. One example is the percent of people who prefer product A versus product B. When a single, global margin of error is reported for a survey, it refers to the maximum margin of error for all reported percentages using the full sample from the survey. If the statistic is a percentage, this maximum margin of error can be calculated as the radius of the confidence interval for a reported percentage of 50%."
):
    inputs = tokenizer(question, passage, truncation="only_second", return_tensors='pt')
    latencies = []
    # Warmup
    for _ in range(10):
        _ = model(**inputs)
    # Timed run
    for _ in range(100):
        start_time = perf_counter()
        _ = model(**inputs)
        latency = perf_counter() - start_time
        latencies.append(latency)
    # Compute run statistics
    time_avg_ms = 1000 * np.mean(latencies)
    time_std_ms = 1000 * np.std(latencies)
    print(f"Average latency (ms) - {time_avg_ms:.2f} +\- {time_std_ms:.2f}")
    return {"time_avg_ms": time_avg_ms, "time_std_ms": time_std_ms}

In [None]:
compute_latency(bert_model, bert_tokenizer)

Average latency (ms) - 441.52 +\- 81.77


{'time_avg_ms': 441.52082815766335, 'time_std_ms': 81.7672099126243}

## Tokenize and encode

In [None]:
bert_ckpt = "bert-base-uncased"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_ckpt)

In [None]:
def tokenize_and_encode(x, tokenizer): 
    return tokenizer(x['question'], x['passage'], truncation="only_second")

boolq_enc = boolq.map(tokenize_and_encode, fn_kwargs={'tokenizer' : bert_tokenizer}, batched=True)

## Metrics

In [None]:
accuracy_score = load_metric('accuracy')

In [None]:
def compute_metrics(pred):
    predictions, labels = pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy_score.compute(predictions=predictions, references=labels)

## Create pruning Trainer

Use question-answering as template:

```python
import nn_pruning.examples.question_answering.qa_sparse_xp as qa_sparse_xp

qa = qa_sparse_xp.QASparseXP(param_dict)
qa.run()
```

So we need something equivalent to `QASparseXP`. 

> The question is - what does "XP" stand for and do we really need `SparseXP` to do movement-pruning or is this something special needed for collecting metrics etc?

Now `QASparseXP` looks like

```python
class QASparseXP(SparseXP, QAXP):
    ARGUMENTS = {
        "model": ModelArguments,
        "data": QADataTrainingArguments,
        "training": XPTrainingArguments,
        "sparse": SparseTrainingArguments,
    }
    QA_TRAINER_CLASS = QASparseTrainer
    SHORT_NAMER = SparseQAShortNamer
    CONSTRUCTOR = AutoModelForQuestionAnswering
    LOGIT_NAMES = ["start_logits", "end_logits"]

    def __init__(self, params):
        QAXP.__init__(self, params)
        SparseXP.__init__(self)

    def create_trainer(self, *args, **kwargs):
        super().create_trainer(*args, **kwargs)
        SparseXP.setup_trainer(self)

    @classmethod
    def final_finetune(cls, src_path, dest_path, teacher):
        param_dict = {
            "model_name_or_path": src_path,
            "dataset_name": "squad",
            "do_train": 1,
            "do_eval": 1,
            "per_device_train_batch_size": 16,
            "per_device_eval_batch_size": 128,
            "max_seq_length": 384,
            "doc_stride": 128,
            "num_train_epochs": 10,
            "logging_steps": 250,
            "save_steps": 2500,
            "eval_steps": 2500,
            "save_total_limit": 50,
            "seed": 17,
            "evaluation_strategy": "steps",
            "learning_rate": 3e-5,
            "output_dir": dest_path,
            "logging_dir": dest_path,
            "overwrite_cache": 0,
            "overwrite_output_dir": 1,
            "warmup_steps": 10,
            "initial_warmup": 0,
            "final_warmup": 0,
            "regularization": "",
            "regularization_final_lambda": 0,
            "distil_teacher_name_or_path": teacher,
            "distil_alpha_ce": 0.1,
            "distil_alpha_teacher": 0.9,
            "final_finetune": 1,
            "attention_output_with_dense": 0,
        }

        qa = cls(param_dict)
        qa.run()

        cls.fix_last_checkpoint_bug(dest_path)
```

so this suggests we also need equivalents of:

* `QAXP`
* `QADataTrainingArguments`
* `XPTrainingArguments`
* `QASparseTrainer`
* `SparseQAShortNamer`

`QAXP` is a subclass of `XP` with much of the pre- and post-processing functions needed for QA:

```python
class QAXP(XP):
    ARGUMENTS = {
        "model": ModelArguments,
        "data": QADataTrainingArguments,
        "training": XPTrainingArguments,
    }
    QA_TRAINER_CLASS = QATrainer
    SHORT_NAMER = TrialShortNamer

    @classmethod
    def _model_init(self, model_args, model_config):
        model = AutoModelForQuestionAnswering.from_pretrained(
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=model_config,
            cache_dir=model_args.cache_dir,
        )
        return model
    ...
```

Some remarks:

* `XP` is a base class with a lot of methods like `create_trainer` which need implementing in the subclass like `QAXP`
* `XPTrainingArguments` is just a subclass of `TrainingArguments`
* `QADataTrainingArguments` is a subclass of `DataTrainingArguments` where the latter controls QA pre- and post-processing args
* `QASparseTrainer`is a mixin:

    ```python
    class QASparseTrainer(SparseTrainer, QATrainer):
        def __init__(self, sparse_args, *args, **kwargs):
            QATrainer.__init__(self, *args, **kwargs)
            SparseTrainer.__init__(self, sparse_args)
    ```
    
    
* `SparseQAShortNamer` is a subclass of `TrialShortNamer` and seems to just collect hyperparameters, presumably for Optuna search

### SparseXP

In [None]:
class MyXP(XP):
#     ARGUMENTS = {
#         "model": ModelArguments,
#         "data": GlueDataTrainingArguments,
#         "training": XPTrainingArguments,
#     }
    MY_TRAINER_CLASS = Trainer
#     SHORT_NAMER = TrialShortNamer

    @classmethod
    def _model_init(cls, model_args, model_config):
        model = AutoModelForSequenceClassification.from_pretrained(
            model_args.model_name_or_path,
            config=model_config,
        )
        return model

In [None]:
class MyTrainer(SparseTrainer, Trainer):
    def __init__(self, sparse_args, *args, **kwargs):
        Trainer.__init__(self, *args, **kwargs)
        SparseTrainer.__init__(self, sparse_args)
        
    def compute_loss(self, model, inputs):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        outputs = model(**inputs)

        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        # We don't use .loss here since the model may return tuples instead of ModelOutput.
        loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
#         print("loss", loss)

        self.metrics["ce_loss"] += float(loss)
#         loss, distil_loss = self.patch_coordinator.distil_loss_combine(loss, inputs, outputs)
#         self.metrics["distil_loss"] += float(distil_loss)
#         regu_loss, lamb, info = self.patch_coordinator.regularization_loss(model)

#         for kind, values in info.items():
#             if kind == "total":
#                 suffix = ""
#             else:
#                 suffix = "_" + kind

#             for k, v in values.items():
#                 self.metrics[k + suffix] += float(v)

        self.loss_counter += 1

#         loss = loss + regu_loss * lamb

        return loss

In [None]:
sparse_args

SparseTrainingArguments(mask_scores_learning_rate=0.01, dense_pruning_method='topK', attention_pruning_method='topK', ampere_pruning_method='disabled', attention_output_with_dense=True, bias_mask=True, mask_init='constant', mask_scale=0.0, dense_block_rows=1, dense_block_cols=1, attention_block_rows=1, attention_block_cols=1, initial_threshold=1.0, final_threshold=0.5, initial_warmup=1, final_warmup=2, initial_ampere_temperature=0.0, final_ampere_temperature=20.0, regularization='disabled', regularization_final_lambda=0.0, attention_lambda=1.0, dense_lambda=1.0, distil_teacher_name_or_path=None, distil_alpha_ce=0.5, distil_alpha_teacher=0.5, distil_temperature=2.0, final_finetune=False, layer_norm_patch=False, gelu_patch=False)

In [None]:
sparse_args = SparseTrainingArguments()

bert_model = AutoModelForSequenceClassification.from_pretrained(bert_ckpt).to(device)

batch_size = 4
learning_rate = 2e-5
num_train_epochs = 3
logging_steps = len(boolq_enc['train']) // batch_size

args = TrainingArguments(
    output_dir='checkpoints',
    evaluation_strategy='epoch',
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=learning_rate,
    weight_decay=0.01,
    logging_steps=logging_steps,
    disable_tqdm=False,
#     report_to=None
)

In [None]:
from nn_pruning.patch_coordinator import ModelPatchingCoordinator

In [None]:
mpc = ModelPatchingCoordinator(sparse_args, device, "checkpoints", "logits", AutoModelForSequenceClassification)

In [None]:
trainer = MyTrainer(
    sparse_args=sparse_args,
    args=args,
    model=bert_model,
    train_dataset=boolq_enc['train'],
    eval_dataset=boolq_enc['validation'],
    tokenizer=bert_tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.set_patch_coordinator(mpc)

In [None]:
trainer.evaluate()

{'eval_loss': 0.6939059495925903,
 'eval_accuracy': 0.6214067278287462,
 'eval_threshold': 0.5,
 'eval_ampere_temperature': 20.0,
 'eval_regu_lambda': 0.0}

In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Threshold,Ampere Temperature,Regu Lambda,Loss
1,0.617738,0.599703,0.7263,0.5,20.0,0.0,0.942556
2,0.530716,0.891712,0.766972,0.5,20.0,0.0,0.401304
3,0.331953,1.161491,0.759633,0.5,20.0,0.0,0.496784


TrainOutput(global_step=7071, training_loss=0.4934706015931449)

In [None]:
trainer.save_model("models/bert-base-uncased-finetuned-boolq-pruned")

In [None]:
import mlflow
mlflow.end_run()

## Speed test

In [None]:
from nn_pruning.inference_model_patcher import optimize_model


In [None]:
pruned_model = optimize_model(trainer.model, "dense")


removed heads 0, total_heads=144, percentage removed=0.0
bert.encoder.layer.0.intermediate.dense, sparsity = 0.00
bert.encoder.layer.0.output.dense, sparsity = 0.00
bert.encoder.layer.1.intermediate.dense, sparsity = 0.00
bert.encoder.layer.1.output.dense, sparsity = 0.00
bert.encoder.layer.2.intermediate.dense, sparsity = 0.00
bert.encoder.layer.2.output.dense, sparsity = 0.00
bert.encoder.layer.3.intermediate.dense, sparsity = 0.00
bert.encoder.layer.3.output.dense, sparsity = 0.00
bert.encoder.layer.4.intermediate.dense, sparsity = 0.00
bert.encoder.layer.4.output.dense, sparsity = 0.00
bert.encoder.layer.5.intermediate.dense, sparsity = 0.00
bert.encoder.layer.5.output.dense, sparsity = 0.00
bert.encoder.layer.6.intermediate.dense, sparsity = 0.00
bert.encoder.layer.6.output.dense, sparsity = 0.00
bert.encoder.layer.7.intermediate.dense, sparsity = 0.00
bert.encoder.layer.7.output.dense, sparsity = 0.00
bert.encoder.layer.8.intermediate.dense, sparsity = 0.00
bert.encoder.layer.8.o

In [None]:
compute_latency(pruned_model.to("cpu"), bert_tokenizer)

Average latency (ms) - 416.89 +\- 83.08


{'time_avg_ms': 416.89285065978765, 'time_std_ms': 83.08497178819044}

In [None]:
class MySparseXP(SparseXP, MyXP):
    ARGUMENTS = {
        "model": ModelArguments,
        "data": DataTrainingArguments,
        "training": XPTrainingArguments,
        "sparse": SparseTrainingArguments,
    }
    MY_TRAINER_CLASS = MyTrainer
#     SHORT_NAMER = MySparseShortNamer
    CONSTRUCTOR = AutoModelForSequenceClassification
    LOGIT_NAMES = ["logits"]

    def __init__(self, params):
        MyXP.__init__(self, params)
        SparseXP.__init__(self)

    def create_trainer(self, *args, **kwargs):
        super().create_trainer(*args, **kwargs)
        SparseXP.setup_trainer(self)

In [None]:
params = {
  "model_name_or_path": "bert-base-uncased",
  "dataset_name": "super_glue",
  "dataset_cache_dir": "dataset_cache_dir",
  "do_train": 1,
  "do_eval": 1,
  "per_device_train_batch_size": 32,
  "per_device_eval_batch_size": 128,
  "max_seq_length": 128,
  "doc_stride": 128,
  "num_train_epochs": 12,
  "logging_steps": 250,
  "save_steps": 5000,
  "eval_steps": 5000,
  "save_total_limit": 50,
  "seed": 17,
  "evaluation_strategy": "steps",
  "learning_rate": 3e-5,
  "mask_scores_learning_rate": 1e-2,
  "output_dir": "output/mnli_test2/",
  "logging_dir": "output/mnli_test2/",
  "overwrite_cache": 0,
  "overwrite_output_dir": 1,
  "warmup_steps": 12000,
  "initial_warmup": 1,
  "final_warmup": 4,
  "initial_threshold": 0,
  "final_threshold": 0.1,
  "dense_pruning_method": "sigmoied_threshold:1d_alt",
  "dense_block_rows":1,
  "dense_block_cols":1,
  "dense_lambda":1.0,
  "attention_pruning_method": "sigmoied_threshold",
  "attention_block_rows":32,
  "attention_block_cols":32,
  "attention_lambda":1.0,
  "ampere_pruning_method": "disabled",
  "mask_init": "constant",
  "mask_scale": 0.0,
  "regularization": "l1",
  "regularization_final_lambda": 20,
  "distil_teacher_name_or_path":"aloxatel/bert-base-mnli",
  "distil_alpha_ce": 0.1,
  "distil_alpha_teacher": 0.90,
  "attention_output_with_dense": 0
}


In [None]:
trainer = MySparseXP(params)

In [None]:
# dir(trainer)

In [None]:
trainer.prepare()

03/08/2021 20:44:45 - INFO - nn_pruning.examples.xp -   Training/evaluation parameters
03/08/2021 20:44:45 - INFO - nn_pruning.examples.xp -     Model: ModelArguments(model_name_or_path='bert-base-uncased', config_name=None, tokenizer_name=None, cache_dir=None, use_fast_tokenizer=True)
03/08/2021 20:44:45 - INFO - nn_pruning.examples.xp -     Data: DataTrainingArguments(dataset_name='super_glue', dataset_config_name=None, train_file=None, validation_file=None, overwrite_cache=0, dataset_cache_dir='dataset_cache_dir', preprocessing_num_workers=None, max_seq_length=128, pad_to_max_length=True, doc_stride=128)
03/08/2021 20:44:45 - INFO - nn_pruning.examples.xp -     Training: XPTrainingArguments(output_dir='output/mnli_test2/', overwrite_output_dir=1, do_train=1, do_eval=1, do_predict=False, model_parallel=False, evaluation_strategy=<EvaluationStrategy.STEPS: 'steps'>, prediction_loss_only=False, per_device_train_batch_size=32, per_device_eval_batch_size=128, per_gpu_train_batch_size=Non

NotImplementedError: Implement in subclass

In [None]:
trainer.evaluate()

AttributeError: 'MySparseXP' object has no attribute 'trainer'

In [None]:
class BoolSparseXP(SparseXP, GlueXP):
    ARGUMENTS = {
        "model": ModelArguments,
        "data": GlueDataTrainingArguments,
        "training": XPTrainingArguments,
        "sparse": SparseTrainingArguments,
    }
    GLUE_TRAINER_CLASS = GlueSparseTrainer
    SHORT_NAMER = SparseGlueShortNamer
    CONSTRUCTOR = AutoModelForSequenceClassification
    LOGIT_NAMES = ["logits"]

    def __init__(self, params):
        GlueXP.__init__(self, params)
        SparseXP.__init__(self)

    def create_trainer(self, *args, **kwargs):
        super().create_trainer(*args, **kwargs)
        SparseXP.setup_trainer(self)

    @classmethod
    def final_finetune(cls, src_path, dest_path, task, teacher):
        param_dict = {
            "model_name_or_path": src_path,
            "task_name": task,
            "dataset_cache_dir": "dataset_cache_dir",
            "do_train": 1,
            "do_eval": 1,
            "per_device_train_batch_size": 32,
            "per_device_eval_batch_size": 128,
            "max_seq_length": 128,
            "doc_stride": 128,
            "num_train_epochs": 6,
            "logging_steps": 250,
            "save_steps": 5000,
            "eval_steps": 5000,
            "save_total_limit": 50,
            "seed": 17,
            "evaluation_strategy": "steps",
            "learning_rate": 3e-5,
            "output_dir": dest_path,
            "logging_dir": dest_path,
            "overwrite_cache": 0,
            "overwrite_output_dir": 1,
            "warmup_steps": 10,
            "initial_warmup": 0,
            "final_warmup": 0,
            "mask_init": "constant",
            "mask_scale": 0.0,
            "regularization": "",
            "regularization_final_lambda": 0,
            "distil_teacher_name_or_path":teacher,
            "distil_alpha_ce": 0.1,
            "distil_alpha_teacher": 0.90,
            "attention_output_with_dense": 0,
            "final_finetune": 1,
        }


        glue = cls(param_dict)
        glue.run()

        cls.fix_last_checkpoint_bug(dest_path)

In [None]:
sparse_args = SparseTrainingArguments()

In [None]:
bert_model = AutoModelForSequenceClassification.from_pretrained(bert_ckpt).to(device)

loading configuration file https://huggingface.co/bert-base-uncased/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/3c61d016573b14f7f008c02c4e51a366c67ab274726fe2910691e2a761acf43e.637c6035640bacb831febcc2b7f7bee0a96f9b30c2d7e9ef84082d9f252f3170
Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "type_vocab_size": 2,
  "vocab_size": 30522
}

loading weights file https://huggingface.co/bert-base-uncased/resolve/main/pytorch_model.bin from cache at /root/.cache/huggingface/transformers/a8041bf617d7f94ea26d15e218abd04afc2004805632ab

In [None]:
batch_size = 4
learning_rate = 2e-5
num_train_epochs = 3
logging_steps = len(boolq_enc['train']) // batch_size

args = TrainingArguments(
    output_dir='checkpoints',
    evaluation_strategy='epoch',
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=learning_rate,
    weight_decay=0.01,
    logging_steps=logging_steps,
    disable_tqdm=False
)

pruning_trainer = PruningTrainer(
    sparse_args=sparse_args,
    args=args,
    model= bert_model,
    train_dataset=boolq_enc['train'],
    eval_dataset=boolq_enc['validation'],
    tokenizer=bert_tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
pruning_trainer.evaluate()

AttributeError: 'PruningTrainer' object has no attribute 'patch_coordinator'

In [None]:
dir(SparseTrainer)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'compute_loss',
 'create_optimizer',
 'create_optimizer_and_scheduler',
 'create_scheduler',
 'evaluate',
 'log',
 'schedule_threshold',
 'set_patch_coordinator',
 'training_step']