# Introduction 

This tutorial is used to list steps of introducing [Prune Once For All](https://arxiv.org/abs/2111.05754) examples.

# Prerequisite

## Install packages

* Follow [installation](https://github.com/intel-innersource/frameworks.ai.nlp-toolkit.intel-nlp-toolkit#installation) to install **nlp-toolkit**. 

In [None]:
# install model dependency
!pip install datasets>=1.8.0 torch>=1.10.0 transformers>=4.12.0 wandb

## Import packages

In [None]:
import datasets
import functools
import logging
import os
import numpy as np
import random
import sys
import torch
import transformers
from dataclasses import dataclass, field
from datasets import load_dataset, load_metric
from intel_extension_for_transformers.transformers import (
    metrics,
    PrunerConfig,
    PruningConfig,
    DistillationConfig,
    QuantizationConfig,
    OptimizedModel,
    objectives
)
from intel_extension_for_transformers.transformers.trainer import NLPTrainer
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
    HfArgumentParser,
    PretrainedConfig,
    TrainingArguments,
    default_data_collator,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.fx import symbolic_trace
from typing import Optional



os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["WANDB_DISABLED"] = "true"


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.12.0")


task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

logger = logging.getLogger(__name__)

## Define arguments

In [None]:
# ========== Define arguments =========
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
    validation_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
    )
    test_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input test data file to evaluate the perplexity on (a text file)."},
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    max_seq_length: int = field(
        default=384,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    pad_to_max_length: bool = field(
        default=True,
        metadata={
            "help": "Whether to pad all samples to `max_seq_length`. "
            "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
            "be faster on GPU but will be slower on TPU)."
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
            "value if set."
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
            "value if set."
        },
    )
    max_predict_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
            "value if set."
        },
    )
    version_2_with_negative: bool = field(
        default=False, metadata={"help": "If true, some of the examples do not have an answer."}
    )
    null_score_diff_threshold: float = field(
        default=0.0,
        metadata={
            "help": "The threshold used to select the null answer: if the best answer has a score that is less than "
            "the score of the null answer minus this threshold, the null answer is selected for this example. "
            "Only useful when `version_2_with_negative=True`."
        },
    )
    doc_stride: int = field(
        default=128,
        metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."},
    )
    n_best_size: int = field(
        default=20,
        metadata={"help": "The total number of n-best predictions to generate when looking for an answer."},
    )
    max_answer_length: int = field(
        default=30,
        metadata={
            "help": "The maximum length of an answer that can be generated. This is needed because the start "
            "and end predictions are not conditioned on one another."
        },
    )


