# Pruning Transformers

> A partial re-implementation of Movement Pruning: Adaptive Sparsity by Fine-Tuning by Victor Sanh, Thomas Wolf, and Alexander M. Rush [[arXiv:2005.07683](https://arxiv.org/abs/2005.07683)]]

## References

* [_Movement Pruning: Adaptive Sparsity by Fine-Tuning_](https://arxiv.org/abs/2005.07683) by Victor Sanh, Thomas Wolf, and Alexander M. Rush
* The scripts and notebooks that accompany the paper ([link](https://github.com/huggingface/transformers/tree/master/examples/research_projects/movement-pruning))

## Load libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#export
from transformerlab.question_answering import *

In [None]:
from pathlib import Path

import datasets
import transformers

print(transformers.__version__, datasets.__version__)

4.1.1 1.2.0


In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, default_data_collator

import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

Running on device: cuda


## Load data

In [None]:
squad = load_dataset("squad")
squad

Reusing dataset squad (/root/.cache/huggingface/datasets/squad/plain_text/1.0.0/4c81550d83a2ac7c7ce23783bd8ff36642800e6633c1f18417fb58c3ff50cdd7)


DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

## Evaluate fine-pruned model

HuggingFace has released a PruneBERT checkpoint for SQuAD v1.1 called `prunebert-base-uncased-6-finepruned-w-distil-squad` which is described in their docs as follows:

> Pre-trained BERT-base-uncased fine-pruned with soft movement pruning on SQuAD v1.1. We use an additional distillation signal from `BERT-base-uncased` finetuned on SQuAD. The encoder counts 6% of total non-null weights and reaches 83.8 F1 score. The model can be accessed with: `pruned_bert = BertForQuestionAnswering.from_pretrained("huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad")`

In this notebook we'll focus on reproducing this model, so let's begin by simply validating that we can obtain the same F1-score. Before doing that, we first need to preprocess the data - let's get started!

### Preprocess data

In [None]:
pruned_model_name = "huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad"
pruned_tokenizer = AutoTokenizer.from_pretrained(pruned_model_name)

In [None]:
max_length = 384 
doc_stride = 128 
pad_on_right = pruned_tokenizer.padding_side == "right"

fn_kwargs = {
    "tokenizer": pruned_tokenizer,
    "max_length": max_length,
    "doc_stride": doc_stride,
    "pad_on_right": pad_on_right
}

#### Preprocess training set

In [None]:
train_enc = squad['train'].map(prepare_train_features, fn_kwargs=fn_kwargs, batched=True, remove_columns=squad["train"].column_names)
train_enc

Loading cached processed dataset at /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/4c81550d83a2ac7c7ce23783bd8ff36642800e6633c1f18417fb58c3ff50cdd7/cache-4376b2c43352894d.arrow


Dataset({
    features: ['attention_mask', 'end_positions', 'input_ids', 'start_positions', 'token_type_ids'],
    num_rows: 88524
})

#### Preprocess validation set

In [None]:
valid_enc = squad['validation'].map(prepare_validation_features, fn_kwargs=fn_kwargs, batched=True, remove_columns=squad["validation"].column_names)
valid_enc

Loading cached processed dataset at /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/4c81550d83a2ac7c7ce23783bd8ff36642800e6633c1f18417fb58c3ff50cdd7/cache-7c20f97d2bd5149b.arrow


Dataset({
    features: ['attention_mask', 'example_id', 'input_ids', 'offset_mapping', 'token_type_ids'],
    num_rows: 10784
})

### Initialize the trainer

Now that the data is preprocessed, let's instantiate a custom trainer and evaluate the model on the validation set:

In [None]:
pruned_model = AutoModelForQuestionAnswering.from_pretrained(pruned_model_name).to(device)
batch_size = 8

eval_ds = valid_enc
eval_raw_ds = squad["validation"]

pruned_args = QuestionAnsweringTrainingArguments(
    output_dir="checkpoints",
    per_device_eval_batch_size=batch_size)

data_collator = default_data_collator

In [None]:
pruned_trainer = QuestionAnsweringTrainer(
    model=pruned_model,
    args=pruned_args,
    eval_dataset=eval_ds,
    eval_examples=eval_raw_ds,
    tokenizer=pruned_tokenizer,
    data_collator=data_collator,
    compute_metrics=squad_metrics)

In [None]:
pruned_trainer.evaluate()

HBox(children=(FloatProgress(value=0.0, max=10570.0), HTML(value='')))




Trainer is attempting to log a value of "No log" of type <class 'str'> for key "eval/loss" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.


{'eval_loss': 'No log',
 'eval_exact_match': 74.98580889309366,
 'eval_f1': 83.78464399985475}

Great - we get an F1-score that matches the value quoted by HuggingFace!

## Prune-tuning without distillation