diff --git a/mmrazor/models/mutators/channel_mutator/channel_mutator.py b/mmrazor/models/mutators/channel_mutator/channel_mutator.py index e0b77bbc3..f608231e8 100644 --- a/mmrazor/models/mutators/channel_mutator/channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/channel_mutator.py @@ -6,7 +6,7 @@ from torch.nn import Module from mmrazor.registry import MODELS, TASK_UTILS -from mmrazor.structures import ConcatNode, DepthWiseConvNode, PathList +from mmrazor.structures import PathConcatNode, PathDepthWiseConvNode, PathList from ...mutables import MutableChannel from ..base_mutator import BaseMutator from ..utils import DEFAULT_MODULE_CONVERTERS @@ -61,14 +61,14 @@ def add_link(self, path_list: PathList) -> None: for path in path_list: pre_node = None for node in path: - if isinstance(node, DepthWiseConvNode): + if isinstance(node, PathDepthWiseConvNode): module = self.name2module[node.name] # The in_channels and out_channels of a depth-wise conv # should be the same module.mutable_out.register_same_mutable(module.mutable_in) module.mutable_in.register_same_mutable(module.mutable_out) - if isinstance(node, ConcatNode): + if isinstance(node, PathConcatNode): if pre_node is not None: module_names = node.get_module_names() concat_modules = [ diff --git a/mmrazor/structures/graph/__init__.py b/mmrazor/structures/graph/__init__.py new file mode 100644 index 000000000..b22fe57d8 --- /dev/null +++ b/mmrazor/structures/graph/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_graph import BaseGraph, BaseNode +from .module_graph import ModuleGraph, ModuleNode + +__all__ = ['BaseGraph', 'BaseNode', 'ModuleNode', 'ModuleGraph'] diff --git a/mmrazor/structures/graph/base_graph.py b/mmrazor/structures/graph/base_graph.py new file mode 100644 index 000000000..a7dba7e4f --- /dev/null +++ b/mmrazor/structures/graph/base_graph.py @@ -0,0 +1,223 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""This module defines BaseNode and BaseGraph, which are used to model Directed +Acyclic Graph(DAG)""" +from collections import OrderedDict +from copy import copy +from typing import Any, Callable, Generic, Iterator, List, TypeVar + +# BaseNode && BaseGraph + + +class BaseNode: + """A single node in a graph.""" + + def __init__(self, name: str, val: Any) -> None: + """ + Args: + name (str): name of the node. + val (any): content of the node. + """ + self.val = val + self.name = name + self.prev_nodes: List = [] + self.next_nodes: List = [] + + # node operation + + def add_prev_node(self, node: 'BaseNode'): + """add previous node.""" + if node not in self.prev_nodes: + self.prev_nodes.append(node) + if self not in node.next_nodes: + node.next_nodes.append(self) + + def add_next_node(self, node: 'BaseNode'): + """add next node.""" + if node not in self.next_nodes: + self.next_nodes.append(node) + if self not in node.prev_nodes: + node.prev_nodes.append(self) + + @classmethod + def copy_from(cls, node: 'BaseNode'): + """Copy a node, and generate a new node with current node type.""" + return cls(node.name, node.val) + + # compare operation + + def __hash__(self) -> int: + """Hash the node.""" + return hash((self.val, self.name)) + + def __eq__(self, other): + """Compare two nodes.""" + return self.val is other.val and self.name == other.name + + # other + + def __repr__(self) -> str: + return self.name + + +BASENODE = TypeVar('BASENODE', bound=BaseNode) + + +class BaseGraph(Generic[BASENODE]): + """A Directed Acyclic Graph(DAG)""" + + def __init__(self) -> None: + super().__init__() + self.nodes: OrderedDict[str, BASENODE] = OrderedDict() + + # graph operations + + @classmethod + def copy_from(cls, + graph: 'BaseGraph', + node_converter: Callable = BaseNode.copy_from): + """Copy a graph, and generate a new graph of the current class. + + Args: + graph (BaseGraph): the graph to be copied. + node_converter (Callable): a function that converts node, + when coping graph. + """ + old2new = {} + new_graph = cls() + # copy nodes + for old in graph: + old2new[old] = new_graph.add_or_find_node(node_converter(old)) + + # connect + for old in graph: + for pre in old.prev_nodes: + new_graph.connect(old2new[pre], old2new[old]) + return new_graph + + # node operations + + def add_or_find_node(self, node: BASENODE): + """Add a node to the graph. + + If the node has exsited in the graph, the function will return the node + recorded in the graph. + """ + find = self.find_node(node) + if find is not None: + return find + else: + self.add_node(node) + return node + + def find_node(self, node: BaseNode): + """Find a node and return.""" + if node.name in self.nodes and node.val == self.nodes[node.name].val: + return self.nodes[node.name] + else: + return None + + def add_node(self, node: BASENODE): + """Add a node.""" + if node.name not in self.nodes: + self.nodes[node.name] = node + else: + raise BaseException(f'{node.name} already exists in graph') + + def connect(self, pre_node: BASENODE, next_node: BASENODE): + """Add an edge from pre_node to next_node.""" + assert pre_node in self and next_node in self + pre_node.add_next_node(next_node) + next_node.add_prev_node(pre_node) + + def disconnect(self, pre_node: BASENODE, next_node: BASENODE): + """Remove the edge form pre_node to next_node.""" + assert pre_node in self and next_node in self + if next_node in pre_node.next_nodes: + pre_node.next_nodes.remove(next_node) + if pre_node in next_node.prev_nodes: + next_node.prev_nodes.remove(pre_node) + + def delete_node(self, node: BASENODE): + """Delete a node with its related edges.""" + node = self.find_node(node) + assert node is not None + + if len(node.prev_nodes) == 0: + for next in copy(node.next_nodes): + self.disconnect(node, next) + elif len(node.next_nodes) == 0: + for pre in copy(node.prev_nodes): + self.disconnect(pre, node) + elif len(node.prev_nodes) == 1: + pre_node = node.prev_nodes[0] + self.disconnect(pre_node, node) + for next in copy(node.next_nodes): + self.disconnect(node, next) + self.connect(pre_node, next) + elif len(node.next_nodes) == 1: + next_node = node.next_nodes[0] + self.disconnect(node, next_node) + for pre in copy(node.prev_nodes): + self.disconnect(pre, node) + self.connect(pre, next_node) + else: + raise Exception(f'not delete {node}, \ + as it has more than one inputs and outputs') + self.nodes.pop(node.name) + + # work as a collection + + def __iter__(self) -> Iterator[BASENODE]: + """Traverse all nodes in the graph.""" + for x in self.nodes.values(): + yield x + + def __contains__(self, node: BASENODE) -> bool: + """Check if a node is contained in the graph.""" + return node.name in self.nodes + + def __len__(self) -> int: + """Number of nodes in the graph.""" + return len(self.nodes) + + # other + + def __repr__(self): + res = f'Graph with {len(self)} nodes:\n' + for node in self: + res += '{0:<40} -> {1:^40} -> {2:<40}\n'.format( + str(node.prev_nodes), node.__repr__(), str(node.next_nodes)) + return res + + # traverse + + def topo_traverse(self) -> Iterator[BASENODE]: + """Traverse the graph in topologitcal order.""" + + def _in_degree(graph: BaseGraph): + degree = {} + for name, node in graph.nodes.items(): + degree[name] = len(node.prev_nodes) + return degree + + def find_zero_degree_node(in_degree): + for node_name in in_degree: + if in_degree[node_name] == 0: + return node_name + return None + + in_degree = _in_degree(self) + + while len(in_degree) > 0: + node_name = find_zero_degree_node(in_degree) # visit the node + in_degree.pop(node_name) + yield self.nodes[node_name] + for next in self.nodes[node_name].next_nodes: + in_degree[next.name] -= 1 + + def topo_sort(self): + """Sort all node in topological order.""" + sorted_nodes = OrderedDict() + for node in self.topo_traverse(): + sorted_nodes[node.name] = node + self.nodes = sorted_nodes diff --git a/mmrazor/structures/graph/module_graph.py b/mmrazor/structures/graph/module_graph.py new file mode 100644 index 000000000..ddbea0966 --- /dev/null +++ b/mmrazor/structures/graph/module_graph.py @@ -0,0 +1,478 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""This module defines ModuleNode and ModuleGraph. + +They model the computation graph of a model based on BaseNode and BaseGraph +""" +import copy +from collections import OrderedDict +from typing import Dict, List, TypeVar, Union + +import torch.nn as nn +from torch.nn import Module + +from ..tracer.backward_tracer import BackwardTracer +# from ..tracer.fx_tracer import FxBaseNode, FxTracer +from ..tracer.loss_calculator import ImageClassifierPseudoLoss +from ..tracer.path import Path, PathConcatNode, PathList, PathNode +from .base_graph import BaseGraph, BaseNode + +# ModuleNode && ModuleGraph + + +class ModuleNode(BaseNode): + """A node in a computation graph. + + All nodes are divided to four types, the detail of definition can be found + in functions self.is_{xxx}_node. + """ + + pre_defined_node_val_str = [ + 'cat_placeholder', 'bind_placeholder', 'pass_placeholder' + ] + + def __init__(self, + name: str, + val: Union[Module, str], + expand_ratio: int = 1) -> None: + """ + Args: + name (str): the name of the node + val (Module | str): content of the node. It can be Module or + string. If val is a string, the string can only be one of + self.pre_defined_node_val_str + expand_ratio (int): expand_ratio is used in bind node, + where the out_channel is always a multiple of the in_channel. + Note: + Here, we give an example of expand_ratio. + >>> class Pool(nn.Module): + def forward(x): + return F.adaptive_avg_pool2d(x,2).flatten(1) + >>> node= ModuleNode('pass_0',Pool(),expand_ratio=4) + >>> assert node.out_channels == node.in_channels*4 + """ + + assert (isinstance(val, Module) + or val in self.__class__.pre_defined_node_val_str + ), f'{val} node is not allowed' + if expand_ratio != 1: + assert val == 'pass_placeholder', \ + 'expand != 1 is only valid when val=="pass"' + super().__init__(name, val) + self.expand_ratio = expand_ratio + + # channel + + @property + def in_channels(self) -> int: + """int: the in_channels of the node.""" + if isinstance(self.val, nn.Module): + MAPPING = { + nn.Conv2d: 'in_channels', + nn.modules.batchnorm._BatchNorm: 'num_features', + nn.modules.Linear: 'in_features', + } + for basetype in MAPPING: + if isinstance(self.val, basetype): + return getattr(self.val, MAPPING[basetype]) + raise NotImplementedError(f'unsupported module: {self.val}') + elif self.is_bind_node() or self.is_pass_node(): + if len(self.prev_nodes) > 0: + return self.prev_nodes[0].out_channels + else: + return 0 + elif self.is_cat_node(): + return sum([ + node.out_channels if node.out_channels is not None else 0 + for node in self.prev_nodes + ]) + else: + raise NotImplementedError( + f'unsupported node type: {self.basic_type}') + + @property + def out_channels(self) -> int: + """int: the out_channels of the node.""" + if isinstance(self.val, nn.Module): + MAPPING = { + nn.Conv2d: 'out_channels', + nn.modules.batchnorm._BatchNorm: 'num_features', + nn.modules.Linear: 'out_features', + } + for basetype in MAPPING: + if isinstance(self.val, basetype): + return getattr(self.val, MAPPING[basetype]) + raise NotImplementedError(f'unsupported module: {self.val}') + elif self.is_bind_node(): + if len(self.prev_nodes) > 0: + return self.prev_nodes[0].out_channels + else: + return 0 + elif self.is_pass_node(): + return self.in_channels * self.expand_ratio + elif self.is_cat_node(): + return sum([ + node.out_channels if node.out_channels is not None else 0 + for node in self.prev_nodes + ]) + else: + raise NotImplementedError( + f'unsupported node type: {self.basic_type}') + + # other + + def __repr__(self) -> str: + return f'{self.name}_({self.in_channels},{self.out_channels})' + + # node type + + @property + def basic_type(self) -> str: + """The basic type of the node. + + Basic types are divided into seveval major types, detailed in + self.is_{xxx}_node + """ + if isinstance(self.val, Module): + if isinstance(self.val, nn.Conv2d): + if self.val.groups == 1: + return 'conv2d' + elif self.val.groups == self.val.in_channels == \ + self.val.out_channels: + return 'dwconv2d' + else: + return 'gwconv2d' + elif isinstance(self.val, nn.modules.batchnorm._BatchNorm): + return 'bn' + elif isinstance(self.val, nn.Linear): + return 'linear' + else: + raise NotImplementedError(f'{self}') + else: + if self.val in [ + 'cat_placeholder', 'bind_placeholder', 'pass_placeholder' + ]: + return self.val + else: + raise NotImplementedError() + + def is_pass_node(self): + """pass node represent a module whose in-channels correspond out- + channels one-to-one.""" + return self.basic_type in ['bn', 'dwconv2d', 'pass_placeholder'] + + def is_cat_node(self): + """cat node represents a cat module.""" + return self.basic_type == 'cat_placeholder' + + def is_bind_node(self): + """bind node represent a node that has multiple inputs, and their + channels are bound one-to-one.""" + return self.basic_type == 'bind_placeholder' + + def is_mix_node(self): + """mix node represents a module that mixs all input channels and + generete new output channels, such as conv and linear.""" + return self.basic_type in ['conv2d', 'linear', 'gwconv2d'] + + # check + + def check_channel(self): + """Check if the channels of the node is matchable with previous nodes + and next nodes.""" + if self.is_cat_node(): + pass + else: + for pre in self.prev_nodes: + assert pre.out_channels == self.in_channels, \ + f'{self} has channel error' + + def check_type(self): + """Check if the node has right number of previous nodes according to + their type.""" + if self.is_pass_node(): + assert len(self.prev_nodes) <= 1, '{name} pass node error' + elif self.is_cat_node(): + pass + elif self.is_bind_node(): + assert len(self.prev_nodes) > 1, '{name} bind node error' + elif self.is_mix_node(): + assert len(self.prev_nodes) <= 1, '{name} mix node error' + else: + raise NotImplementedError(f'{self}') + + +MODULENODE = TypeVar('MODULENODE', bound=ModuleNode) + + +class ModuleGraph(BaseGraph[MODULENODE]): + """Computatation Graph.""" + + def __init__(self) -> None: + super().__init__() + self._model = None + + # functions to generate module graph. + + @staticmethod + def init_using_backward_tracer( + model: Module, + backward_tracer=BackwardTracer( + loss_calculator=ImageClassifierPseudoLoss()), + ): + """init module graph using backward tracer.""" + path_lists = backward_tracer.trace(model) + converter = PathToGraphConverter(path_lists, model) + return converter.graph + + @staticmethod + def init_using_fx_tracer(model: Module, is_extra_leaf_module=None): + """init module graph using torch fx tracer.""" + pass + + @staticmethod + def init_from_model(model: Module): + """init module graph from a model which uses connect_module to record + the relation among modules.""" + pass + + # check + + def check(self): + """Check if the graph is valid.""" + for node in self: + node.check_channel() + node.check_type() + + # static method for models that can't use tracer + + @staticmethod + def connect_module(pre: Module, next: Module): + """This function is used to write hardcode in modules to generate Graph + object using init_from_model.""" + if hasattr(pre, '_next'): + _next = getattr(pre, '_next') + assert isinstance(_next, List) + else: + pre._next = set() + pre._next.add(next) + + if hasattr(next, '_pre'): + _pre = getattr(next, '_pre') + assert isinstance(_pre, List) + else: + next._pre = set() + next._pre.add(pre) + + +# Converter + + +class GraphConverter: + """Base class for converters for ModuleGraph.""" + + def __init__(self) -> None: + self.graph = ModuleGraph[ModuleNode]() + self.cat_placeholder_num = 0 + self.bind_placeholder_num = 0 + self.pass_placeholder_num = 0 + + # add node + + def _new_placeholder_node(self, type: str, expand_ratio=1): + """New cat/bind/pass node.""" + assert type in [ + 'cat_placeholder', 'pass_placeholder', 'bind_placeholder' + ] + if expand_ratio != 1: + assert type == 'pass_placeholder' + if type == 'cat_placeholder': + num = self.cat_placeholder_num + self.cat_placeholder_num += 1 + elif type == 'pass_placeholder': + num = self.pass_placeholder_num + self.pass_placeholder_num += 1 + elif type == 'bind_placeholder': + num = self.bind_placeholder_num + self.bind_placeholder_num += 1 + else: + pass + node = ModuleNode(f'{type}_{num}', type, expand_ratio=expand_ratio) + self.graph.add_or_find_node(node) + return node + + # insert nodes + + def _insert_node_before(self, node: ModuleNode, new_node: ModuleNode): + """Insert a new node before a node.""" + for pre in node.prev_nodes: + self.graph.connect(pre, new_node) + for pre in new_node.prev_nodes: + self.graph.disconnect(pre, node) + self.graph.connect(new_node, node) + + def _insert_bind_nodes(self): + """Add bind nodes before the nodes which only need one previous node + but have more than one.""" + + need_bind_nodes = [] + for node in self.graph: + if (isinstance(node.val, nn.Conv2d) + or isinstance(node.val, nn.Linear) + or isinstance(node.val, nn.modules.batchnorm._BatchNorm)): + if len(node.prev_nodes) > 1: + need_bind_nodes.append(node) + for node in need_bind_nodes: + bind_node = self._new_placeholder_node('bind_placeholder') + self._insert_node_before(node, bind_node) + + def _insert_pass_nodes(self): + """Add pass nodes where the channel conflict.""" + for node in copy.copy(list(self.graph.nodes.values())): + if len(node.prev_nodes) == 1: + pre: ModuleNode = node.prev_nodes[0] + if node.in_channels != pre.out_channels: + assert node.in_channels % pre.out_channels == 0 + pass_node = self._new_placeholder_node( + 'pass_placeholder', + node.in_channels // pre.out_channels) + self._insert_node_before(node, pass_node) + + def _remove_redundant_pass_nodes(self): + """Remove redundant pass nodes, which do not change number of channels + and do not represent any module.""" + for node in copy.copy(list(self.graph.nodes.values())): + if (node.is_pass_node() and len(node.prev_nodes) == 1 + and len(node.next_nodes) == 1 + and not isinstance(node.val, nn.Module) + and node.in_channels == node.out_channels): + self.graph.delete_node(node) + + # topo_rename_nodes + def _topo_rename(self): + """Rename cat, bind, pass nodes in topological order.""" + self.cat_placeholder_num = 0 + self.bind_placeholder_num = 0 + self.pass_placeholder_num = 0 + sorted_nodes = OrderedDict() + for node in self.graph.topo_traverse(): + node: ModuleNode + if isinstance(node.val, Module): + pass + elif node.is_pass_node(): + node.name = f'pass_{self.pass_placeholder_num}' + self.pass_placeholder_num += 1 + elif node.is_cat_node(): + node.name = f'cat_{self.cat_placeholder_num}' + self.cat_placeholder_num += 1 + elif node.is_bind_node(): + node.name = f'bind_{self.bind_placeholder_num}' + self.bind_placeholder_num += 1 + else: + pass + sorted_nodes[node.name] = node + self.graph.nodes = sorted_nodes + + # other + def _post_process(self): + """Some post process after init a basic module graph.""" + self._remove_redundant_pass_nodes() + self._insert_bind_nodes() + self._insert_pass_nodes() + self._topo_rename() + + +class PathToGraphConverter(GraphConverter): + """The class converts pathlist, which is generated by backward tracer, to a + module graph.""" + + def __init__(self, path_list: PathList, model: Module) -> None: + """ + Args: + path_list (PathList): path_list generated by backward tracer. + model (Module): the model corresponding to the path_list + """ + super().__init__() + self.path_list = path_list + self.cat_dict: Dict[str, str] = {} + self.name2module = dict(model.named_modules()) + self._pass(self.path_list) + + self._post_process() + + def _pass(self, path_list: PathList): + """Parse path list.""" + self._parse_helper(path_list, []) + + def _parse_helper(self, path_unit: Union[PathList, Path, PathNode], + next_nodes: List[ModuleNode]): + """Parse a node(unit) in path list.""" + current_node = None + # path_list + if isinstance(path_unit, PathList): + for single_path in path_unit: # sibling + self._parse_helper(single_path, next_nodes) + + # path: + elif isinstance(path_unit, Path): + current_nexts = next_nodes + for node in path_unit: # parent -> children + current_node = self._parse_helper(node, current_nexts) + current_nexts = [current_node] + + # Node + elif isinstance(path_unit, PathNode): + + # cat node: [cat_path_lists] + if isinstance(path_unit, PathConcatNode): + current_node = self._add_or_find_node(path_unit) + self._connect_nexts(current_node, next_nodes) + for catpath in path_unit.path_lists: # sibling + self._parse_helper(catpath, [current_node]) + + # single node + else: + current_node = self._add_or_find_node(path_unit) + self._connect_nexts(current_node, next_nodes) + return current_node + + def _add_or_find_cat_node(self, pathnode: PathConcatNode): + """Receive a cat-node. + + If the cat-node exists in the graph, the corresponding node is + returned, or a new cat node is added to the graph. + """ + + def unify_cat_name(name: str): + cat_name = name.split('_') + inputs = sorted(cat_name[1:]) + return f"cat_{'_'.join(inputs)}" + + name_id = pathnode.name + name_id = unify_cat_name(name_id) + if name_id in self.cat_dict: + name = self.cat_dict[name_id] + else: + name = f'cat_{self.cat_placeholder_num}' + self.cat_placeholder_num += 1 + self.cat_dict[name_id] = name + node = self.graph.add_or_find_node(ModuleNode(name, 'cat_placeholder')) + return node + + def _add_or_find_node(self, pathnode: PathNode) -> Module: + """Receive a cat-node. + + If the cat-node exists in the graph, the corresponding node is + returned, or a new cat node is added to the graph. + """ + if isinstance(pathnode, PathConcatNode): + return self._add_or_find_cat_node(pathnode) + else: + name = pathnode.name + assert name in self.name2module, f"{name} doesn't exist in model" + module = self.name2module[name] + return self.graph.add_or_find_node(ModuleNode(name, module)) + + def _connect_nexts(self, node, nexts: List[ModuleNode]): + """Connext the node and the nodes in nexts.""" + for next in nexts: + self.graph.connect(node, next) diff --git a/mmrazor/structures/tracer/__init__.py b/mmrazor/structures/tracer/__init__.py index 4b2868cc5..a9a6fde52 100644 --- a/mmrazor/structures/tracer/__init__.py +++ b/mmrazor/structures/tracer/__init__.py @@ -2,10 +2,10 @@ from .backward_tracer import BackwardTracer from .loss_calculator import * # noqa: F401,F403 from .parsers import * # noqa: F401,F403 -from .path import (ConcatNode, ConvNode, DepthWiseConvNode, LinearNode, Node, - NormNode, Path, PathList) +from .path import (Path, PathConcatNode, PathConvNode, PathDepthWiseConvNode, + PathLinearNode, PathList, PathNode, PathNormNode) __all__ = [ - 'BackwardTracer', 'ConvNode', 'LinearNode', 'NormNode', 'ConcatNode', - 'Path', 'PathList', 'Node', 'DepthWiseConvNode' + 'BackwardTracer', 'PathConvNode', 'PathLinearNode', 'PathNormNode', + 'PathConcatNode', 'Path', 'PathList', 'PathNode', 'PathDepthWiseConvNode' ] diff --git a/mmrazor/structures/tracer/parsers.py b/mmrazor/structures/tracer/parsers.py index 55d0ec7fd..c5994a27a 100644 --- a/mmrazor/structures/tracer/parsers.py +++ b/mmrazor/structures/tracer/parsers.py @@ -2,8 +2,8 @@ import copy from typing import Callable, Dict -from .path import (ConcatNode, ConvNode, DepthWiseConvNode, LinearNode, - NormNode, Path, PathList) +from .path import (Path, PathConcatNode, PathConvNode, PathDepthWiseConvNode, + PathLinearNode, PathList, PathNormNode) def _is_leaf_grad_fn(grad_fn): @@ -40,9 +40,9 @@ def parse_conv(tracer, grad_fn, module2name, param2module, cur_path, name = module2name[module] parent = grad_fn.next_functions[0][0] if module.in_channels == module.groups: - cur_path.append(DepthWiseConvNode(name)) + cur_path.append(PathDepthWiseConvNode(name)) else: - cur_path.append(ConvNode(name)) + cur_path.append(PathConvNode(name)) # If a module is not a shared module and it has been visited during # forward, its parent modules must have been traced already. # However, a shared module will be visited more than once during @@ -87,7 +87,7 @@ def parse_linear(tracer, grad_fn, module2name, param2module, cur_path, name = module2name[module] parent = grad_fn.next_functions[-2][0] - cur_path.append(LinearNode(name)) + cur_path.append(PathLinearNode(name)) # If a module is not a shared module and it has been visited during # forward, its parent modules must have been traced already. # However, a shared module will be visited more than once during @@ -135,7 +135,7 @@ def parse_cat(tracer, grad_fn, module2name, param2module, cur_path, tracer.backward_trace(parent, module2name, param2module, Path(), sub_path_list, visited, shared_module) sub_path_lists.append(sub_path_list) - cur_path.append(ConcatNode(name, sub_path_lists)) + cur_path.append(PathConcatNode(name, sub_path_lists)) result_paths.append(copy.deepcopy(cur_path)) cur_path.pop(-1) @@ -165,7 +165,7 @@ def parse_norm(tracer, grad_fn, module2name, param2module, cur_path, module = param2module[param_id] name = module2name[module] parent = grad_fn.next_functions[0][0] - cur_path.append(NormNode(name)) + cur_path.append(PathNormNode(name)) visited[name] = True tracer.backward_trace(parent, module2name, param2module, cur_path, diff --git a/mmrazor/structures/tracer/path.py b/mmrazor/structures/tracer/path.py index 25cac1322..c6597703f 100644 --- a/mmrazor/structures/tracer/path.py +++ b/mmrazor/structures/tracer/path.py @@ -25,7 +25,7 @@ def _merge_node_parents(node2parents, _node2parents): node2parents[node] = parents -class Node: +class PathNode: """``Node`` is the data structure that represents individual instances within a ``Path``. It corresponds to a module or an operation such as concatenation in the model. @@ -61,24 +61,24 @@ def __repr__(self): return f'{self._get_class_name()}(\'{self.name}\')' -class ConvNode(Node): +class PathConvNode(PathNode): """A `ConvNode` corresponds to a Conv module in the original model.""" pass -class DepthWiseConvNode(Node): +class PathDepthWiseConvNode(PathNode): """A `DepthWiseConvNode` corresponds to a depth-wise conv module in the original model.""" pass -class NormNode(Node): +class PathNormNode(PathNode): """A `NormNode` corresponds to a normalization module in the original model.""" pass -class LinearNode(Node): +class PathLinearNode(PathNode): """A `LinearNode` corresponds to a linear module in the original model.""" pass @@ -92,14 +92,15 @@ class Path: Default to None. """ - def __init__(self, nodes: Optional[Union[Node, List[Node]]] = None): - self._nodes: List[Node] = list() + def __init__(self, + nodes: Optional[Union[PathNode, List[PathNode]]] = None): + self._nodes: List[PathNode] = list() if nodes is not None: - if isinstance(nodes, Node): + if isinstance(nodes, PathNode): nodes = [nodes] assert isinstance(nodes, (list, tuple)) for node in nodes: - assert isinstance(node, Node) + assert isinstance(node, PathNode) self._nodes.append(node) def get_root_names(self) -> List[str]: @@ -117,11 +118,12 @@ def find_nodes_parents(self, non_pass (Tuple): Ancestor nodes whose types are one of `non_pass` are the parents of a specific node. Default to None. """ - node2parents: Dict[str, List[Node]] = dict() + node2parents: Dict[str, List[PathNode]] = dict() for i, node in enumerate(self._nodes): - if isinstance(node, ConcatNode): - _node2parents: Dict[str, List[Node]] = node.find_nodes_parents( - target_nodes, non_pass) + if isinstance(node, PathConcatNode): + _node2parents: Dict[str, + List[PathNode]] = node.find_nodes_parents( + target_nodes, non_pass) _merge_node_parents(node2parents, _node2parents) continue @@ -140,9 +142,9 @@ def nodes(self) -> List: """Return a list of nodes in the current path.""" return self._nodes - def append(self, x: Node) -> None: + def append(self, x: PathNode) -> None: """Add a node to the end of the current path.""" - assert isinstance(x, Node) + assert isinstance(x, PathNode) self._nodes.append(x) def pop(self, *args, **kwargs): @@ -227,7 +229,7 @@ def find_nodes_parents(self, non_pass (Tuple): Ancestor nodes whose types are one of `non_pass` are the parents of a specific node. Default to None. """ - node2parents: Dict[str, List[Node]] = dict() + node2parents: Dict[str, List[PathNode]] = dict() for p in self._paths: _node2parents = p.find_nodes_parents(target_nodes, non_pass) _merge_node_parents(node2parents, _node2parents) @@ -278,7 +280,7 @@ def __repr__(self): return main_str -class ConcatNode(Node): +class PathConcatNode(PathNode): """``ConcatNode`` is the data structure that represents the concatenation operation in a model. @@ -317,7 +319,7 @@ def find_nodes_parents(self, non_pass (Tuple): Ancestor nodes whose types are one of `non_pass` are the parents of a specific node. Default to None. """ - node2parents: Dict[str, List[Node]] = dict() + node2parents: Dict[str, List[PathNode]] = dict() for p in self._path_lists: _node2parents = p.find_nodes_parents(target_nodes, non_pass) _merge_node_parents(node2parents, _node2parents) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..ddce77790 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .test_core.test_graph.test_graph import TestGraph + +__all__ = ['TestGraph'] diff --git a/tests/data/models.py b/tests/data/models.py new file mode 100644 index 000000000..dd328b516 --- /dev/null +++ b/tests/data/models.py @@ -0,0 +1,399 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch.nn import Module +from torch import Tensor +import torch.nn as nn +import torch + +# this file includes models for tesing. + + +class MultiConcatModel(Module): + """ + x---------------- + |op1 |op2 |op4 + x1 x2 x4 + | | | + |cat----- | + cat1 | + |op3 | + x3 | + |cat------------- + cat2 + |avg_pool + x_pool + |fc + output + """ + + def __init__(self) -> None: + super().__init__() + + self.op1 = nn.Conv2d(3, 8, 1) + self.op2 = nn.Conv2d(3, 8, 1) + self.op3 = nn.Conv2d(16, 8, 1) + self.op4 = nn.Conv2d(3, 8, 1) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(16, 1000) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.op1(x) + x2 = self.op2(x) + cat1 = torch.cat([x1, x2], dim=1) + x3 = self.op3(cat1) + x4 = self.op4(x) + cat2 = torch.cat([x3, x4], dim=1) + x_pool = self.avg_pool(cat2).flatten(1) + output = self.fc(x_pool) + + return output + + +class MultiConcatModel2(Module): + """ + x--------------- + |op1 |op2 |op3 + x1 x2 x3 + | | | + |cat----- | + cat1 | + |cat------------- + cat2 + |op4 + x4 + |avg_pool + x_pool + |fc + output + """ + + def __init__(self) -> None: + super().__init__() + + self.op1 = nn.Conv2d(3, 8, 1) + self.op2 = nn.Conv2d(3, 8, 1) + self.op3 = nn.Conv2d(3, 8, 1) + self.op4 = nn.Conv2d(24, 8, 1) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(8, 1000) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.op1(x) + x2 = self.op2(x) + x3 = self.op3(x) + cat1 = torch.cat([x1, x2], dim=1) + cat2 = torch.cat([cat1, x3], dim=1) + x4 = self.op4(cat2) + + x_pool = self.avg_pool(x4).reshape([x4.shape[0], -1]) + output = self.fc(x_pool) + + return output + + +class ConcatModel(Module): + """ + x------------ + |op1,bn1 |op2,bn2 + x1 x2 + |cat--------| + cat1 + |op3 + x3 + |avg_pool + x_pool + |fc + output + """ + + def __init__(self) -> None: + super().__init__() + + self.op1 = nn.Conv2d(3, 8, 1) + self.bn1 = nn.BatchNorm2d(8) + self.op2 = nn.Conv2d(3, 8, 1) + self.bn2 = nn.BatchNorm2d(8) + self.op3 = nn.Conv2d(16, 8, 1) + + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(8, 1000) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.bn1(self.op1(x)) + x2 = self.bn2(self.op2(x)) + cat1 = torch.cat([x1, x2], dim=1) + x3 = self.op3(cat1) + + x_pool = self.avg_pool(x3).flatten(1) + output = self.fc(x_pool) + + return output + + +class ResBlock(Module): + """ + x + |op1,bn1 + x1----------- + |op2,bn2 | + x2 | + +------------ + |op3 + x3 + |avg_pool + x_pool + |fc + output + """ + + def __init__(self) -> None: + super().__init__() + + self.op1 = nn.Conv2d(3, 8, 1) + self.bn1 = nn.BatchNorm2d(8) + self.op2 = nn.Conv2d(8, 8, 1) + self.bn2 = nn.BatchNorm2d(8) + self.op3 = nn.Conv2d(8, 8, 1) + + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(8, 1000) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.bn1(self.op1(x)) + x2 = self.bn2(self.op2(x1)) + x3 = self.op3(x2 + x1) + x_pool = self.avg_pool(x3).flatten(1) + output = self.fc(x_pool) + return output + + +class LineModel(Module): + """ + x + |net0,net1 + |net2 + |net3 + x1 + |fc + output + """ + + def __init__(self) -> None: + super().__init__() + self.net = nn.Sequential( + nn.Conv2d(3, 8, 3, 1, 1), nn.BatchNorm2d(8), nn.ReLU(), + nn.Conv2d(8, 16, 3, 1, 1), nn.BatchNorm2d(16), + nn.AdaptiveAvgPool2d(1)) + self.linear = nn.Linear(16, 1000) + + def forward(self, x): + x1 = self.net(x) + x1 = x1.reshape([x1.shape[0], -1]) + return self.linear(x1) + + +class AddCatModel(Module): + """ + x------------------------ + |op1 |op2 |op3 |op4 + x1 x2 x3 x4 + | | | | + |cat----- |cat----- + cat1 cat2 + | | + +---------------- + x5 + |avg_pool + x_pool + |fc + y + """ + + def __init__(self) -> None: + super().__init__() + self.op1 = nn.Conv2d(3, 2, 3) + self.op2 = nn.Conv2d(3, 6, 3) + self.op3 = nn.Conv2d(3, 4, 3) + self.op4 = nn.Conv2d(3, 4, 3) + self.op5 = nn.Conv2d(8, 16, 3) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(16, 1000) + + def forward(self, x): + x1 = self.op1(x) + x2 = self.op2(x) + x3 = self.op3(x) + x4 = self.op4(x) + + cat1 = torch.cat((x1, x2), dim=1) + cat2 = torch.cat((x3, x4), dim=1) + x5 = self.op5(cat1 + cat2) + x_pool = self.avg_pool(x5).flatten(1) + y = self.fc(x_pool) + return y + + +class GroupWiseConvModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.op1 = nn.Conv2d(3, 8, 3, 1, 1) + self.bn1 = nn.BatchNorm2d(8) + self.op2 = nn.Conv2d(8, 16, 3, 1, 1, groups=2) + self.bn2 = nn.BatchNorm2d(16) + self.op3 = nn.Conv2d(16, 32, 3, 1, 1) + + def forward(self, x): + x1 = self.op1(x) + x1 = self.bn1(x1) + x2 = self.op2(x1) + x2 = self.bn2(x2) + x3 = self.op3(x2) + x_pool = self.avg_pool(x3).flatten(1) + return self.fc(x_pool) + + +class Xmodel(nn.Module): + """ + x-------- + |op1 |op2 + x1 x2 + | | + +-------- + x12------ + |op3 |op4 + x3 x4 + | | + +-------- + x34 + |avg_pool + x_pool + |fc + y + """ + + def __init__(self) -> None: + super().__init__() + self.op1 = nn.Conv2d(3, 8, 3, 1, 1) + self.op2 = nn.Conv2d(3, 8, 3, 1, 1) + self.op3 = nn.Conv2d(8, 16, 3, 1, 1) + self.op4 = nn.Conv2d(8, 16, 3, 1, 1) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(16, 1000) + + def forward(self, x): + x1 = self.op1(x) + x2 = self.op2(x) + x12 = x1 * x2 + x3 = self.op3(x12) + x4 = self.op4(x12) + x34 = x3 + x4 + x_pool = self.avg_pool(x34).flatten(1) + return self.fc(x_pool) + + +class MultipleUseModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.conv0 = nn.Conv2d(3, 8, 3, 1, 1) + self.conv1 = nn.Conv2d(3, 8, 3, 1, 1) + self.conv2 = nn.Conv2d(3, 8, 3, 1, 1) + self.conv3 = nn.Conv2d(3, 8, 3, 1, 1) + self.conv_multiple_use = nn.Conv2d(8, 16, 3, 1, 1) + self.conv_last = nn.Conv2d(16, 32, 3, 1, 1) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.linear = nn.Linear(32, 1000) + + def forward(self, x): + xs = [ + conv(x) + for conv in [self.conv0, self.conv1, self.conv2, self.conv3] + ] + xs_ = [self.conv_multiple_use(x_) for x_ in xs] + x_sum = 0 + for x_ in xs_: + x_sum = x_sum + x_ + feature = self.conv_last(x_sum) + pool = self.avg_pool(feature).flatten(1) + return self.linear(pool) + + +default_models = [ + LineModel, ResBlock, AddCatModel, ConcatModel, MultiConcatModel, + MultiConcatModel2, GroupWiseConvModel, Xmodel, MultipleUseModel +] + + +class ModelLibrary: + + # includes = [ + # 'alexnet', # pass + # 'densenet', # pass + # # 'efficientnet', # pass + # # 'googlenet', # pass. + # # googlenet return a tuple when training, + # # so it should trace in eval mode + # # 'inception', # failed + # # 'mnasnet', # pass + # # 'mobilenet', # pass + # # 'regnet', # failed + # # 'resnet', # pass + # # 'resnext', # failed + # # 'shufflenet', # failed + # # 'squeezenet', # pass + # # 'vgg', # pass + # # 'wide_resnet', # pass + # ] + + def __init__(self, include=[]) -> None: + + self.include_key = include + + self.model_creator = self.get_torch_models() + + def __repr__(self) -> str: + s = f'model: {len(self.model_creator)}\n' + for creator in self.model_creator: + s += creator.__name__ + '\n' + return s + + def get_torch_models(self): + from inspect import isfunction + + import torchvision + + attrs = dir(torchvision.models) + models = [] + for name in attrs: + module = getattr(torchvision.models, name) + if isfunction(module): + models.append(module) + return models + + def export_models(self): + models = [] + for creator in self.model_creator: + if self.is_include(creator.__name__): + models.append(creator) + return models + + def is_include(self, name): + for key in self.include_key: + if key in name: + return True + return False + + def include(self): + include = [] + for creator in self.model_creator: + for key in self.include_key: + if key in creator.__name__: + include.append(creator) + return include + + def uninclude(self): + include = self.include() + uninclude = [] + for creator in self.model_creator: + if creator not in include: + uninclude.append(creator) + return uninclude diff --git a/tests/test_core/test_graph/test_graph.py b/tests/test_core/test_graph/test_graph.py new file mode 100644 index 000000000..0bb49e434 --- /dev/null +++ b/tests/test_core/test_graph/test_graph.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import sys +from unittest import TestCase + +import torch + +from mmrazor.structures.graph import ModuleGraph +from ...data.models import (AddCatModel, ConcatModel, LineModel, + MultiConcatModel, MultiConcatModel2, ResBlock) + +sys.setrecursionlimit(int(1e8)) + + +class ToyCNNPseudoLoss: + + def __call__(self, model): + pseudo_img = torch.rand(2, 3, 16, 16) + pseudo_output = model(pseudo_img) + return pseudo_output.sum() + + +DATA = [ + { + 'model': LineModel, + 'num_nodes': 5, + }, + { + 'model': ResBlock, + 'num_nodes': 7, + }, + { + 'model': ConcatModel, + 'num_nodes': 7, + }, + { + 'model': MultiConcatModel2, + 'num_nodes': 7, + }, + { + 'model': MultiConcatModel, + 'num_nodes': 7, + }, + { + 'model': AddCatModel + }, +] + + +class TestGraph(TestCase): + + def test_graph_init(self) -> None: + + for data in DATA: + with self.subTest(data=data): + model = data['model']() + # print(model) + graphs = [ + ModuleGraph.init_using_backward_tracer(model), + ] + + unit_num = len(graphs[0].nodes) + + for graph in graphs: + + # check channels + try: + graph.check() + except Exception as e: + self.fail(str(e) + '\n' + str(graph)) + + # check number of nodes + self.assertEqual(unit_num, len(graph.nodes)) + if 'num_nodes' in data: + self.assertEqual( + len(graph), + data['num_nodes'], + msg=f'{graph.nodes}') diff --git a/tests/test_core/test_tracer/test_backward_tracer.py b/tests/test_core/test_tracer/test_backward_tracer.py index 9b7bda7ac..f63ff0370 100644 --- a/tests/test_core/test_tracer/test_backward_tracer.py +++ b/tests/test_core/test_tracer/test_backward_tracer.py @@ -6,12 +6,12 @@ from torch import Tensor, nn from torch.nn import Module -from mmrazor.structures import (BackwardTracer, ConcatNode, ConvNode, - DepthWiseConvNode, LinearNode, NormNode, Path, - PathList) +from mmrazor.structures import (BackwardTracer, Path, PathConcatNode, + PathConvNode, PathDepthWiseConvNode, + PathLinearNode, PathList, PathNormNode) -NONPASS_NODES = (ConvNode, LinearNode, ConcatNode) -PASS_NODES = (NormNode, DepthWiseConvNode) +NONPASS_NODES = (PathConvNode, PathLinearNode, PathConcatNode) +PASS_NODES = (PathNormNode, PathDepthWiseConvNode) class MultiConcatModel(Module): @@ -98,23 +98,23 @@ def test_trace_resblock(self) -> None: nonpass2parents = path_list.find_nodes_parents(NONPASS_NODES) assert len(nonpass2parents) == 3 assert nonpass2parents['op1'] == list() - assert nonpass2parents['op2'] == list({NormNode('bn1')}) + assert nonpass2parents['op2'] == list({PathNormNode('bn1')}) assert nonpass2parents['op3'] == list( - {NormNode('bn2'), NormNode('bn1')}) + {PathNormNode('bn2'), PathNormNode('bn1')}) nonpass2nonpassparents = path_list.find_nodes_parents( NONPASS_NODES, non_pass=NONPASS_NODES) assert len(nonpass2parents) == 3 assert nonpass2nonpassparents['op1'] == list() - assert nonpass2nonpassparents['op2'] == list({ConvNode('op1')}) + assert nonpass2nonpassparents['op2'] == list({PathConvNode('op1')}) assert nonpass2nonpassparents['op3'] == list( - {ConvNode('op2'), ConvNode('op1')}) + {PathConvNode('op2'), PathConvNode('op1')}) pass2nonpassparents = path_list.find_nodes_parents( PASS_NODES, non_pass=NONPASS_NODES) assert len(pass2nonpassparents) == 2 - assert pass2nonpassparents['bn1'] == list({ConvNode('op1')}) - assert pass2nonpassparents['bn2'] == list({ConvNode('op2')}) + assert pass2nonpassparents['bn1'] == list({PathConvNode('op1')}) + assert pass2nonpassparents['bn2'] == list({PathConvNode('op2')}) def test_trace_multi_cat(self) -> None: loss_calculator = ToyCNNPseudoLoss() @@ -129,11 +129,11 @@ def test_trace_multi_cat(self) -> None: assert len(nonpass2parents) == 4 assert nonpass2parents['op1'] == list() assert nonpass2parents['op2'] == list() - path_list1 = PathList(Path(ConvNode('op1'))) - path_list2 = PathList(Path(ConvNode('op2'))) + path_list1 = PathList(Path(PathConvNode('op1'))) + path_list2 = PathList(Path(PathConvNode('op2'))) # only one parent assert len(nonpass2parents['op3']) == 1 - assert isinstance(nonpass2parents['op3'][0], ConcatNode) + assert isinstance(nonpass2parents['op3'][0], PathConcatNode) assert len(nonpass2parents['op3'][0]) == 2 assert nonpass2parents['op3'][0].get_module_names() == ['op1', 'op2'] assert nonpass2parents['op3'][0].path_lists == [path_list1, path_list2] @@ -152,29 +152,31 @@ def test_trace_multi_cat(self) -> None: assert nonpass2parents['op3'] == list() # only one parent assert len(nonpass2parents['op4']) == 1 - assert isinstance(nonpass2parents['op4'][0], ConcatNode) + assert isinstance(nonpass2parents['op4'][0], PathConcatNode) assert nonpass2parents['op4'][0].get_module_names() == [ 'op1', 'op2', 'op3' ] def test_repr(self): - toy_node = ConvNode('op1') - assert repr(toy_node) == 'ConvNode(\'op1\')' + toy_node = PathConvNode('op1') + assert repr(toy_node) == 'PathConvNode(\'op1\')' - toy_path = Path([ConvNode('op1'), ConvNode('op2')]) + toy_path = Path([PathConvNode('op1'), PathConvNode('op2')]) assert repr( - toy_path) == 'Path(\n ConvNode(\'op1\'),\n ConvNode(\'op2\')\n)' + toy_path + ) == 'Path(\n PathConvNode(\'op1\'),\n PathConvNode(\'op2\')\n)' - toy_path_list = PathList(Path(ConvNode('op1'))) - assert repr(toy_path_list - ) == 'PathList(\n Path(\n ConvNode(\'op1\')\n )\n)' + toy_path_list = PathList(Path(PathConvNode('op1'))) + assert repr( + toy_path_list + ) == 'PathList(\n Path(\n PathConvNode(\'op1\')\n )\n)' - path_list1 = PathList(Path(ConvNode('op1'))) - path_list2 = PathList(Path(ConvNode('op2'))) - toy_concat_node = ConcatNode('op3', [path_list1, path_list2]) + path_list1 = PathList(Path(PathConvNode('op1'))) + path_list2 = PathList(Path(PathConvNode('op2'))) + toy_concat_node = PathConcatNode('op3', [path_list1, path_list2]) assert repr( toy_concat_node - ) == 'ConcatNode(\n PathList(\n Path(\n ConvNode(\'op1\')\n )\n ),\n PathList(\n Path(\n ConvNode(\'op2\')\n )\n )\n)' # noqa: E501 + ) == 'PathConcatNode(\n PathList(\n Path(\n PathConvNode(\'op1\')\n )\n ),\n PathList(\n Path(\n PathConvNode(\'op2\')\n )\n )\n)' # noqa: E501 def test_reset_bn_running_stats(self): _test_reset_bn_running_stats(False) @@ -182,17 +184,17 @@ def test_reset_bn_running_stats(self): _test_reset_bn_running_stats(True) def test_node(self): - node1 = ConvNode('conv1') - node2 = ConvNode('conv2') + node1 = PathConvNode('conv1') + node2 = PathConvNode('conv2') assert node1 != node2 - node1 = ConvNode('conv1') - node2 = ConvNode('conv1') + node1 = PathConvNode('conv1') + node2 = PathConvNode('conv1') assert node1 == node2 def test_path(self): - node1 = ConvNode('conv1') - node2 = ConvNode('conv2') + node1 = PathConvNode('conv1') + node2 = PathConvNode('conv2') path1 = Path([node1]) path2 = Path([node2]) @@ -205,8 +207,8 @@ def test_path(self): assert path1[0] == node1 def test_path_list(self): - node1 = ConvNode('conv1') - node2 = ConvNode('conv2') + node1 = PathConvNode('conv1') + node2 = PathConvNode('conv2') path1 = Path([node1]) path2 = Path([node2])