# Introduction 

This tutorial is used to list steps of introducing [Prune Once For All](https://arxiv.org/abs/2111.05754) examples.

# Prerequisite

## Install Packages

* Follow [installation](https://github.com/intel-innersource/frameworks.ai.nlp-toolkit.intel-nlp-toolkit#installation) to install **intel-extension-for-transformers**. 

In [None]:
!pip install accelerate torch>=1.10 datasets>=1.1.3 sentencepiece!=0.1.92 transformers>=4.12.0 protobuf wandb

## Import Packages

In [None]:
import logging
import os
import numpy as np
import torch
from dataclasses import dataclass, field
from datasets import load_dataset, load_metric
from intel_extension_for_transformers.transformers import (
    metrics,
    PrunerConfig,
    PruningConfig,
    DistillationConfig,
    QuantizationConfig,
    objectives
)
from intel_extension_for_transformers.transformers.trainer import NLPTrainer
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
    PretrainedConfig
)
from transformers.utils import check_min_version

os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["WANDB_DISABLED"] = "true"

check_min_version("4.12.0")


task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

logger = logging.getLogger(__name__)

## Download Dataset from the Hub

In [None]:
raw_datasets = load_dataset("glue", "sst2")
label_list = raw_datasets["train"].features["label"].names
num_labels = len(label_list)

## Download fp32 Model from the Hub

In [None]:
model_name_or_path = "Intel/distilbert-base-uncased-sparse-90-unstructured-pruneofa"
config = AutoConfig.from_pretrained(
    model_name_or_path,
    num_labels=num_labels,
    finetuning_task="sst2",
    revision="main"
)
tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
    use_fast=True,
    revision="main"
)
model = AutoModelForSequenceClassification.from_pretrained(
    model_name_or_path,
    from_tf=bool(".ckpt" in model_name_or_path),
    config=config,
    revision="main"
)

## Preprocessing the Dataset

In [None]:
is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
if is_regression:
    num_labels = 1
else:
    label_list = raw_datasets["train"].unique("label")
    label_list.sort()
    num_labels = len(label_list)

sentence1_key, sentence2_key = task_to_keys["sst2"]

label_to_id = None
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id \
    and not is_regression:
    label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
    label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}

if label_to_id is not None:
    model.config.label2id = label_to_id
    model.config.id2label = {id: label for label, id in config.label2id.items()}

max_seq_length = min(128, tokenizer.model_max_length)


def preprocess_function(examples, tokenizer=tokenizer):
    args = (
        (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
    )
    result = tokenizer(*args, padding=False, max_length=max_seq_length, truncation=True)

    # Map labels to IDs (not necessary for GLUE tasks)
    if label_to_id is not None and "label" in examples:
        result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
    return result


raw_datasets = raw_datasets.map(
    preprocess_function, batched=True, load_from_cache_file=False
)

train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["validation"]
metric = load_metric("glue", "sst2")


def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
    result = metric.compute(predictions=preds, references=p.label_ids)
    
    if len(result) > 1:
        result["combined_score"] = np.mean(list(result.values())).item()
        return result
    
    if is_regression:
        return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
    else:
        return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}


data_collator = (
    DataCollatorWithPadding(tokenizer, None)
)

## Prepare Datasets for Teacher Model

In [None]:
teacher_model_name_or_path = "distilbert-base-uncased-finetuned-sst-2-english"
teacher_config = AutoConfig.from_pretrained(teacher_model_name_or_path, \
                        num_labels=num_labels, finetuning_task="sst2")
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name_or_path, \
                    use_fast=True)
teacher_model = AutoModelForSequenceClassification.from_pretrained(
                    teacher_model_name_or_path,
                    from_tf=bool(".ckpt" in teacher_model_name_or_path),
                    config=teacher_config,
)
teacher_processed_datasets = raw_datasets.map(
    functools.partial(preprocess_function, tokenizer=teacher_tokenizer), 
    batched=True, remove_columns=raw_datasets["train"].column_names
)
teacher_train_dataset = teacher_processed_datasets["train"]
teacher_eval_dataset = teacher_processed_datasets["validation"]


def dict_tensor_to_model_device(batch, model):
    device = next(model.parameters()).device
    for k in batch:
        batch[k] = batch[k].to(device)


def get_logits(teacher_model, train_dataset, teacher_train_dataset):
    logger.info("***** Getting logits of teacher model *****")
    logger.info(f"  Num examples = {len(train_dataset) }")
    teacher_model.eval()
    npy_file = os.path.join(os.path.dirname(os.path.abspath(__file__)),
        '{}.{}.npy'.format("sst2", teacher_model_name_or_path.replace('/', '.')))
    if os.path.exists(npy_file):
        teacher_logits = [x for x in np.load(npy_file)]
    return train_dataset.add_column('teacher_logits', teacher_logits)


class BertModelforLogitsOutputOnly(torch.nn.Module):
    def __init__(self, model):
        super(BertModelforLogitsOutputOnly, self).__init__()
        self.model = model
    def forward(self, *args, **kwargs):
        output = self.model(*args, **kwargs)
        return output['logits']
    

with torch.no_grad():
    train_dataset = get_logits(BertModelforLogitsOutputOnly(teacher_model), train_dataset, teacher_train_dataset)

para_counter = lambda model:sum(p.numel() for p in model.parameters())
logger.info("***** Number of teacher model parameters: {:.2f}M *****".format(\
            para_counter(teacher_model)/10**6))
logger.info("***** Number of student model parameters: {:.2f}M *****".format(\
            para_counter(model)/10**6))

# Orchestrate Optimizations & Benchmark

## Orchestrate Optimizations

In [None]:
trainer = NLPTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

metric_name = "eval_accuracy"

tune_metric = metrics.Metric(
    name=metric_name, is_relative=True, criterion=0.01
)

target_sparsity_ratio = None
pruner_config = PrunerConfig(prune_type='PatternLock', target_sparsity_ratio=None)
pruning_conf = PruningConfig(framework="pytorch_fx",pruner_config=[pruner_config], metrics=tune_metric)
distillation_conf = DistillationConfig(framework="pytorch_fx", metrics=tune_metric)

objective = objectives.performance
quantization_conf = QuantizationConfig(
    approach="QuantizationAwareTraining",
    max_trials=600,
    metrics=[tune_metric],
    objectives=[objective]
)
conf_list = [pruning_conf, distillation_conf, quantization_conf]
model = trainer.orchestrate_optimizations(config_list=conf_list, teacher_model=teacher_model)

## Run Benchmark after Orchestrate Optimizations

In [None]:
results = trainer.evaluate()
throughput = results.get("eval_samples_per_second")
eval_acc = results.get("eval_accuracy")
print('Batch size = {}'.format(8))
print("Finally Eval eval_accuracy Accuracy: {:.5f}".format(eval_acc))
print("Latency: {:.5f} ms".format(1000 / throughput))
print("Throughput: {:.5f} samples/sec".format(throughput))