# Fine-Pruning with a Sparse Trainer

> How to make sparse and fast models with a mix of structured and unstructured pruning

In this tutorial, we'll see how `nn_pruning` combines techniques from [movement pruning](https://arxiv.org/abs/2005.07683) and structured pruning to produce compact Transformers that can run inference faster than their dense counterparts, with little impact on accuracy. This tutorial is aimed at those who are familiar with the `transformers.Trainer` - if you're not, you can check out the [documentation](https://huggingface.co/transformers/main_classes/trainer.html?highlight=trainer#trainer) and `transformers` [examples](https://huggingface.co/transformers/examples.html#the-big-table-of-tasks) to see how it works. Let's get started! 

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from nn_pruning.sparse_trainer import SparseTrainer
from nn_pruning.patch_coordinator import SparseTrainingArguments, ModelPatchingCoordinator
from nn_pruning.inference_model_patcher import optimize_model

import torch
import datasets
import numpy as np
import transformers
datasets.logging.set_verbosity_error()
transformers.logging.set_verbosity_error()

from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using transformers v{transformers.__version__} and datasets v{datasets.__version__} and torch v{torch.__version__}")
print(f"Running on device: {device}")

Using transformers v4.3.3 and datasets v1.4.1 and torch v1.8.0
Running on device: cuda


## The dataset

To show `nn_pruning` in action, we'll use the [BoolQ dataset](https://arxiv.org/abs/1905.10044) which consists of naturally occurring yes/no questions concerning a passage of text. We can use the `datasets` library to load the dataset from the [Hugging Face Hub](https://huggingface.co/) as part of the [SuperGLUE benchmark](https://huggingface.co/datasets/super_glue):

In [None]:
from datasets import load_dataset

boolq = load_dataset("super_glue", "boolq")
boolq

DatasetDict({
    train: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 9427
    })
    validation: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 3270
    })
    test: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 3245
    })
})

Let's take a look at one of the training examples:

In [None]:
boolq['train'][0]

{'idx': 0,
 'label': 1,
 'passage': 'Persian language -- Persian (/ˈpɜːrʒən, -ʃən/), also known by its endonym Farsi (فارسی fārsi (fɒːɾˈsiː) ( listen)), is one of the Western Iranian languages within the Indo-Iranian branch of the Indo-European language family. It is primarily spoken in Iran, Afghanistan (officially known as Dari since 1958), and Tajikistan (officially known as Tajiki since the Soviet era), and some other regions which historically were Persianate societies and considered part of Greater Iran. It is written in the Persian alphabet, a modified variant of the Arabic script, which itself evolved from the Aramaic alphabet.',
 'question': 'do iran and afghanistan speak the same language'}

Here we can see that we're given a `question` about a `passage` of text, and the answer is given a value of 0 (false) / 1 (true) in the `label` field. To help the trainer automatically detect the labels, let's rename the column as follows: 

In [None]:
boolq.rename_column_("label", "labels")

## Tokenizing the question-answer pairs

Before we can fine-prune any models, the first thing we need to do is tokenize and encode the `question` and `passage` fields of each example. Currently, `nn_pruning` supports fine-pruning for BERT models so we'll use BERT-base and load up the tokenizer as follows:

In [None]:
bert_ckpt = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(bert_ckpt)

To tokenize our inputs, we'll pass the `question` and `passage` fields to our tokenizer and set `truncation="only_second"` to ensure that we only truncate the passages if the question-answer pair exceeds the maximum context length of 512 tokens. The following function does what we need and we can apply it to the whole dataset via the `DatasetDict.map` method:

In [None]:
from transformers import AutoTokenizer

def tokenize_and_encode(examples): 
    return tokenizer(examples['question'], examples['passage'], truncation="only_second")

boolq_enc = boolq.map(tokenize_and_encode, batched=True)

## Creating a Sparse Trainer

The next thing to do is create a trainer that can handle the fine-pruning and evaluation steps for us. In `nn_pruning` 