@dataclass
class OptimizationArguments:
    """
    Arguments pertaining to what type of optimization we are going to apply on the model.
    """

    prune: bool = field(
        default=False,
        metadata={"help": "Whether or not to apply prune."},
    )
    pruning_approach: Optional[str] = field(
        default="BasicMagnitude",
        metadata={"help": "Pruning approach. Supported approach is basic_magnite."},
    )
    target_sparsity_ratio: Optional[float] = field(
        default=None,
        metadata={"help": "Targeted sparsity when pruning the model."},
    )
    metric_name: Optional[str] = field(
        default=None,
        metadata={"help": "Metric used for the tuning strategy."},
    )
    tolerance_mode: Optional[str] = field(
        default="absolute",
        metadata={"help": "Metric tolerance model, expected to be relative or absolute."},
    )
    tune: bool = field(
        default=False,
        metadata={"help": "Whether or not to apply quantization."},
    )
    quantization_approach: Optional[str] = field(
        default="QuantizationAwareTraining",
        metadata={"help": "Quantization approach. Supported approach are PostTrainingStatic, "
                  "PostTrainingDynamic and QuantizationAwareTraining."},
    )
    metric_name: Optional[str] = field(
        default="eval_f1",
        metadata={"help": "Metric used for the tuning strategy."},
    )
    is_relative: Optional[bool] = field(
        default=True,
        metadata={"help": "Metric tolerance model, expected to be relative or absolute."},
    )
    perf_tol: Optional[float] = field(
        default=0.01,
        metadata={"help": "Performance tolerance when optimizing the model."},
    )
    int8: bool = field(
        default=False,
        metadata={"help":"run benchmark."}
    )
    distillation: bool = field(
        default=False,
        metadata={"help": "Whether or not to apply distillation."},
    )
    teacher_model_name_or_path: str = field(
        default=False,
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    run_teacher_logits: bool = field(
        default=False,
        metadata={"help": "Whether or not to obtain teacher model's logits on train dataset before training."},
    )
    orchestrate_optimizations: bool = field(
        default=False,
        metadata={"help":"for one shot."}
    )
    benchmark: bool = field(
        default=False,
        metadata={"help": "run benchmark."}
    )
    accuracy_only: bool = field(
        default=False,
        metadata={"help":"Whether to only test accuracy for model tuned by Neural Compressor."}
    )

In [None]:
model_args = ModelArguments(
    model_name_or_path="textattack/bert-base-uncased-MRPC",
)
data_args = DataTrainingArguments(
    task_name="mrpc",
    max_seq_length=128,
    overwrite_cache=True
)
training_args = TrainingArguments(
    output_dir="./saved_result",
    do_eval=True,
    do_train=True,
    no_cuda=True,
    overwrite_output_dir=True,
    per_device_train_batch_size=8,
)
optim_args = OptimizationArguments(
    tune=True,
    quantization_approach="PostTrainingStatic"
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)

## Download dataset from the hub

In [None]:
# download the dataset.
raw_datasets = load_dataset("glue", data_args.task_name)
# Labels
label_list = raw_datasets["train"].features["label"].names
num_labels = len(label_list)

## Download fp32 model from the hub

In [None]:
# Load pretrained model and tokenizer
config = AutoConfig.from_pretrained(
    model_args.model_name_or_path,
    num_labels=num_labels,
    finetuning_task=data_args.task_name,
    revision="main"
)
tokenizer = AutoTokenizer.from_pretrained(
    model_args.model_name_or_path,
    use_fast=True,
    revision="main"
)
model = AutoModelForSequenceClassification.from_pretrained(
    model_args.model_name_or_path,
    from_tf=bool(".ckpt" in model_args.model_name_or_path),
    config=config,
    revision="main"
)

## Preprocessing the dataset

In [None]:
# Labels
is_regression = data_args.task_name == "stsb"

# Preprocessing the raw_datasets
sentence1_key, sentence2_key = task_to_keys[data_args.task_name]

# Padding strategy
if data_args.pad_to_max_length:
    padding = "max_length"
else:
    # We will pad later, dynamically at batch creation, to the max sequence length in each batch
    padding = False

# Some models have set the order of the labels to use, so let's make sure we do use it.
label_to_id = None
if (
    model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
    and data_args.task_name is not None
    and not is_regression
):
    # Some have all caps in their config, some don't.
    label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
    if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
        label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
    else:
        logger.warning(
            "Your model seems to have been trained with labels, but they don't match the dataset: ",
            f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
            "\nIgnoring the model labels as a result.",
        )
elif data_args.task_name is None and not is_regression:
    label_to_id = {v: i for i, v in enumerate(label_list)}

if label_to_id is not None:
    model.config.label2id = label_to_id
    model.config.id2label = {id: label for label, id in config.label2id.items()}

if data_args.max_seq_length > tokenizer.model_max_length:
    logger.warning(
        f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
        f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
    )
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

def preprocess_function(examples, tokenizer=tokenizer):
    # Tokenize the texts
    args = (
        (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
    )
    result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)

    # Map labels to IDs (not necessary for GLUE tasks)
    if label_to_id is not None and "label" in examples:
        result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
    return result

with training_args.main_process_first(desc="dataset map pre-processing"):
    raw_datasets = raw_datasets.map(
        preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache
    )
if training_args.do_train:
    if "train" not in raw_datasets:
        raise ValueError("--do_train requires a train dataset")
    train_dataset = raw_datasets["train"]
    if data_args.max_train_samples is not None:
        train_dataset = train_dataset.select(range(data_args.max_train_samples))

if training_args.do_eval:
    if "validation" not in raw_datasets and "validation_matched" not in raw_datasets:
        raise ValueError("--do_eval requires a validation dataset")
    eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
    if data_args.max_eval_samples is not None:
        eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))

