Skip to content

Commit

Permalink
[Fixbug] Fix a bug in trace_from when the inputs are directly use…
Browse files Browse the repository at this point in the history
…d as outputs (#76)

* fix a bug in trace_from when the inputs are directly used as outputs

* .
  • Loading branch information
yaoyaoding committed Jan 20, 2023
1 parent ae3e4b4 commit 731c30b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 17 deletions.
52 changes: 38 additions & 14 deletions python/hidet/graph/ir/flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,20 +331,24 @@ def load(model_file: str) -> FlowGraph:
return ret

def update_nodes(self):
inputs, self.nodes, self.usage_count = self._analyze(self.outputs)
free_vars, self.nodes, self.usage_count = self._analyze(self.outputs)
if self.inputs:
if len(inputs) != len(self.inputs):
raise ValueError('Found {} symbol inputs, but {} given'.format(len(inputs), len(self.inputs)))
if any(all(a is not v for v in self.inputs) for a in inputs):
raise ValueError('There is a symbol tensor not given in inputs')
non_bound_free_vars: Set[Tensor] = set(free_vars) - set(self.inputs)
if len(non_bound_free_vars) > 0:
msg = ['There is free variable(s) not given in inputs:']
for v in non_bound_free_vars:
msg.append(' {}'.format(v.signature()))
raise ValueError('\n'.join(msg))
else:
if len(inputs) > 1:
if len(free_vars) > 1:
raise ValueError(
f'The traced graph has found {len(inputs)} symbol inputs. When there are multiple symbol inputs, \n'
'It is mandatory to specify the "inputs" argument explicitly when calling hidet.trace_from(...):\n'
' hidet.trace_from(..., inputs=[tensor1, tensor2, ...])\n'
f'The traced graph has found {len(free_vars)} free varaibles. '
f'When there are multiple free '
f'variables, it is mandatory to specify the "inputs" argument explicitly when calling '
f'hidet.trace_from(...):\n'
' hidet.trace_from(..., free_vars=[tensor1, tensor2, ...])\n'
)
self.inputs = inputs
self.inputs = free_vars
return self

def cuda_graph(self):
Expand Down Expand Up @@ -409,7 +413,27 @@ def latency(

@staticmethod
def _analyze(outputs: List[Tensor]) -> Tuple[List[Tensor], List[Operator], Dict[Tensor, int]]:
inputs = []
"""
Analyze the implicit flow graph by backwards traversing the graph from given outputs.
Parameters
----------
outputs: List[Tensor]
The outputs of the flow graph to traversing from.
Returns
-------
free_vars, nodes, usage_count: Tuple[List[Tensor], List[Operator], Dict[Tensor, int]]
The free variables, nodes and usage count of the flow graph.
The free variables are the free symbolic tensors that are not produced by any operators and do not contain
the non-None storage attribute.
The nodes are the operators that are used to produce the outputs, in topological order.
The usage count contains the number of times each tensor is used.
"""
free_vars = []
nodes: List[Operator] = []
# find out all nodes
all_nodes: Set[Operator] = set()
Expand Down Expand Up @@ -449,9 +473,9 @@ def find_all_nodes(u: Operator):
nodes.append(op)
for it in op.inputs:
if it.op is None:
if it.storage is None and all(it is not v for v in inputs):
if it.storage is None and all(it is not v for v in free_vars):
# input
inputs.append(it)
free_vars.append(it)
else:
out_degree[it.op] -= 1
if out_degree[it.op] == 0:
Expand All @@ -467,7 +491,7 @@ def find_all_nodes(u: Operator):
for graph_output in outputs:
usage_count[graph_output] += 1

return inputs, nodes, usage_count
return free_vars, nodes, usage_count


def trace_from(tensor: Union[Tensor, List[Tensor]], inputs: Optional[Union[Tensor, List[Tensor]]] = None) -> FlowGraph:
Expand Down
6 changes: 3 additions & 3 deletions python/hidet/ir/dtypes/floats.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def finfo(self) -> FloatInfo:
np.finfo(np.float16).min,
np.finfo(np.float16).max,
np.finfo(np.float16).eps,
np.finfo(np.float16).smallest_normal,
np.finfo(np.float16).tiny,
)
float32 = FloatType(
'float32',
Expand All @@ -114,7 +114,7 @@ def finfo(self) -> FloatInfo:
np.finfo(np.float32).min,
np.finfo(np.float32).max,
np.finfo(np.float32).eps,
np.finfo(np.float32).smallest_normal,
np.finfo(np.float32).tiny,
)
float64 = FloatType(
'float64',
Expand All @@ -123,7 +123,7 @@ def finfo(self) -> FloatInfo:
np.finfo(np.float64).min,
np.finfo(np.float64).max,
np.finfo(np.float64).eps,
np.finfo(np.float64).smallest_normal,
np.finfo(np.float64).tiny,
)
bfloat16 = FloatType('bfloat16', 'bf16', 2, -3.4e38, 3.4e38, None, None) # TODO: find correct values
tfloat32 = FloatType('tfloat32', 'tf32', 4, -3.4e38, 3.4e38, None, None)
Expand Down

0 comments on commit 731c30b

Please sign in to comment.