In [None]:
class PruningTrainer(SparseTrainer, Trainer):
    def __init__(self, sparse_args, *args, **kwargs):
        Trainer.__init__(self, *args, **kwargs)
        SparseTrainer.__init__(self, sparse_args)
        
    def compute_loss(self, model, inputs, return_outputs=False):
        """
        We override the default loss in SparseTrainer because it throws an 
        error when run without distillation
        """
        outputs = model(**inputs)

        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        # We don't use .loss here since the model may return tuples instead of ModelOutput.
        loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
        self.metrics["ce_loss"] += float(loss)
        self.loss_counter += 1
        return (loss, outputs) if return_outputs else loss

In [None]:
sparse_args = SparseTrainingArguments()

d = {
  "initial_warmup": 1,
  "final_warmup": 3,
  "initial_threshold": 1.0, # When using topK set to 1 (initial density). With sigmoied_threshold, use 0.0 (cutoff)
  "final_threshold": 0.25, # When using topK, this is the final density. With sigmoied_threshold, use 0.1 (final cutoff, which is a bit arbitrary of course, set regularization_final_lambda to adjust final sparsity)
  "dense_pruning_method": "topK:1d_alt", #"sigmoied_threshold:1d_alt",
  "dense_block_rows":1,
  "dense_block_cols":1,
  "dense_lambda":0.25,
  "attention_pruning_method": "topK", #"sigmoied_threshold",
  "attention_block_rows":32,
  "attention_block_cols":32,
  "attention_lambda":1.0,
  "ampere_pruning_method": "disabled",
  "mask_init": "constant",
  "mask_scale": 0.0,
  "regularization": None, # "l1" when pruning_methods are sigmoied_threshold
  "regularization_final_lambda": 20, # To be tweaked to adjust sparsity : the higher, the more sparse. Try different values by multiplying by 2x several times
  "distil_teacher_name_or_path":None,
  "distil_alpha_ce": 0.1,
  "distil_alpha_teacher": 0.9,
  "attention_output_with_dense": 0,
  "layer_norm_patch" : 0,
  "gelu_patch":0
}

for k,v in d.items():
    if hasattr(sparse_args, k):
        setattr(sparse_args, k, v)
    else:
        print(f"sparse_args does not have an argument {k}")



In [None]:
batch_size = 16
learning_rate = 2e-5
num_train_epochs = 6
logging_steps = len(boolq_enc["train"]) // batch_size
# warmup for 10% of training steps
warmup_steps = logging_steps * num_train_epochs * 0.1

args = TrainingArguments(
    output_dir='checkpoints',
    evaluation_strategy='epoch',
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=learning_rate,
    weight_decay=0.01,
    logging_steps=logging_steps,
    disable_tqdm=False,
    report_to=None
)

bert_model = AutoModelForSequenceClassification.from_pretrained(bert_ckpt).to(device)


In [None]:
mpc = ModelPatchingCoordinator(
    sparse_args=sparse_args, 
    device=device, 
    cache_dir="checkpoints", 
    logit_names="logits", 
    teacher_constructor=AutoModelForSequenceClassification)


In [None]:
mpc.patch_model(bert_model)

bert_model.save_pretrained("patched")

LAYER NORM PATCH {'patched': 72}


## Defining the metrics

In [None]:
accuracy_score = load_metric('accuracy')

def compute_metrics(pred):
    predictions, labels = pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy_score.compute(predictions=predictions, references=labels)

