# Quantizing GPT2 to reduce costs and latency 💵💪

## System config ⚙️
To install required dependencies

In [None]:
!pip install -q torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
!pip install -q onnxruntime==1.8.0
!pip install -q transformers==4.3.1 datasets
!pip install -q onnx onnxconverter_common psutil pytz pandas py-cpuinfo py3nvml coloredlogs

## Imports and general settings 🔧

In [None]:
import torch
from torch import Tensor
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss

In [None]:
# In this example we will be quantizing the Dutch GPT2-small model

model_ckpt = "ml6team/gpt2-small-dutch-finetune-oscar"
device = torch.device("cpu")

In [None]:
import os

# Create a cache directory to store pretrained model.
cache_dir = os.path.join(".", "cache_models")
if not os.path.exists(cache_dir):
    os.makedirs(cache_dir)

## Quantization 🤏

### Convert HF model to ONNX

In [None]:
from onnxruntime.transformers.gpt2_helper import Gpt2Helper, MyGPT2LMHeadModel

In [None]:
#Load in the model config
from transformers import AutoConfig

model_name_or_path = model_ckpt
config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)

In [None]:
# Instantiate the model
model_regular = MyGPT2LMHeadModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir)

In [None]:
# Activate eval mode to for example deactivate Dropout, and transfer to the device
model_regular.eval().to(device)

In [None]:
# Obtain parameters for downstream usage
num_attention_heads = model_regular.config.n_head
hidden_size = model_regular.config.n_embd
num_layer = model_regular.config.n_layer

In [None]:
# Export to ONNX binary
Gpt2Helper.export_onnx(model_regular, device, "gpt2_regular.onnx")

### Optimize before quantization
To for example perform step fusing in the model graph

In [None]:
Gpt2Helper.optimize_onnx(
    "gpt2_regular.onnx",
    "gpt2_regular_opt.onnx",
    False,
    model_regular.config.num_attention_heads,
    model_regular.config.hidden_size)

### Quantize the models

In [None]:
from onnxruntime.transformers.quantize_helper import QuantizeHelper

In [None]:
QuantizeHelper.quantize_onnx_model(
    "gpt2_regular_opt.onnx",
    "gpt2_regular_opt_int8.onnx")

## Evaluate the quantized model 🔎

### Sampling code
Since we don't want to perform greedy decoding, but use a more sophisticated sampling strategy, we had to **coughs* borrow some code from [this HF repo page](https://github.com/huggingface/transformers/blob/main/src/transformers/generation_utils.py).

