Skip to content

Commit

Permalink
quantization: add quant for conv+add fusion (#1309) (#1318)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiayisunx committed Dec 14, 2022
1 parent dca2d35 commit 5dd3a6e
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions intel_extension_for_pytorch/quantization/_recipe.py
Expand Up @@ -16,6 +16,9 @@
conv_gemm_ops = [str(F.conv2d), str(nn.Conv2d), str(F.conv3d), str(nn.Conv3d), str(torch.conv2d), str(torch.conv3d), \
str(F.conv_transpose2d), str(torch.nn.ConvTranspose2d), str(F.conv_transpose3d), str(torch.nn.ConvTranspose3d),
str(torch.conv_transpose2d), str(torch.conv_transpose2d), str(F.linear), str(nn.Linear), str(torch.matmul), str(torch.Tensor.matmul)]
conv_ops = [str(F.conv2d), str(nn.Conv2d), str(F.conv3d), str(nn.Conv3d), str(torch.conv2d), str(torch.conv3d), \
str(F.conv_transpose2d), str(torch.nn.ConvTranspose2d), str(F.conv_transpose3d), str(torch.nn.ConvTranspose3d),
str(torch.conv_transpose2d), str(torch.conv_transpose2d)]
rnn_ops = [str(torch.nn.LSTM)]

# Those ops only support s8->s8 path, and also require the qscheme is per_tensor_symmetric.
Expand Down Expand Up @@ -233,6 +236,7 @@ def reset_input_inf_dtype_to_orig_dtype(node, input_idx):
node.input_tensor_force_inf_dtype[input_idx] = node.input_tensor_infos[input_idx].inf_dtype

conv_gemm_node = _find_fused_node_with_cur_add(node, conv_gemm_ops)
conv_node = _find_fused_node_with_cur_add(node, conv_ops)
if conv_gemm_node is None:
# If pre_nodes don't have gemm node, need to check whether have quantizable node before it,
# if does't have quantizable node before it, we will not insert fake quant before add.
Expand All @@ -255,13 +259,17 @@ def reset_input_inf_dtype_to_orig_dtype(node, input_idx):
if node.input_tensor_infos[0] is not None and node.input_tensor_infos[0] in conv_gemm_node.output_tensor_infos:
node.input_tensor_infos[0].inf_dtype = node.input_tensor_infos[0].orig_dtype
node.input_tensor_force_inf_dtype[0] = node.input_tensor_infos[0].inf_dtype
# set another input's dtype, if another's input is from non-quantizable op, we can remove the fake quant.
reset_input_inf_dtype_to_orig_dtype(node, 1)
# TODO: set another input's dtype for conv nodes when oneDNN is ready.
if conv_node is None:
# set another input's dtype, if another's input is from non-quantizable op, we can remove the fake quant.
reset_input_inf_dtype_to_orig_dtype(node, 1)
elif node.input_tensor_infos[1] is not None and node.input_tensor_infos[1] in conv_gemm_node.output_tensor_infos:
node.input_tensor_infos[1].inf_dtype = node.input_tensor_infos[1].orig_dtype
node.input_tensor_force_inf_dtype[1] = node.input_tensor_infos[1].inf_dtype
# set another input's dtype, if another's input is from non-quantizable op, we can remove the fake quant.
reset_input_inf_dtype_to_orig_dtype(node, 0)
# TODO: set another input's dtype for conv nodes when oneDNN is ready.
if conv_node is None:
# set another input's dtype, if another's input is from non-quantizable op, we can remove the fake quant.
reset_input_inf_dtype_to_orig_dtype(node, 0)

# get a default recipe
def get_default_recipe(nodes):
Expand Down

0 comments on commit 5dd3a6e

Please sign in to comment.