# Movement Pruning

> A partial re-implementation of Movement Pruning: Adaptive Sparsity by Fine-Tuning by Victor Sanh, Thomas Wolf, and Alexander M. Rush [[arXiv:2005.07683](https://arxiv.org/abs/2005.07683)]

## References

* [_Movement Pruning: Adaptive Sparsity by Fine-Tuning_](https://arxiv.org/abs/2005.07683) by Victor Sanh, Thomas Wolf, and Alexander M. Rush
* The scripts and notebooks that accompany the paper ([link](https://github.com/huggingface/transformers/tree/master/examples/research_projects/movement-pruning))

## Load libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from transformerlab.question_answering import *
from transformerlab.pruning import *

In [None]:
from pathlib import Path

import datasets
import transformers

datasets.logging.set_verbosity_error()
transformers.logging.set_verbosity_error()

print(transformers.__version__, datasets.__version__)

4.1.1 1.2.0


In [None]:
import numpy as np
import random

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, default_data_collator, AdamW, get_linear_schedule_with_warmup

import torch
import torch.nn as nn
from torch.utils.data import SequentialSampler, DataLoader
from torch.nn import init, CrossEntropyLoss
from torch import autograd
import torch.nn.functional as F
import math
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

Running on device: cuda


## Load data

In [None]:
squad_ds = load_dataset("squad")
squad_ds

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

## Evaluate fine-pruned model

HuggingFace has released a PruneBERT checkpoint for SQuAD v1.1 called `prunebert-base-uncased-6-finepruned-w-distil-squad` which is described in their docs as follows:

> Pre-trained BERT-base-uncased fine-pruned with soft movement pruning on SQuAD v1.1. We use an additional distillation signal from `BERT-base-uncased` finetuned on SQuAD. The encoder counts 6% of total non-null weights and reaches 83.8 F1 score. The model can be accessed with: `pruned_bert = BertForQuestionAnswering.from_pretrained("huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad")`

In this notebook we'll focus on reproducing this model, so let's begin by simply validating that we can obtain the same F1-score. Before doing that, we first need to preprocess the data - let's get started!

## Fine-pruning without distillation

In [None]:
model_name = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)

### Create trainer

In [None]:
class PruningTrainingArguments(QuestionAnsweringTrainingArguments):
    def __init__(self, *args, initial_threshold=1., final_threshold=0.1, initial_warmup=1, final_warmup=2, final_lambda=0.,
                 mask_scores_learning_rate=0., **kwargs): 
        super().__init__(*args, **kwargs)

        self.initial_threshold = initial_threshold
        self.final_threshold = final_threshold
        self.initial_warmup = initial_warmup
        self.final_warmup = final_warmup
        self.final_lambda = final_lambda
        self.mask_scores_learning_rate = mask_scores_learning_rate

In [None]:
class PruningTrainer(QuestionAnsweringTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        if self.args.max_steps > 0:
            self.t_total = self.args.max_steps
            self.args.num_train_epochs = self.args.max_steps // (len(self.get_train_dataloader()) // self.args.gradient_accumulation_steps) + 1
        else:
            self.t_total = len(self.get_train_dataloader()) // self.args.gradient_accumulation_steps * self.args.num_train_epochs
            
        
    def create_optimizer_and_scheduler(self, num_training_steps: int):
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters() if "mask_score" in n and p.requires_grad],
                "lr": self.args.mask_scores_learning_rate,
            },
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if "mask_score" not in n and p.requires_grad and not any(nd in n for nd in no_decay)
                ],
                "lr": self.args.learning_rate,
                "weight_decay": self.args.weight_decay,
            },
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if "mask_score" not in n and p.requires_grad and any(nd in n for nd in no_decay)
                ],
                "lr": self.args.learning_rate,
                "weight_decay": 0.0,
            },
        ]

        self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon)
        self.lr_scheduler = get_linear_schedule_with_warmup(
            self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=self.t_total
        )
        
        
    def compute_loss(self, model, inputs):
            
        threshold, regu_lambda = self._schedule_threshold(
            step=self.state.global_step+1,
            total_step=self.t_total,
            warmup_steps=self.args.warmup_steps,
            final_threshold=self.args.final_threshold,
            initial_threshold=self.args.initial_threshold,
            final_warmup=self.args.final_warmup,
            initial_warmup=self.args.initial_warmup,
            final_lambda=self.args.final_lambda,
        )
        inputs["threshold"] = threshold  
        outputs = model(**inputs)
        loss, start_logits_stu, end_logits_stu = outputs
        
        return loss
    
    def _schedule_threshold(
        self,
        step: int,
        total_step: int,
        warmup_steps: int,
        initial_threshold: float,
        final_threshold: float,
        initial_warmup: int,
        final_warmup: int,
        final_lambda: float,
    ):
        if step <= initial_warmup * warmup_steps:
            threshold = initial_threshold
        elif step > (total_step - final_warmup * warmup_steps):
            threshold = final_threshold
        else:
            spars_warmup_steps = initial_warmup * warmup_steps
            spars_schedu_steps = (final_warmup + initial_warmup) * warmup_steps
            mul_coeff = 1 - (step - spars_warmup_steps) / (total_step - spars_schedu_steps)
            threshold = final_threshold + (initial_threshold - final_threshold) * (mul_coeff ** 3)
        regu_lambda = final_lambda * threshold / final_threshold
        return threshold, regu_lambda

