Skip to content

Commit

Permalink
Removed unused methods in _topology (#210)
Browse files Browse the repository at this point in the history
* Removed unused methods in _topology
* Make if-else statement concise
  • Loading branch information
Prabhat authored and xadupre committed Jul 3, 2019
1 parent 65951d5 commit a64cc5b
Showing 1 changed file with 5 additions and 133 deletions.
138 changes: 5 additions & 133 deletions skl2onnx/common/_topology.py
Expand Up @@ -112,10 +112,9 @@ def get_shape(tt):
elif elem == onnx_proto.TensorProto.INT32:
ty = Int32TensorType(shape)
else:
raise NotImplementedError("Unsupported type '{}' "
"(elem_type={}).".format(
type(obj.type.tensor_type),
elem))
raise NotImplementedError(
"Unsupported type '{}' (elem_type={}).".format(
type(obj.type.tensor_type), elem))
else:
raise NotImplementedError("Unsupported type '{}' as "
"a string ({}).".format(
Expand Down Expand Up @@ -259,16 +258,6 @@ def get_shape_calculator(self, model_type):
"""
return self.custom_shape_calculators.get(model_type, None)

def get_onnx_variable_name(self, seed):
"""
Retrieves the variable ID of the given seed or create one
if it is the first time of seeing this seed.
"""
if seed in self.variable_name_mapping:
return self.variable_name_mapping[seed][-1]
else:
return self.get_unique_variable_name(seed)

def get_unique_variable_name(self, seed):
"""
Creates a unique variable ID based on the given seed.
Expand All @@ -284,19 +273,6 @@ def get_unique_operator_name(self, seed):
"""
return Topology._generate_unique_name(seed, self.onnx_operator_names)

def find_sink_variables(self):
"""
Finds sink variables in this scope.
"""
# First we assume all variables are sinks
is_sink = {name: True for name in self.variables.keys()}
# Then, we remove those variables which are inputs of some operators
for operator in self.operators.values():
for variable in operator.inputs:
is_sink[variable.onnx_name] = False
return [variable for name, variable in self.variables.items()
if is_sink[name]]

def declare_local_variable(self, raw_name, type=None, prepend=False):
"""
This function may create a new variable in this scope. If
Expand All @@ -320,25 +296,6 @@ def declare_local_variable(self, raw_name, type=None, prepend=False):
self.variable_name_mapping[raw_name] = [onnx_name]
return variable

def get_local_variable_or_declare_one(self, raw_name, type=None):
"""
This function first checks if *raw_name* has been used to create
some variables. If yes, the latest one named in
``self.variable_name_mapping[raw_name]`` will be returned.
Otherwise, a new variable will be created and then returned.
"""
onnx_name = self.get_onnx_variable_name(raw_name)
if onnx_name in self.variables:
return self.variables[onnx_name]
else:
variable = Variable(raw_name, onnx_name, self.name, type)
self.variables[onnx_name] = variable
if raw_name in self.variable_name_mapping:
self.variable_name_mapping[raw_name].append(onnx_name)
else:
self.variable_name_mapping[raw_name] = [onnx_name]
return variable

def declare_local_operator(self, type, raw_model=None):
"""
This function is used to declare new local operator.
Expand Down Expand Up @@ -512,30 +469,6 @@ def unordered_variable_iterator(self):
for variable in scope.variables.values():
yield variable

def find_root_and_sink_variables(self):
"""
Finds root variables of the whole graph.
"""
# First we assume all variables are roots
is_root = {
name: True for scope in self.scopes
for name in scope.variables.keys()
}
# Then, we remove those variables which are outputs of some operators
for operator in self.unordered_operator_iterator():
for variable in operator.outputs:
is_root[variable.onnx_name] = False
is_sink = {
name: True for scope in self.scopes
for name in scope.variables.keys()
}
for operator in self.unordered_operator_iterator():
for variable in operator.inputs:
is_sink[variable.onnx_name] = False
return [variable for scope in self.scopes
for name, variable in scope.variables.items()
if is_root[name] or is_sink[name]]

def topological_operator_iterator(self):
"""
This is an iterator of all operators in Topology object.
Expand Down Expand Up @@ -578,62 +511,6 @@ def topological_operator_iterator(self):
if not is_evaluation_happened:
break

def rename_variable(self, old_name, new_name):
"""
Replaces the old ONNX variable name with a new ONNX variable
name. There are several fields we need to edit.
a. Topology
1. scopes (the scope where the specified ONNX variable was
declared)
2. variable_name_set
b. Scope
1. onnx_variable_names (a mirror of Topology's
variable_name_set)
2. variable_name_mapping
3. variables
:param old_name: a string
:param new_name: a string
"""
# Search for the first variable that is named as old_name.
scope, onnx_name, variable = next(
(scope, onnx_name, variable) for scope in self.scopes
for onnx_name, variable in scope.variables.items()
if onnx_name == old_name)

# Rename the variable we just found
variable.onnx_name = new_name

# Because the ONNX name of the targeted variable got changed,
# the (onnx_name, variable) pair in the associated scope's
# variable dictionary should be changed as well. We therefore
# create a new pair to replace the old pair.
scope.variables[new_name] = variable
del scope.variables[old_name]

# One original CoreML name may have several ONNX names recorded.
# To fix the record affected by renaming, we need to replace
# old_name with new_name in the record of the associated CoreML
# name (variable.raw_name). Note that derived_names contains
# all ONNX variable names derived from variable.raw_name.
derived_names = scope.variable_name_mapping[variable.raw_name]
for i in range(len(derived_names)):
# Find old_name in derived_names
if old_name != derived_names[i]:
continue
# Replace the recorded ONNX name with the new name
derived_names[i] = new_name
# Because ONNX names are unique so name replacement only
# happens once, we terminate the loop right after one name
# replacement.
break

# Finally, new_name takes the place of old_name in the set of
# all existing variable names
scope.onnx_variable_names.remove(old_name)
scope.onnx_variable_names.add(new_name)

def _check_structure(self):
"""
This function applies some rules to check if the parsed model is
Expand Down Expand Up @@ -683,13 +560,8 @@ def _initialize_graph_status_for_traversing(self):
for variable in self.unordered_variable_iterator():
# If root_names is set, we only set those variable to be
# fed. Otherwise, all roots would be fed.
if self.root_names:
if variable.onnx_name in self.root_names:
variable.is_fed = True
else:
variable.is_fed = False
else:
variable.is_fed = True
variable.is_fed = (False if self.root_names and variable.onnx_name
not in self.root_names else True)
variable.is_root = True
variable.is_leaf = True

Expand Down

0 comments on commit a64cc5b

Please sign in to comment.