## Graph Mode Static Quantization

In [1]:
import functools
from tqdm import tqdm

import torch 

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForQuestionAnswering

from nn_pruning.inference_model_patcher import optimize_model
from nn_pruning.modules.quantization import prepare_static, quantize

### Configuration

In [2]:
BS = 1
SEQLEN = 384
DOC_STRIDE= 128

### Model Preparation

In [3]:
model_name = "madlag/bert-base-uncased-squadv1-x2.01-f89.2-d30-hybrid-rewind-opt-v1"
model = AutoModelForQuestionAnswering.from_pretrained(model_name, torchscript=True)

In [4]:
optimized_model = optimize_model(model, mode="dense")

removed heads 0, total_heads=89, percentage removed=0.0
bert.encoder.layer.0.intermediate.dense, sparsity = 84.44
bert.encoder.layer.0.output.dense, sparsity = 84.44
bert.encoder.layer.1.intermediate.dense, sparsity = 82.75
bert.encoder.layer.1.output.dense, sparsity = 82.75
bert.encoder.layer.2.intermediate.dense, sparsity = 78.35
bert.encoder.layer.2.output.dense, sparsity = 78.35
bert.encoder.layer.3.intermediate.dense, sparsity = 79.56
bert.encoder.layer.3.output.dense, sparsity = 79.56
bert.encoder.layer.4.intermediate.dense, sparsity = 82.29
bert.encoder.layer.4.output.dense, sparsity = 82.29
bert.encoder.layer.5.intermediate.dense, sparsity = 81.84
bert.encoder.layer.5.output.dense, sparsity = 81.84
bert.encoder.layer.6.intermediate.dense, sparsity = 84.31
bert.encoder.layer.6.output.dense, sparsity = 84.31
bert.encoder.layer.7.intermediate.dense, sparsity = 88.05
bert.encoder.layer.7.output.dense, sparsity = 88.05
bert.encoder.layer.8.intermediate.dense, sparsity = 94.01
bert.e

In [5]:
prepared_model = prepare_static(
    optimized_model,
    input_names=["input_ids", "attention_mask", "token_type_ids"],
    batch_size=BS,
    sequence_length=SEQLEN,
    qconfig_name="default"
)



### Calibration

In [6]:
def prepare_example(examples, tokenizer, max_length=SEQLEN, doc_stride=DOC_STRIDE):
    pad_on_right = tokenizer.padding_side == "right"
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    return tokenized_examples

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [8]:
nb_representative_samples = 200
representative_dataset = load_dataset('squad', split=f'train[:{nb_representative_samples}]')
representative_dataset = representative_dataset.map(
    functools.partial(prepare_example, tokenizer=tokenizer),
    batched=True,
    remove_columns=representative_dataset.column_names
)
representative_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'token_type_ids'])
dataloader = torch.utils.data.DataLoader(representative_dataset, batch_size=BS)

Reusing dataset squad (/home/michael/.cache/huggingface/datasets/squad/plain_text/1.0.0/4fffa6cf76083860f85fa83486ec3028e7e32c342c218ff2a620fc6b2868483a)


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [9]:
with torch.no_grad():
    for examples in tqdm(dataloader):
        prepared_model(**examples)

100%|██████████| 200/200 [00:58<00:00,  3.41it/s]


### Conversion and export

In [10]:
model_int8 = quantize(prepared_model)

In [11]:
# Checking that inference is working properly.
model_int8_output = model_int8(**model_int8.dummy_inputs)

In [12]:
traced_model_int8 = torch.jit.trace(model_int8, tuple(model_int8.dummy_inputs.values()), strict=True)

  quantize_per_tensor_1 = torch.quantize_per_tensor(sub_1, _input_scale_0, _input_zero_point_0, _input_dtype_0);  sub_1 = _input_scale_0 = _input_zero_point_0 = _input_dtype_0 = None
  quantize_per_tensor_2 = torch.quantize_per_tensor(bert_embeddings_word_embeddings, _input_scale_1, _input_zero_point_1, _input_dtype_1);  bert_embeddings_word_embeddings = _input_scale_1 = _input_zero_point_1 = _input_dtype_1 = None
  quantize_per_tensor_3 = torch.quantize_per_tensor(bert_embeddings_token_type_embeddings, _input_scale_2, _input_zero_point_2, _input_dtype_2);  bert_embeddings_token_type_embeddings = _input_scale_2 = _input_zero_point_2 = _input_dtype_2 = None
  add_1 = torch.ops.quantized.add(quantize_per_tensor_2, quantize_per_tensor_3, _scale_0, _zero_point_0);  quantize_per_tensor_2 = quantize_per_tensor_3 = _scale_0 = _zero_point_0 = None
  quantize_per_tensor_4 = torch.quantize_per_tensor(bert_embeddings_position_embeddings, _input_scale_3, _input_zero_point_3, _input_dtype_3);  bert

  quantize_per_tensor_74 = torch.quantize_per_tensor(view_47, bert_encoder_layer_11_attention_output_dense_input_scale_0, bert_encoder_layer_11_attention_output_dense_input_zero_point_0, bert_encoder_layer_11_attention_output_dense_input_dtype_0);  view_47 = bert_encoder_layer_11_attention_output_dense_input_scale_0 = bert_encoder_layer_11_attention_output_dense_input_zero_point_0 = bert_encoder_layer_11_attention_output_dense_input_dtype_0 = None
  add_60 = torch.ops.quantized.add(bert_encoder_layer_11_attention_output_dropout, add_58, _scale_82, _zero_point_82);  bert_encoder_layer_11_attention_output_dropout = add_58 = _scale_82 = _zero_point_82 = None
  quantize_per_tensor_75 = torch.quantize_per_tensor(bert_encoder_layer_11_attention_output_layer_norm_weight, _input_scale_62, _input_zero_point_62, _input_dtype_62);  bert_encoder_layer_11_attention_output_layer_norm_weight = _input_scale_62 = _input_zero_point_62 = _input_dtype_62 = None
  mul_25 = torch.ops.quantized.mul(add_60, q

In [13]:
traced_model_int8_output = traced_model_int8(*tuple(model_int8.dummy_inputs.values()))

In [14]:
num_outputs = len(model_int8_output)
for i in range(num_outputs):
    if not torch.allclose(model_int8_output[i], traced_model_int8_output[i]):
        print(f"The {i}th outputs do not match")

In [15]:
torch.jit.save(traced_model_int8, "quantized.pt")

In [16]:
quantized_model = torch.jit.load("quantized.pt")
quantized_model

RecursiveScriptModule(
  original_name=GraphModule
  (bert): RecursiveScriptModule(
    original_name=Module
    (embeddings): RecursiveScriptModule(
      original_name=Module
      (word_embeddings): RecursiveScriptModule(
        original_name=Embedding
        (_packed_params): RecursiveScriptModule(original_name=EmbeddingPackedParams)
      )
      (token_type_embeddings): RecursiveScriptModule(
        original_name=Embedding
        (_packed_params): RecursiveScriptModule(original_name=EmbeddingPackedParams)
      )
      (position_embeddings): RecursiveScriptModule(
        original_name=Embedding
        (_packed_params): RecursiveScriptModule(original_name=EmbeddingPackedParams)
      )
      (LayerNorm): RecursiveScriptModule(original_name=Module)
      (dropout): RecursiveScriptModule(original_name=Dropout)
    )
    (encoder): RecursiveScriptModule(
      original_name=Module
      (layer): RecursiveScriptModule(
        original_name=Module
        (0): RecursiveScriptMod