# Introduction 

This tutorial is used to list steps of introducing [Prune Once For All](https://arxiv.org/abs/2111.05754) examples.

# Prerequisite

## Install packages

* Follow [installation](https://github.com/intel-innersource/frameworks.ai.nlp-toolkit.intel-nlp-toolkit#installation) to install **intel-extension-for-transformers**. 

In [None]:
!pip install datasets>=1.8.0 torch>=1.10.0 transformers>=4.12.0 wandb

## Import packages

In [None]:
import logging
import os
from dataclasses import dataclass, field
import timeit
from datasets import load_dataset, load_metric
import functools
from trainer_qa import QuestionAnsweringTrainer

from transformers import (
    AutoConfig,
    AutoModelForQuestionAnswering,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
)
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from typing import Optional
from utils_qa import postprocess_qa_predictions


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.12.0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

logger = logging.getLogger(__name__)

os.environ["WANDB_DISABLED"] = "true"

## Download Dataset from the Hub

In [None]:
raw_datasets = load_dataset("squad")

## Download fp32 Model from the Hub

In [None]:
model_name_or_path = 'Intel/distilbert-base-uncased-sparse-90-unstructured-pruneofa'
config = AutoConfig.from_pretrained(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
model = AutoModelForQuestionAnswering.from_pretrained(
    model_name_or_path,
    from_tf=bool(".ckpt" in model_name_or_path),
    config=config,
    use_auth_token=None
)

## Preprocessing the Dataset for Training

In [None]:
column_names = raw_datasets["train"].column_names
question_column_name = "question" if "question" in column_names else column_names[0]
context_column_name = "context" if "context" in column_names else column_names[1]
answer_column_name = "answers" if "answers" in column_names else column_names[2]    

pad_on_right = tokenizer.padding_side == "right"
max_seq_length = min(384, tokenizer.model_max_length)

def prepare_train_features(examples, tokenizer=tokenizer):
    examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]

    tokenized_examples = tokenizer(
        examples[question_column_name if pad_on_right else context_column_name],
        examples[context_column_name if pad_on_right else question_column_name],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_seq_length,
        stride=128,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding=False,
    )


    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    offset_mapping = tokenized_examples.pop("offset_mapping")

    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        sequence_ids = tokenized_examples.sequence_ids(i)

        sample_index = sample_mapping[i]
        answers = examples[answer_column_name][sample_index]
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples


train_examples = raw_datasets["train"]
train_dataset = train_examples.map(
    prepare_train_features,
    batched=True,
    num_proc=None,
    remove_columns=column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on train dataset",
)

## Preprocessing the Dataset for Validation

In [None]:
column_names = raw_datasets["validation"].column_names
question_column_name = "question" if "question" in column_names else column_names[0]
context_column_name = "context" if "context" in column_names else column_names[1]
answer_column_name = "answers" if "answers" in column_names else column_names[2]    

pad_on_right = tokenizer.padding_side == "right"
max_seq_length = min(384, tokenizer.model_max_length)

def prepare_validation_features(examples, tokenizer=tokenizer):
    examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]

    tokenized_examples = tokenizer(
        examples[question_column_name if pad_on_right else context_column_name],
        examples[context_column_name if pad_on_right else question_column_name],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_seq_length,
        stride=128,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding=False,
    )

    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    tokenized_examples["example_id"] = []

    for i in range(len(tokenized_examples["input_ids"])):
        sequence_ids = tokenized_examples.sequence_ids(i)
        context_index = 1 if pad_on_right else 0

        sample_index = sample_mapping[i]
        tokenized_examples["example_id"].append(examples["id"][sample_index])

        tokenized_examples["offset_mapping"][i] = [
            (o if sequence_ids[k] == context_index else None)
            for k, o in enumerate(tokenized_examples["offset_mapping"][i])
        ]

    return tokenized_examples

eval_examples = raw_datasets["validation"]
eval_dataset = eval_examples.map(
    prepare_validation_features,
    batched=True,
    num_proc=None,
    remove_columns=column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on validation dataset",
)

## Define Metrics

In [None]:
data_collator = (
    DataCollatorWithPadding(tokenizer, None)
)

def post_processing_function(examples, features, predictions, stage="eval"):
    predictions = postprocess_qa_predictions(
        examples=examples,
        features=features,
        predictions=predictions,
        version_2_with_negative=False,
        n_best_size=20,
        max_answer_length=30,
        null_score_diff_threshold=0.0,
        output_dir="./tmp/squad_output",
        log_level='passive'
        prefix=stage,
    )
    formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]

    references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
    return EvalPrediction(predictions=formatted_predictions, label_ids=references)