In [None]:
def top_k_top_p_filtering(
    logits: Tensor,
    top_k: int = 0,
    top_p: float = 1.0,
    filter_value: float = -float("Inf"),
    min_tokens_to_keep: int = 1,
) -> Tensor:
    """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
    Args:
        logits: logits distribution shape (batch size, vocabulary size)
        if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
        if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
            Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        Make sure we keep at least min_tokens_to_keep per batch example in the output
    From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    if top_k > 0:
        top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p
        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = filter_value
    return logits

### Inference helper code
These helper methods have been copied from [this notebook](https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/transformers/notebooks/Inference_GPT2_with_OnnxRuntime_on_CPU.ipynb) from the ONNXRuntime github repo.

In [None]:
import numpy
from transformers import AutoTokenizer

In [None]:
def get_tokenizer(model_name_or_path, cache_dir):
    # Fetch and prepare the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
    tokenizer.padding_side = "left"
    tokenizer.pad_token = tokenizer.eos_token
    #okenizer.add_special_tokens({'pad_token': '[PAD]'})
    return tokenizer

def get_example_inputs(prompt_text):  
    # Prepare the input text for furhter processing  
    tokenizer = get_tokenizer(model_name_or_path, cache_dir)
    encodings_dict = tokenizer.batch_encode_plus(prompt_text, padding=True)

    input_ids = torch.tensor(encodings_dict['input_ids'], dtype=torch.int64)
    attention_mask = torch.tensor(encodings_dict['attention_mask'], dtype=torch.float32)
    position_ids = (attention_mask.long().cumsum(-1) - 1)
    position_ids.masked_fill_(position_ids < 0, 0)

    #Empty Past State for generating first word
    empty_past = []
    batch_size = input_ids.size(0)
    sequence_length = input_ids.size(1)
    past_shape = [2, batch_size, num_attention_heads, 0, hidden_size // num_attention_heads]
    for i in range(num_layer):
        empty_past.append(torch.empty(past_shape).type(torch.float32).to(device))
       
    return input_ids, attention_mask, position_ids, empty_past

In [None]:
def regular_inference_with_io_binding(session, config, input_ids, position_ids, attention_mask, past):
    # Helper method to perform ORT session inference with IO Binding
    output_shapes = Gpt2Helper.get_output_shapes(batch_size=input_ids.size(0),
                                                 past_sequence_length=past[0].size(3),
                                                 sequence_length=input_ids.size(1),
                                                 config=config)
    output_buffers = Gpt2Helper.get_output_buffers(output_shapes, device)

    io_binding = Gpt2Helper.prepare_io_binding(session, input_ids, position_ids, attention_mask, past,
                                               output_buffers, output_shapes)
    session.run_with_iobinding(io_binding)

    outputs = Gpt2Helper.get_outputs_from_io_binding_buffer(session, output_buffers, output_shapes,
                                                            return_numpy=False)
    return outputs

In [None]:
def regular_test_generation(tokenizer, input_text, ort_session=None, num_tokens_to_produce = 30, top_k=50, top_p=0.95, do_sample=False, temperature=1.0):
    use_onnxruntime = (ort_session is not None)
    print("Text generation using", "OnnxRuntime" if use_onnxruntime else "PyTorch", "...")
    eos_token_id = tokenizer.eos_token_id
    
    input_ids, attention_mask, position_ids, past = get_example_inputs(input_text)
    batch_size = input_ids.size(0)

    has_eos = torch.zeros(batch_size, dtype=torch.bool)

    all_token_ids = input_ids.clone()

    for step in range(num_tokens_to_produce):
        outputs = regular_inference_with_io_binding(ort_session, config, input_ids, position_ids, attention_mask, past)

        # Get next logits
        next_token_logits = outputs[0][:, -1, :]

        # Top-k sampling
        if do_sample:
            # Temperature (higher temperature => more likely to sample low probability tokens)
            if temperature != 1.0:
                    scores = next_token_logits / temperature
            next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
            probs = F.softmax(next_token_logscores, dim=-1)
            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
        else:
            # Greedy sampling
            next_tokens = torch.argmax(next_token_logits, dim=-1)

        has_eos = has_eos | (next_tokens == eos_token_id)
        tokens_to_add = next_tokens.masked_fill(has_eos, eos_token_id)
        all_token_ids = torch.cat([all_token_ids, tokens_to_add.unsqueeze(-1)], dim=-1)

        # Update input_ids, attention_mask, position_ids and past
        input_ids = tokens_to_add.clone().detach().reshape([batch_size, 1]).to(device)    
        position_ids = (position_ids[:,-1] + 1).reshape(batch_size,1)
        attention_mask = torch.cat([attention_mask, torch.ones([batch_size, 1]).type_as(attention_mask)], 1).to(device)    

        past = []
        if not use_onnxruntime:
            past = list(outputs[1]) # past in torch output is tuple
        else:
            for i in range(num_layer):
                past_i = torch.from_numpy(outputs[i + 1]) if isinstance(outputs[i + 1], numpy.ndarray) else outputs[i + 1].clone().detach()
                past.append(past_i.to(device))

        if torch.all(has_eos):
            break

    for i, output in enumerate(all_token_ids):
        print("------------")
        print(tokenizer.decode(output, skip_special_tokens=True))

### Basic output quality tests

In [None]:
from onnxruntime import InferenceSession

In [None]:
tokenizer = get_tokenizer(model_name_or_path, cache_dir)
length=5

In [None]:
session_int8_regular = InferenceSession("gpt2_regular_opt_int8.onnx")

In [None]:
input_text = ['Dit is een test om', 'Dit is een test om', 'Dit is een test om']

In [None]:
regular_test_generation(
    tokenizer,
    input_text,
    do_sample=True,
    top_p=0.95,
    top_k=50,
    temperature=0.95,
    ort_session=session_int8_regular,
    num_tokens_to_produce=length)

### Compare the output logits

In [None]:
from transformers import GPT2LMHeadModel, AutoConfig
import onnxruntime
import numpy as np
from tqdm.notebook import tqdm

In [None]:
model_name_or_path= model_ckpt

In [None]:
input_ids, attention_mask, position_ids, empty_past = get_example_inputs(prompt_text=["Ik zie het niet meer zitten om"])

#### Testing dataset
For testing, we will select a small sample from the OSCAR dutch corpus

In [None]:
from datasets import load_dataset

dataset = load_dataset("nthngdy/oscar-mini", "unshuffled_deduplicated_nl", download_mode="force_redownload")

#### HF logits

In [None]:
config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)

torch_model = GPT2LMHeadModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir)
device = torch.device("cpu")
torch_model.eval().to(device)

In [None]:
def get_hf_logits(input_ids, empty_past, attention_mask, position_ids):
    with torch.no_grad():
        torch_output = torch_model(input_ids, past_key_values=empty_past, attention_mask=attention_mask, position_ids=position_ids)
    
    return torch_output[0]

#### ORT logits

In [None]:
onnx_model_path = "gpt2_regular_opt_int8.onnx"
session = onnxruntime.InferenceSession(onnx_model_path)

def get_ort_logits(input_ids, empty_past, attention_mask, position_ids):
    ort_inputs = {'input_ids': np.ascontiguousarray(input_ids.cpu().numpy()),
                'attention_mask' : np.ascontiguousarray(attention_mask.cpu().numpy()),
                'position_ids': np.ascontiguousarray(position_ids.cpu().numpy())
                }
    for i, past_i in enumerate(empty_past):
        ort_inputs[f'past_{i}'] = np.ascontiguousarray(past_i.cpu().numpy())
    ort_outputs = session.run(None, ort_inputs)

    return ort_outputs[0]

#### Compare

In [None]:
all_max_logits_diff = []
all_mean_logits_diff = []
all_median_logits_diff = []

for line in tqdm(dataset['train']["text"][:100]):

    # get inputs
    input_ids, attention_mask, position_ids, empty_past = get_example_inputs(prompt_text=[line])

    # hf logits
    hf_logits = get_hf_logits(input_ids, empty_past, attention_mask, position_ids)

    # ort logits
    ort_logits = get_ort_logits(input_ids, empty_past, attention_mask, position_ids)

    # compare
    logits_masked_diff = (hf_logits - ort_logits) * attention_mask.unsqueeze(2)

    max_logits_diff = logits_masked_diff.abs().max()
    mean_logits_diff = logits_masked_diff.abs().mean()
    median_logits_diff = logits_masked_diff.abs().median()

    all_max_logits_diff.append(max_logits_diff)
    all_mean_logits_diff.append(mean_logits_diff)
    all_median_logits_diff.append(median_logits_diff)

In [None]:
print(np.mean(all_max_logits_diff))
print(np.mean(all_mean_logits_diff))
print(np.mean(all_median_logits_diff))

16.129986
2.8550153
2.3409908


### Perplexity
To measure the text generation capabilities

In [None]:
#Fetch a portion of the test dataset

total_string= "\n\n".join(dataset['train']["text"][:1000])
encodings = tokenizer(total_string, return_tensors="pt")

input_ids, attention_mask, position_ids, empty_past = get_example_inputs(prompt_text=[total_string])

In [None]:
max_length = torch_model.config.n_positions
stride = 512

#### Torch model PPL

In [None]:
nlls = []

for i in tqdm(range(0, input_ids.size(1), stride)):
    begin_loc = max(i + stride - max_length, 0)
    end_loc = min(i + stride, input_ids.size(1))
    trg_len = end_loc - i  # may be different from stride on last loop
    input_ids_local = input_ids[:, begin_loc:end_loc].to(device)
    target_ids = input_ids_local.clone()
    target_ids[:, :-trg_len] = -100

    with torch.no_grad():
        outputs = torch_model(input_ids_local, labels=target_ids)
        neg_log_likelihood = outputs[0] * trg_len

    nlls.append(neg_log_likelihood)

torch_ppl = torch.exp(torch.stack(nlls).sum() / end_loc)

#### ORT quantized PPL

In [None]:
onnx_model_path = "gpt2_regular_opt_int8.onnx"
session = onnxruntime.InferenceSession(onnx_model_path)

nlls = []
for i in tqdm(range(0, input_ids.size(1), stride)):
    begin_loc = max(i + stride - max_length, 0)
    end_loc = min(i + stride, input_ids.size(1))
    trg_len = end_loc - i  # may be different from stride on last loop

    # Slice it up
    input_ids_local = input_ids[:, begin_loc:end_loc].to(device)
    attention_mask_local = attention_mask[:, begin_loc:end_loc].to(device)
    position_ids_local = position_ids[:,:input_ids_local.size(1)].to(device)

    target_ids = input_ids_local.clone()
    target_ids[:, :-trg_len] = -100

    ort_inputs = {
        'input_ids': numpy.ascontiguousarray(input_ids_local.cpu().numpy()),
        'attention_mask' : numpy.ascontiguousarray(attention_mask_local.cpu().numpy()),
        'position_ids': numpy.ascontiguousarray(position_ids_local.cpu().numpy())
        }

    for i, past_i in enumerate(empty_past):
        ort_inputs[f'past_{i}'] = numpy.ascontiguousarray(past_i.cpu().numpy())

    ort_outputs = session.run(None, ort_inputs)
    ort_outputs_logits = torch.from_numpy(ort_outputs[0])

    # Calculate loss

    shift_logits = ort_outputs_logits[..., :-1, :].contiguous()
    shift_labels = target_ids[..., 1:].contiguous()

    loss_fct = CrossEntropyLoss()
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

    neg_log_likelihood = loss * trg_len

    # print(neg_log_likelihood)

    nlls.append(neg_log_likelihood)

quantized_ppl = torch.exp(torch.stack(nlls).sum() / end_loc)

#### Compare

In [None]:
print(f"Non-quantized perplexity: {torch_ppl}")
print(f"Quantized perplexity: {quantized_ppl}")

Non-quantized perplexity: 52.991127014160156
Quantized perplexity: 75.03434753417969