In [None]:
trainer = PruningTrainer(
    sparse_args=sparse_args,
    args=args,
    model=bert_model,
    train_dataset=boolq_enc["train"],
    eval_dataset=boolq_enc["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.set_patch_coordinator(mpc)

In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second,Threshold,Ampere Temperature,Regu Lambda,Loss
1,0.6499,0.662862,0.621713,97.2997,33.607,0.25,20.0,20.0,0.660849


KeyboardInterrupt: 

In [None]:
trainer.save_model("models/bert-base-uncased-finepruned-boolq")

## Optimise for inference

In [None]:
mpc.compile_model(trainer.model)

In [None]:
prunebert_model = optimize_model(trainer.model, "dense")

removed heads 84, total_heads=144, percentage removed=0.5833333333333334
bert.encoder.layer.0.intermediate.dense, sparsity = 75.00
bert.encoder.layer.0.output.dense, sparsity = 75.00
bert.encoder.layer.1.intermediate.dense, sparsity = 75.00
bert.encoder.layer.1.output.dense, sparsity = 75.00
bert.encoder.layer.2.intermediate.dense, sparsity = 75.00
bert.encoder.layer.2.output.dense, sparsity = 75.00
bert.encoder.layer.3.intermediate.dense, sparsity = 75.00
bert.encoder.layer.3.output.dense, sparsity = 75.00
bert.encoder.layer.4.intermediate.dense, sparsity = 75.00
bert.encoder.layer.4.output.dense, sparsity = 75.00
bert.encoder.layer.5.intermediate.dense, sparsity = 75.00
bert.encoder.layer.5.output.dense, sparsity = 75.00
bert.encoder.layer.6.intermediate.dense, sparsity = 75.00
bert.encoder.layer.6.output.dense, sparsity = 75.00
bert.encoder.layer.7.intermediate.dense, sparsity = 75.00
bert.encoder.layer.7.output.dense, sparsity = 75.00
bert.encoder.layer.8.intermediate.dense, sparsi

In [None]:
prunebert_model.num_parameters() / bert_model.num_parameters()

0.46086829411385494

In [None]:
boolq["train"][-1]

{'idx': 9426,
 'labels': 0,
 'passage': "Margin of error -- The margin of error is usually defined as the ``radius'' (or half the width) of a confidence interval for a particular statistic from a survey. One example is the percent of people who prefer product A versus product B. When a single, global margin of error is reported for a survey, it refers to the maximum margin of error for all reported percentages using the full sample from the survey. If the statistic is a percentage, this maximum margin of error can be calculated as the radius of the confidence interval for a reported percentage of 50%.",
 'question': 'is margin of error the same as confidence interval'}

In [None]:
from time import perf_counter

def compute_latencies(model,
                     question="Is Saving Private Ryan based on a book?",
                     passage="""In 1994, Robert Rodat wrote the script for the film. Rodat’s script was submitted to 
                     producer Mark Gordon, who liked it and in turn passed it along to Spielberg to direct. The film is 
                     loosely based on the World War II life stories of the Niland brothers. A shooting date was set for 
                     June 27, 1997"""):
    inputs = tokenizer(question, passage, truncation="only_second", return_tensors="pt")
    latencies = []
    
    for _ in range(10):
        _ = model(**inputs)
        
    for _ in range(100):
        start_time = perf_counter()
        _ = model(**inputs)
        latency = perf_counter() - start_time 
        latencies.append(latency)
        # Compute run statistics
        time_avg_ms = 1000 * np.mean(latencies)
        time_std_ms = 1000 * np.std(latencies) 
    print(f"Average latency (ms) - {time_avg_ms:.2f} +\- {time_std_ms:.2f}") 
    return {"time_avg_ms": time_avg_ms, "time_std_ms": time_std_ms}

In [None]:
compute_latencies(prunebert_model.to("cpu"))

Average latency (ms) - 57.58 +\- 29.91


{'time_avg_ms': 57.578903548419476, 'time_std_ms': 29.905951302571026}

In [None]:
bert_ft_model = AutoModelForSequenceClassification.from_pretrained("lewtun/bert-base-uncased-finetuned-boolq").to("cpu")

In [None]:
compute_latencies(bert_ft_model.to("cpu"))

Average latency (ms) - 129.27 +\- 37.69


{'time_avg_ms': 129.27311155945063, 'time_std_ms': 37.693078994840896}