# Log a few random samples from the training set:
if training_args.do_train:
    for index in random.sample(range(len(train_dataset)), 3):
        logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")

# Get the metric function
metric = load_metric("glue", data_args.task_name)

metric_name = "eval_accuracy"

# You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
    if data_args.task_name is not None:
        result = metric.compute(predictions=preds, references=p.label_ids)
        if len(result) > 1:
            result["combined_score"] = np.mean(list(result.values())).item()
        return result
    elif is_regression:
        return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
    else:
        return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}

# Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding.
data_collator = None

## Prepare datasets for teacher model

In [None]:
teacher_config = AutoConfig.from_pretrained(optim_args.teacher_model_name_or_path, \
                    num_labels=num_labels, finetuning_task=data_args.task_name)
teacher_tokenizer = AutoTokenizer.from_pretrained(optim_args.teacher_model_name_or_path, \
                    use_fast=model_args.use_fast_tokenizer)
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    optim_args.teacher_model_name_or_path,
    from_tf=bool(".ckpt" in optim_args.teacher_model_name_or_path),
    config=teacher_config,
)
teacher_model.to(training_args.device)

# prepare datasets for teacher model
teacher_processed_datasets = raw_datasets.map(
    functools.partial(preprocess_function, tokenizer=teacher_tokenizer), 
    batched=True, remove_columns=raw_datasets["train"].column_names
)
teacher_train_dataset = teacher_processed_datasets["train"]
if data_args.max_train_samples is not None:
    teacher_train_dataset = teacher_train_dataset.select(range(data_args.max_train_samples))
teacher_eval_dataset = teacher_processed_datasets["validation_matched" \
                            if data_args.task_name == "mnli" else "validation"]
if data_args.max_eval_samples is not None:
    teacher_eval_dataset = teacher_eval_dataset.select(range(data_args.max_eval_samples))
assert train_dataset.num_rows == teacher_train_dataset.num_rows and \
    eval_dataset.num_rows == teacher_eval_dataset.num_rows, \
    "Length of train or evaluation dataset of teacher doesnot match that of student."
    
# get logits of teacher model
def dict_tensor_to_model_device(batch, model):
    device = next(model.parameters()).device
    for k in batch:
        batch[k] = batch[k].to(device)

def get_logits(teacher_model, train_dataset, teacher_train_dataset):
    logger.info("***** Getting logits of teacher model *****")
    logger.info(f"  Num examples = {len(train_dataset) }")
    teacher_model.eval()
    npy_file = os.path.join(os.path.dirname(os.path.abspath(__file__)),
        '{}.{}.npy'.format(data_args.task_name, 
                            optim_args.teacher_model_name_or_path.replace('/', '.')))
    if os.path.exists(npy_file):
        teacher_logits = [x for x in np.load(npy_file)]
    else:
        sampler = None
        if training_args.world_size > 1:
            from transformers.trainer_pt_utils import ShardSampler
            sampler = ShardSampler(
                teacher_train_dataset,
                batch_size=training_args.per_device_eval_batch_size,
                num_processes=training_args.world_size,
                process_index=training_args.process_index,
            )
            teacher_model = torch.nn.parallel.DistributedDataParallel(
                teacher_model,
                device_ids=[training_args.local_rank] \
                    if training_args._n_gpu != 0 else None,
                output_device=training_args.local_rank \
                    if training_args._n_gpu != 0 else None,
            )
        train_dataloader = DataLoader(teacher_train_dataset, 
                                        collate_fn=data_collator, \
                                        sampler=sampler,
                                        batch_size=training_args.per_device_eval_batch_size)
        train_dataloader = tqdm(train_dataloader, desc="Evaluating")
        teacher_logits = []
        for step, batch in enumerate(train_dataloader):
            dict_tensor_to_model_device(batch, teacher_model)
            outputs = teacher_model(**batch)
            if training_args.world_size > 1:
                outputs_list = [None for i in range(training_args.world_size)]
                torch.distributed.all_gather_object(outputs_list, outputs)
                outputs = torch.concat(outputs_list, dim=0)
            teacher_logits += [x for x in outputs.cpu().numpy()]
        if training_args.world_size > 1:
            teacher_logits = teacher_logits[:len(teacher_train_dataset)]
        if training_args.local_rank in [-1, 0]:
            np.save(npy_file, np.array(teacher_logits))
    return train_dataset.add_column('teacher_logits', teacher_logits)
