From 8f136488115985c76ff66749a52daee573c67b90 Mon Sep 17 00:00:00 2001 From: liukai Date: Wed, 10 Aug 2022 10:19:12 +0800 Subject: [PATCH] type -> basic_type --- mmrazor/structures/graph/module_graph.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/mmrazor/structures/graph/module_graph.py b/mmrazor/structures/graph/module_graph.py index b69dc59a5..23f0585e1 100644 --- a/mmrazor/structures/graph/module_graph.py +++ b/mmrazor/structures/graph/module_graph.py @@ -86,7 +86,8 @@ def in_channels(self) -> int: for node in self.prev_nodes ]) else: - raise NotImplementedError(f'unsupported node type: {self.type}') + raise NotImplementedError( + f'unsupported node type: {self.basic_type}') @property def out_channels(self) -> int: @@ -114,7 +115,8 @@ def out_channels(self) -> int: for node in self.prev_nodes ]) else: - raise NotImplementedError(f'unsupported node type: {self.type}') + raise NotImplementedError( + f'unsupported node type: {self.basic_type}') # other @@ -124,7 +126,7 @@ def __repr__(self) -> str: # node type @property - def type(self) -> str: + def basic_type(self) -> str: """The basic type of the node. Basic types are divided into seveval major types, detailed in @@ -156,21 +158,21 @@ def type(self) -> str: def is_pass_node(self): """pass node represent a module whose in-channels correspond out- channels one-to-one.""" - return self.type in ['bn', 'dwconv2d', 'pass_placeholder'] + return self.basic_type in ['bn', 'dwconv2d', 'pass_placeholder'] def is_cat_node(self): """cat node represents a cat module.""" - return self.type == 'cat_placeholder' + return self.basic_type == 'cat_placeholder' def is_bind_node(self): """bind node represent a node that has multiple inputs, and their channels are bound one-to-one.""" - return self.type == 'bind_placeholder' + return self.basic_type == 'bind_placeholder' def is_mix_node(self): """mix node represents a module that mixs all input channels and generete new output channels, such as conv and linear.""" - return self.type in ['conv2d', 'linear', 'gwconv2d'] + return self.basic_type in ['conv2d', 'linear', 'gwconv2d'] # check