metric = load_metric("squad")

def compute_metrics(p: EvalPrediction):
    return metric.compute(predictions=p.predictions, references=p.label_ids)

## Prepare Datasets for Teacher Model

In [None]:
teacher_model_name_or_path = "distilbert-base-uncased-distilled-squad"
teacher_config = AutoConfig.from_pretrained(
    teacher_model_name_or_path,
    use_auth_token=True
)
teacher_tokenizer = AutoTokenizer.from_pretrained(
    teacher_model_name_or_path,
    use_fast=True,
    use_auth_token=True
)
teacher_model = AutoModelForQuestionAnswering.from_pretrained(
    teacher_model_name_or_path,
    from_tf=bool(".ckpt" in model_name_or_path),
    config=teacher_config,
    use_auth_token=True
)

teacher_train_dataset = train_examples.map(
    functools.partial(prepare_train_features, tokenizer=teacher_tokenizer),
    batched=True,
    num_proc=None,
    remove_columns=column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on train dataset",
)

teacher_eval_dataset = eval_examples.map(
    functools.partial(prepare_validation_features, tokenizer=teacher_tokenizer),
    batched=True,
    num_proc=None,
    remove_columns=column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on validation dataset",
)
    

para_counter = lambda model:sum(p.numel() for p in model.parameters())
logger.info("***** Number of teacher model parameters: {:.2f}M *****".format(\
            para_counter(teacher_model)/10**6))
logger.info("***** Number of student model parameters: {:.2f}M *****".format(\
            para_counter(model)/10**6))

# Orchestrate Optimizations & Benchmark

## Orchestrate Optimizations

In [None]:
# Initialize our Trainer
trainer = QuestionAnsweringTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    eval_examples=eval_examples,
    tokenizer=tokenizer,
    data_collator=data_collator,
    post_process_function=post_processing_function,
    compute_metrics=compute_metrics,
)

metric_name = "eval_f1"

tune_metric = metrics.Metric(
    name=metric_name, is_relative=True, criterion=0.01
)

target_sparsity_ratio = None
pruner_config = PrunerConfig(prune_type='PatternLock', target_sparsity_ratio=None)
pruning_conf = PruningConfig(framework="pytorch_fx",pruner_config=[pruner_config], metrics=tune_metric)
distillation_conf = DistillationConfig(framework="pytorch_fx", metrics=tune_metric)

objective = objectives.performance
quantization_conf = QuantizationConfig(
    approach="QuantizationAwareTraining",
    max_trials=600,
    metrics=[tune_metric],
    objectives=[objective]
)
conf_list = [pruning_conf, distillation_conf, quantization_conf]
model = trainer.orchestrate_optimizations(config_list=conf_list, teacher_model=teacher_model)

## Run Benchmark after Orchestrate Optimizations

In [None]:
start_time = timeit.default_timer()
results = trainer.evaluate()
evalTime = timeit.default_timer() - start_time
samples = len(eval_dataset)
eval_f1_dynamic = results.get("eval_f1")
print('Batch size = {}'.format(8))
print("Finally Eval eval_f1 Accuracy: {}".format(eval_f1_dynamic))
print("Latency: {:.3f} ms".format(evalTime / samples * 1000))
print("Throughput: {} samples/sec".format(samples/evalTime))