with torch.no_grad():
    train_dataset = get_logits(BertModelforLogitsOutputOnly(teacher_model), train_dataset, teacher_train_dataset)
        
para_counter = lambda model:sum(p.numel() for p in model.parameters())
logger.info("***** Number of teacher model parameters: {:.2f}M *****".format(\
            para_counter(teacher_model)/10**6))
logger.info("***** Number of student model parameters: {:.2f}M *****".format(\
            para_counter(model)/10**6))

# Trace model
from neural_compressor.adaptor.torch_utils.symbolic_trace import symbolic_trace
model = symbolic_trace(model, optim_args.quantization_approach=="QuantizationAwareTraining")

# Orchestrate Optimizations & Benchmark

## Orchestrate Optimizations

In [None]:
set_seed(training_args.seed)
# Initialize our Trainer
trainer = NLPTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset if training_args.do_train else None,
    eval_dataset=eval_dataset if training_args.do_eval else None,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

metric_name = (
    optim_args.metric_name
    if optim_args.metric_name is not None
    else "eval_"
    + (
        "pearson"
        if data_args.task_name == "stsb"
        else "matthews_correlation"
        if data_args.task_name == "cola"
        else "accuracy"
    )
)

if optim_args.orchestrate_optimizations:

    tune_metric = metrics.Metric(
        name=metric_name, is_relative=optim_args.is_relative, criterion=optim_args.perf_tol
    )
    prune_type = 'PatternLock' \
        if optim_args.pruning_approach else optim_args.pruning_approach
    target_sparsity_ratio = optim_args.target_sparsity_ratio \
        if optim_args.target_sparsity_ratio else None
    pruner_config = PrunerConfig(prune_type=prune_type, target_sparsity_ratio=target_sparsity_ratio)
    pruning_conf = PruningConfig(framework="pytorch_fx",pruner_config=[pruner_config], metrics=tune_metric)
    distillation_conf = DistillationConfig(framework="pytorch_fx", metrics=tune_metric)
    
    objective = objectives.performance
    quantization_conf = QuantizationConfig(
        approach=optim_args.quantization_approach,
        max_trials=600,
        metrics=[tune_metric],
        objectives=[objective]
    )
    conf_list = [pruning_conf, distillation_conf, quantization_conf]
    model = trainer.orchestrate_optimizations(config_list=conf_list, teacher_model=teacher_model)

## Run Benchmark after Orchestrate Optimizations

In [None]:
results = trainer.evaluate()
throughput = results.get("eval_samples_per_second")
eval_acc = results.get("eval_accuracy")
print('Batch size = {}'.format(training_args.per_device_eval_batch_size))
print("Finally Eval eval_accuracy Accuracy: {:.5f}".format(eval_acc))
print("Latency: {:.5f} ms".format(1000 / throughput))
print("Throughput: {:.5f} samples/sec".format(throughput))