Skip to content

Commit

Permalink
support LSTM quantization for PackedSequence input (#1384)
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaobingSuper committed Feb 6, 2023
1 parent 01ef647 commit a81bd70
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 7 deletions.
43 changes: 41 additions & 2 deletions intel_extension_for_pytorch/quantization/_quantize_utils.py
Expand Up @@ -6,6 +6,7 @@
from torch.fx.node import map_aggregate
from torch.ao.quantization import PlaceholderObserver
from torch.quantization.qconfig import QConfig
from torch.nn.utils.rnn import PackedSequence

from ._utils import get_torch_function_hook_type, HookType, get_module_hook_type, OpQuantizeabilityType, \
attach_op_convert_info_to_model, save_quant_state, attach_scale_zp_values_to_model, convert_quant_state_map_to_nodes, \
Expand Down Expand Up @@ -36,6 +37,25 @@ def _check_add_has_scalar_input(args):
return True
return False

def _convert_PackedSequence_to_tuple_lstm(args):
if isinstance(args, tuple) and len(args) == 2: # (PackedSequence, hx)
input, batch_sizes, sorted_indices, unsorted_indices = args[0]
args = (input, batch_sizes, sorted_indices, unsorted_indices, args[-1])
elif isinstance(args, tuple) and len(args) == 1: # (PackedSequence, )
input, batch_sizes, sorted_indices, unsorted_indices = args[0]
args = (input, batch_sizes, sorted_indices, unsorted_indices)
else:
assert False, "_convert_PackedSequence_to_tuple args should be a tuple with size 2 or PackedSequence"
return args

def _convert_tuple_to_PackedSequence_lstm(args):
assert isinstance(args, tuple) and len(args) >= 4 and len(args) <=5, "_convert_tuple_to_PackedSequence input should be a tuple(5=<size >=4)"
if len(args) == 4:
return (PackedSequence(*args),)
else:
return (PackedSequence(*args[:-1]), args[-1])


def auto_prepare(
model : torch.nn.Module,
configure: QConfig,
Expand Down Expand Up @@ -212,7 +232,9 @@ def _patched_module_call(self, *args, **kwargs):
old_global_disable_torch_function_override = \
global_disable_torch_function_override
global_disable_torch_function_override = True

is_lstm_packed_input = isinstance(cur_module, torch.nn.LSTM) and isinstance(args[0], PackedSequence)
if is_lstm_packed_input:
args = _convert_PackedSequence_to_tuple_lstm(args)
if first_call:
# mypy ignore is used instead of assert because this
# runs on every forward and assert has a performance cost
Expand All @@ -226,19 +248,28 @@ def _patched_module_call(self, *args, **kwargs):
args, kwargs = parent_qstate.op_prepare_before_hook(
cur_module, args, kwargs) # type: ignore[arg-type]

if is_lstm_packed_input:
args = _convert_tuple_to_PackedSequence_lstm(args)

# original forward
output = orig_module_call(self, *args, **kwargs)
# Re-enable the overrides.
global_disable_torch_function_override = \
old_global_disable_torch_function_override

# after hooks
if is_lstm_packed_input:
output = _convert_PackedSequence_to_tuple_lstm(output)
if first_call:
output = parent_qstate.first_call_op_prepare_after_hook(
cur_module, output, args, qtensor_id, OpQuantizeabilityType.QUANTIZEABLE)
else:
output = parent_qstate.op_prepare_after_hook(
cur_module, output, args, global_op_idx)

if is_lstm_packed_input:
output = _convert_tuple_to_PackedSequence_lstm(output)

parent_qstate.mark_cur_op_complete(cur_module)
elif hook_type is HookType.MODULE_IO_HOOKS:
cur_qstate = cur_module._auto_quant_state
Expand Down Expand Up @@ -500,17 +531,25 @@ def _patched_module_call(self, *args, **kwargs):
old_global_disable_torch_function_override = \
global_disable_torch_function_override
global_disable_torch_function_override = True
is_lstm_packed_input = isinstance(cur_module, torch.nn.LSTM) and isinstance(args[0], PackedSequence)
if is_lstm_packed_input:
args = _convert_PackedSequence_to_tuple_lstm(args)
_, args, kwargs = qstate.op_convert_before_hook(
cur_module, args, kwargs, cur_module)
if is_lstm_packed_input:
args = _convert_tuple_to_PackedSequence_lstm(args)
if type(cur_module) in quantized_modules_has_weights:
weights = qstate.op_weight_convert_before_hook(cur_module)
output = module_call_to_function_call(self, args, weights)
else:
output = orig_module_call(self, *args, **kwargs)
# after hooks
if is_lstm_packed_input:
output = _convert_PackedSequence_to_tuple_lstm(output)
output = qstate.op_convert_after_hook(
cur_module, output)

if is_lstm_packed_input:
output = _convert_tuple_to_PackedSequence_lstm(output)
# Re-enable the override.
global_disable_torch_function_override = \
old_global_disable_torch_function_override
Expand Down
11 changes: 11 additions & 0 deletions intel_extension_for_pytorch/quantization/_recipe.py
Expand Up @@ -63,6 +63,17 @@ def _default_recipe_init(nodes):
tensor_info.inf_dtype = tensor_info.orig_dtype
node.input_tensor_force_inf_dtype[idx] = tensor_info.inf_dtype

# For LSTM, if it's input is a PackedSequence, we don't support ot now.
# TODO: support PackedSequence input for quantization LSTM.
if node.type in rnn_ops and len(node.input_tensor_infos) > 2:
for idx, tensor_info in enumerate(node.input_tensor_infos):
if tensor_info is not None:
tensor_info.inf_dtype = tensor_info.orig_dtype
node.input_tensor_force_inf_dtype[idx] = tensor_info.inf_dtype
for idx, tensor_info in enumerate(node.weight_tensor_infos):
if tensor_info is not None:
tensor_info.inf_dtype = tensor_info.orig_dtype

#TODO: making fusion pattern check more general.
def _find_fused_node_with_cur_elt_wise(node, ops):
r"""
Expand Down
12 changes: 7 additions & 5 deletions intel_extension_for_pytorch/quantization/_utils.py
Expand Up @@ -403,7 +403,7 @@ def set_node_output_quantized(nodes):
# output's infe dtype is not int8, set it and also set insert_fake_quant_after_output to True.
"""
def _reset_post_node_input_infos(node):
# make sure the post node will node insert fake quant if we add fake quant by cur node' output
# make sure the post node will insert fake quant if we add fake quant by cur node' output
if len(node.post_nodes) > 0:
for post_node in node.post_nodes:
if post_node.qconfig is not None:
Expand Down Expand Up @@ -434,10 +434,12 @@ def _reset_post_node_input_infos(node):
node.insert_fake_quant_after_outputs[0] = True
_reset_post_node_input_infos(node)
else:
if node.input_tensor_force_inf_dtype[0] in [torch.qint8, torch.quint8] and not post_node_are_quantized:
node.output_tensor_infos[0].inf_dtype = node.input_tensor_force_inf_dtype[0]
node.insert_fake_quant_after_outputs[0] = True
_reset_post_node_input_infos(node)
# TODO: enable PackedSequence input for LSTM.
if not (node.type in [nn.LSTM] and len(node.input_tensor_infos) > 2):
if node.input_tensor_force_inf_dtype[0] in [torch.qint8, torch.quint8] and not post_node_are_quantized:
node.output_tensor_infos[0].inf_dtype = node.input_tensor_force_inf_dtype[0]
node.insert_fake_quant_after_outputs[0] = True
_reset_post_node_input_infos(node)

qscheme_dict = {
str(torch.per_tensor_affine): torch.per_tensor_affine,
Expand Down
27 changes: 27 additions & 0 deletions tests/cpu/test_ao_jit_ipex_quantization.py
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence,pad_packed_sequence
from torch.testing import FileCheck
import copy
import json
Expand Down Expand Up @@ -262,6 +263,32 @@ def _lstm_params_list():
graph = self.checkQuantizeTrace(m, [x], atol=3e-2, rtol=1e-1)
self.assertGraphContainsExactly(graph, 'ipex::quantized_lstm', 1)

def test_lstm_PackedSequence(self):
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
self.lstm = nn.LSTM(input_size=288, hidden_size=1024, num_layers=6, batch_first=True, bidirectional=True, bias=True, dropout=0.2)

def forward(self, input, hid, mask=None):
if mask is not None:
lengths = mask.sum(-1)
seq = pack_padded_sequence(input, lengths.cpu(), batch_first=True)
seq, hid = self.lstm(seq, hid)
seq = pad_packed_sequence(seq, batch_first=True)[0]
return seq, hid
else:
return self.lstm(input, hid)

model = M().eval()
seq = torch.randn(size=(1, 211, 288), dtype=torch.float32)
# initialize hidden states
h0 = torch.zeros((12, 1, 1024), dtype=seq.dtype)
hid = (h0, h0)
mask = torch.ones(size=(1, 211), dtype=torch.uint8)

graph = self.checkQuantizeTrace(model, [seq, hid, mask])
self.assertGraphContainsExactly(graph, 'aten::lstm', 1)

class TestIpexQuantizationConvertAPI(JitLlgaTestCase):
def test_inplace_preapre(self):
class M(nn.Module):
Expand Down

0 comments on commit a81bd70

Please sign in to comment.