From 7004e5664dd868affbbaeed035740921248ada4d Mon Sep 17 00:00:00 2001 From: liukai Date: Wed, 10 Aug 2022 10:24:29 +0800 Subject: [PATCH] fix some error --- mmrazor/structures/graph/base_graph.py | 2 +- mmrazor/structures/graph/module_graph.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mmrazor/structures/graph/base_graph.py b/mmrazor/structures/graph/base_graph.py index 7eceb1995..a7dba7e4f 100644 --- a/mmrazor/structures/graph/base_graph.py +++ b/mmrazor/structures/graph/base_graph.py @@ -90,7 +90,7 @@ def copy_from(cls, # connect for old in graph: - for pre in old.pre: + for pre in old.prev_nodes: new_graph.connect(old2new[pre], old2new[old]) return new_graph diff --git a/mmrazor/structures/graph/module_graph.py b/mmrazor/structures/graph/module_graph.py index 23f0585e1..ddbea0966 100644 --- a/mmrazor/structures/graph/module_graph.py +++ b/mmrazor/structures/graph/module_graph.py @@ -47,7 +47,7 @@ def __init__(self, >>> class Pool(nn.Module): def forward(x): return F.adaptive_avg_pool2d(x,2).flatten(1) - >>> node= ModuleNode('pass_0',Pool(),expadn_ratio=4) + >>> node= ModuleNode('pass_0',Pool(),expand_ratio=4) >>> assert node.out_channels == node.in_channels*4 """ @@ -207,6 +207,10 @@ def check_type(self): class ModuleGraph(BaseGraph[MODULENODE]): """Computatation Graph.""" + def __init__(self) -> None: + super().__init__() + self._model = None + # functions to generate module graph. @staticmethod