Skip to content

Commit

Permalink
[Engine]: fix tuple lower pass (#816)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenwei-intel committed Apr 14, 2023
1 parent 76125ec commit e83a51f
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,33 +38,15 @@ def removeUnusedNode(graph, unused_nodes):
out_val.replaceAllUsesWith(in_val)
remove_list.append(node)

# remove ListConstruct followed by cat/stack/einsum
for node in graph.nodes():
if node.kind() == 'prim::ListConstruct' and node.outputsAt(0).type().str() in ['Tensor[]', 'int[]']:
out_val = node.outputsAt(0)
for val_user in out_val.uses():
next_node = val_user.user
if next_node.kind() in ['aten::cat', 'aten::stack']:
for i in range(node.inputsSize()):
next_node.addInput(node.inputsAt(i))
next_node.addInput(next_node.inputsAt(1))
next_node.removeInput(0)
next_node.removeInput(0)
remove_list.append(node)
elif next_node.kind() in ['aten::einsum', 'aten::view']:
for i in range(node.inputsSize()):
next_node.addInput(node.inputsAt(i))
next_node.removeInput(1)
remove_list.append(node)

for node in remove_list:
node.destroy()

def fuse_padding_seq(graph):
old_g = """
graph(%input_ids.1, %attention_mask.1, %3, %4, %5, %6, %7, %8, %9, %10):
%11 : int = aten::size(%input_ids.1, %9)
%attention_mask0.1 : Tensor = aten::view(%attention_mask.1, %11, %8)
%12 : int[] = prim::ListConstruct(%11, %8)
%attention_mask0.1 : Tensor = aten::view(%attention_mask.1, %12)
%14 : Tensor = aten::slice(%attention_mask0.1, %9, %9, %7, %6)
%15 : Tensor = aten::unsqueeze(%14, %6)
%16 : Tensor = aten::unsqueeze(%15, %5)
Expand Down Expand Up @@ -99,7 +81,7 @@ def __call__(self, model):
graph, _ = torch._C._jit_pass_lower_graph(model.graph, model._c)
torch._C._jit_pass_dce(graph)
torch._C._jit_pass_remove_inplace_ops(graph)
torch._C._jit_pass_lower_all_tuples(graph)
# torch._C._jit_pass_lower_all_tuples(graph)
torch._C._jit_pass_constant_propagation(graph)
removeUnusedNode(graph, ['aten::dropout', 'prim::NumToTensor', 'aten::to', 'aten::contiguous',
'aten::alias', 'aten::Int', 'aten::ScalarImplicit'])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -561,4 +561,39 @@ class LatRange(Operator):
"""Register the Tile operator."""
def __init__(self):
"""The init function of this operator."""
super().__init__()
super().__init__()

@operator_registry(operator_type='Masked_fill')
class Masked_fill(Operator):
"""Register the Tile operator."""
def __init__(self):
"""The init function of this operator."""
super().__init__()

@operator_registry(operator_type='Floor_divide')
class Floor_divide(Operator):
"""Register the Tile operator."""
def __init__(self):
"""The init function of this operator."""
super().__init__()

@operator_registry(operator_type='Max')
class Max(Operator):
"""Register the Tile operator."""
def __init__(self):
"""The init function of this operator."""
super().__init__()

@operator_registry(operator_type='ListUnpack')
class ListUnpack(Operator):
"""Register the Tile operator."""
def __init__(self):
"""The init function of this operator."""
super().__init__()

@operator_registry(operator_type='Silu')
class Silu(Operator):
"""Register the Tile operator."""
def __init__(self):
"""The init function of this operator."""
super().__init__()
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""The LowerAllTuples Pattern."""

from .pattern import Pattern, pattern_registry
from collections import namedtuple, OrderedDict
from .. import graph_utils as util
import copy


@pattern_registry(pattern_type='LowerAllTuples')
class LowerAllTuples(Pattern):
"""The LowerAllTuples pattern.
LowerAllTuples
"""
def __call__(self, model):
"""The __call__ function of this pattern class."""
if model.framework_modeling_config['framework'] != 'torch':
return model
remove_list = []
for node in model.nodes:
if node.op_type in ['ListConstruct', 'TupleConstruct']:
if node.output_tensors[0].dest_op == [] and node.output_tensors[0].name in model.output_tensors_name:
idx = model.output_tensors_name.index(node.output_tensors[0].name)
del model.output_tensors_name[idx]
for tensor in node.input_tensors:
for tensor in node.input_tensors:
model.output_tensors_name.insert(idx, tensor.name)
idx += 1

for dest_op_name in node.output_tensors[0].dest_op:
dest_node = model.get_node_by_name(dest_op_name)
for i in range(len(dest_node.input_tensors)):
if dest_node.input_tensors[i].name == node.output_tensors[0].name:
del dest_node.input_tensors[i]
idx = i
for tensor in node.input_tensors:
if node.name in tensor.dest_op:
tensor.dest_op.remove(node.name)
tensor.dest_op.append(dest_node.name)
dest_node.input_tensors.insert(idx, copy.deepcopy(tensor))
idx += 1
remove_list.append(node.name)
node_idx = len(model.nodes) - 1
while node_idx >= 0:
node = model.nodes[node_idx]
if node.op_type in ['ListUnpack', 'TupleUnpack']:
for source_op_name in node.input_tensors[0].source_op:
source_op = model.get_node_by_name(source_op_name)
for i in range(len(source_op.output_tensors)):
if source_op.output_tensors[i].name == node.input_tensors[0].name:
del source_op.output_tensors[i]
idx = i
for tensor in node.output_tensors:
if node.name in tensor.source_op:
tensor.source_op.remove(node.name)
tensor.source_op.append(source_op.name)
if source_op.op_type == 'Input':
tensor.source_op = []
tensor.dtype = "fp32"
tensor.shape = [-1, -1, -1, -1]
source_op.output_tensors.insert(idx, copy.deepcopy(tensor))
idx += 1
remove_list.append(node.name)
node_idx -= 1

model.remove_nodes(remove_list)
return model
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
'CastTo',

# GPT-J
'LowerAllTuples',
'TorchEmbedding',
'InnerproductReshapeFusion',
'MatMulWithTranspose',
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_node_name(node):
node_names[node] = name
return name

op_maps = {'aten::softmax': 'Softmax', 'prim::Constant': 'Constant', 'prim::ListConstruct': 'ListConstruct',
op_maps = {'aten::softmax': 'Softmax', 'prim::Constant': 'Constant',
'aten::linear': 'InnerProduct', 'aten::slice': 'Slice', 'aten::unsqueeze': 'Unsqueeze',
'aten::embedding': 'Gather', 'aten::where': 'Where', 'aten::matmul': 'Matmul', 'aten::gelu': 'Gelu',
'aten::layer_norm': 'LayerNorm', 'aten::size': 'Shape', 'aten::view': 'View',
Expand All @@ -43,7 +43,11 @@ def get_node_name(node):
'aten::rsub': 'Rsub', 'aten::mul': 'Mul', 'aten::add': 'Add', 'aten::add_': 'Add', 'aten::div': 'Div',
'aten::sub': 'Sub', 'aten::gt': 'Greater', 'aten::lt': 'Less', 'aten::eq': 'Equal', 'aten::ne': 'NotEqual',
'aten::quantize_per_tensor': 'Quantize', 'aten::dequantize': 'Dequantize',
'aten::padding_sequence': 'PaddingSequence'}
'aten::padding_sequence': 'PaddingSequence', 'aten::expand': 'Expand', 'aten::masked_fill': 'Masked_fill',
'aten::floor_divide': 'Floor_divide', 'aten::max': 'Max', 'aten::mean': 'Mean', 'aten::reshape': 'Reshape',
'aten::rsqrt': 'Rsqrt', 'aten::silu': 'Silu',
'prim::ListUnpack': 'ListUnpack', 'prim::ListConstruct': 'ListConstruct',
'prim::TupleUnpack': 'TupleUnpack', 'prim::TupleConstruct': 'TupleConstruct'}

def torch_extract_operator(node, model, nodes_dict, engine_graph=None):
"""Decorate the operator in Torch.
Expand Down

0 comments on commit e83a51f

Please sign in to comment.