In [None]:
masked_config = MaskedBertConfig(pruning_method='topK', mask_init='constant', mask_scale=0.)
masked_model = MaskedBertForQuestionAnswering.from_pretrained('bert-base-uncased', config=masked_config).to(device)

batch_size = 16

num_train_examples = 1600
num_eval_examples = 320

train_ds, eval_ds, eval_examples = convert_examples_to_features(squad_ds, tokenizer, num_train_examples, num_eval_examples)
num_train_epochs=10

# pruning params
warmup_steps = int(num_train_examples / batch_size * num_train_epochs * .1) # 10% of total steps
initial_threshold = 1.
final_threshold = 0.3
initial_warmup = 1
final_warmup = 2
final_lambda = 0
mask_scores_learning_rate = 0 #1e-2

print(f"Number of training examples: {train_ds.num_rows}")
print(f"Number of validation examples: {eval_ds.num_rows}")
print(f"Number of raw validation examples: {eval_examples.num_rows}")

logging_steps = len(train_ds) // batch_size

print(f"Number of warmup steps: {warmup_steps}")
print(f"Number of logging steps: {logging_steps}")

pruning_training_args = PruningTrainingArguments(
    output_dir="checkpoints",
    evaluation_strategy = "epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_train_epochs,
    weight_decay=0.0,
    logging_steps=logging_steps,
    disable_tqdm=False,
    warmup_steps=warmup_steps,
    seed=42,
    final_threshold=final_threshold,
    initial_warmup=initial_warmup,
    final_warmup=final_warmup,
    final_lambda=final_lambda,
    mask_scores_learning_rate=mask_scores_learning_rate
)

data_collator = default_data_collator

Number of training examples: 1611
Number of validation examples: 325
Number of raw validation examples: 320
Number of warmup steps: 100
Number of logging steps: 100


In [None]:
eval_ds = eval_ds.map(lambda x : {'threshold': final_threshold})

In [None]:
pruning_trainer = PruningTrainer(
    model=masked_model,
    args=pruning_training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    eval_examples=eval_examples,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=squad_metrics
)

In [None]:
pruning_trainer.evaluate()

HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))




{'eval_loss': 'No log',
 'eval_exact_match': 0.3125,
 'eval_f1': 3.122511825187888}

In [None]:
pruning_trainer.train()

Epoch,Training Loss,Validation Loss,Exact Match,F1
1.0,4.870819,No log,0.3125,4.831887


HBox(children=(FloatProgress(value=0.0, max=320.0), HTML(value='')))




KeyboardInterrupt: 