Skip to content

Commit

Permalink
add BaseNode/BaseGraph, ModuleNode/ModuleGraph (#217)
Browse files Browse the repository at this point in the history
* add BaseNode/BaseGraph,  ModuleNode/ModuleGraph

* add docstring, redesign some functions

* add 'placeholder' after cat/bind/pass

* change type to a property from a method

* add test model s

* rename XXXNode in path to PathXXXNode

* 'xxconv' -> 'xxconv2d' in type

* ToGraph -> GraphConverter, PathToGraph -> PathToGraphConverter

* rm init_from_path_list

* convert some public methods to private methods in GraphConverter

* type -> basic_type

* fix some error

Co-authored-by: liukai <liukai@pjlab.org.cn>
  • Loading branch information
LKJacky and liukai committed Aug 19, 2022
1 parent d190037 commit 83240dc
Show file tree
Hide file tree
Showing 11 changed files with 1,256 additions and 66 deletions.
6 changes: 3 additions & 3 deletions mmrazor/models/mutators/channel_mutator/channel_mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.nn import Module

from mmrazor.registry import MODELS, TASK_UTILS
from mmrazor.structures import ConcatNode, DepthWiseConvNode, PathList
from mmrazor.structures import PathConcatNode, PathDepthWiseConvNode, PathList
from ...mutables import MutableChannel
from ..base_mutator import BaseMutator
from ..utils import DEFAULT_MODULE_CONVERTERS
Expand Down Expand Up @@ -61,14 +61,14 @@ def add_link(self, path_list: PathList) -> None:
for path in path_list:
pre_node = None
for node in path:
if isinstance(node, DepthWiseConvNode):
if isinstance(node, PathDepthWiseConvNode):
module = self.name2module[node.name]
# The in_channels and out_channels of a depth-wise conv
# should be the same
module.mutable_out.register_same_mutable(module.mutable_in)
module.mutable_in.register_same_mutable(module.mutable_out)

if isinstance(node, ConcatNode):
if isinstance(node, PathConcatNode):
if pre_node is not None:
module_names = node.get_module_names()
concat_modules = [
Expand Down
5 changes: 5 additions & 0 deletions mmrazor/structures/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_graph import BaseGraph, BaseNode
from .module_graph import ModuleGraph, ModuleNode

__all__ = ['BaseGraph', 'BaseNode', 'ModuleNode', 'ModuleGraph']
223 changes: 223 additions & 0 deletions mmrazor/structures/graph/base_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""This module defines BaseNode and BaseGraph, which are used to model Directed
Acyclic Graph(DAG)"""
from collections import OrderedDict
from copy import copy
from typing import Any, Callable, Generic, Iterator, List, TypeVar

# BaseNode && BaseGraph


class BaseNode:
"""A single node in a graph."""

def __init__(self, name: str, val: Any) -> None:
"""
Args:
name (str): name of the node.
val (any): content of the node.
"""
self.val = val
self.name = name
self.prev_nodes: List = []
self.next_nodes: List = []

# node operation

def add_prev_node(self, node: 'BaseNode'):
"""add previous node."""
if node not in self.prev_nodes:
self.prev_nodes.append(node)
if self not in node.next_nodes:
node.next_nodes.append(self)

def add_next_node(self, node: 'BaseNode'):
"""add next node."""
if node not in self.next_nodes:
self.next_nodes.append(node)
if self not in node.prev_nodes:
node.prev_nodes.append(self)

@classmethod
def copy_from(cls, node: 'BaseNode'):
"""Copy a node, and generate a new node with current node type."""
return cls(node.name, node.val)

# compare operation

def __hash__(self) -> int:
"""Hash the node."""
return hash((self.val, self.name))

def __eq__(self, other):
"""Compare two nodes."""
return self.val is other.val and self.name == other.name

# other

def __repr__(self) -> str:
return self.name


BASENODE = TypeVar('BASENODE', bound=BaseNode)


class BaseGraph(Generic[BASENODE]):
"""A Directed Acyclic Graph(DAG)"""

def __init__(self) -> None:
super().__init__()
self.nodes: OrderedDict[str, BASENODE] = OrderedDict()

# graph operations

@classmethod
def copy_from(cls,
graph: 'BaseGraph',
node_converter: Callable = BaseNode.copy_from):
"""Copy a graph, and generate a new graph of the current class.
Args:
graph (BaseGraph): the graph to be copied.
node_converter (Callable): a function that converts node,
when coping graph.
"""
old2new = {}
new_graph = cls()
# copy nodes
for old in graph:
old2new[old] = new_graph.add_or_find_node(node_converter(old))

# connect
for old in graph:
for pre in old.prev_nodes:
new_graph.connect(old2new[pre], old2new[old])
return new_graph

# node operations

def add_or_find_node(self, node: BASENODE):
"""Add a node to the graph.
If the node has exsited in the graph, the function will return the node
recorded in the graph.
"""
find = self.find_node(node)
if find is not None:
return find
else:
self.add_node(node)
return node

def find_node(self, node: BaseNode):
"""Find a node and return."""
if node.name in self.nodes and node.val == self.nodes[node.name].val:
return self.nodes[node.name]
else:
return None

def add_node(self, node: BASENODE):
"""Add a node."""
if node.name not in self.nodes:
self.nodes[node.name] = node
else:
raise BaseException(f'{node.name} already exists in graph')

def connect(self, pre_node: BASENODE, next_node: BASENODE):
"""Add an edge from pre_node to next_node."""
assert pre_node in self and next_node in self
pre_node.add_next_node(next_node)
next_node.add_prev_node(pre_node)

def disconnect(self, pre_node: BASENODE, next_node: BASENODE):
"""Remove the edge form pre_node to next_node."""
assert pre_node in self and next_node in self
if next_node in pre_node.next_nodes:
pre_node.next_nodes.remove(next_node)
if pre_node in next_node.prev_nodes:
next_node.prev_nodes.remove(pre_node)

def delete_node(self, node: BASENODE):
"""Delete a node with its related edges."""
node = self.find_node(node)
assert node is not None

if len(node.prev_nodes) == 0:
for next in copy(node.next_nodes):
self.disconnect(node, next)
elif len(node.next_nodes) == 0:
for pre in copy(node.prev_nodes):
self.disconnect(pre, node)
elif len(node.prev_nodes) == 1:
pre_node = node.prev_nodes[0]
self.disconnect(pre_node, node)
for next in copy(node.next_nodes):
self.disconnect(node, next)
self.connect(pre_node, next)
elif len(node.next_nodes) == 1:
next_node = node.next_nodes[0]
self.disconnect(node, next_node)
for pre in copy(node.prev_nodes):
self.disconnect(pre, node)
self.connect(pre, next_node)
else:
raise Exception(f'not delete {node}, \
as it has more than one inputs and outputs')
self.nodes.pop(node.name)

# work as a collection

def __iter__(self) -> Iterator[BASENODE]:
"""Traverse all nodes in the graph."""
for x in self.nodes.values():
yield x

def __contains__(self, node: BASENODE) -> bool:
"""Check if a node is contained in the graph."""
return node.name in self.nodes

def __len__(self) -> int:
"""Number of nodes in the graph."""
return len(self.nodes)

# other

def __repr__(self):
res = f'Graph with {len(self)} nodes:\n'
for node in self:
res += '{0:<40} -> {1:^40} -> {2:<40}\n'.format(
str(node.prev_nodes), node.__repr__(), str(node.next_nodes))
return res

# traverse

def topo_traverse(self) -> Iterator[BASENODE]:
"""Traverse the graph in topologitcal order."""

def _in_degree(graph: BaseGraph):
degree = {}
for name, node in graph.nodes.items():
degree[name] = len(node.prev_nodes)
return degree

def find_zero_degree_node(in_degree):
for node_name in in_degree:
if in_degree[node_name] == 0:
return node_name
return None

in_degree = _in_degree(self)

while len(in_degree) > 0:
node_name = find_zero_degree_node(in_degree) # visit the node
in_degree.pop(node_name)
yield self.nodes[node_name]
for next in self.nodes[node_name].next_nodes:
in_degree[next.name] -= 1

def topo_sort(self):
"""Sort all node in topological order."""
sorted_nodes = OrderedDict()
for node in self.topo_traverse():
sorted_nodes[node.name] = node
self.nodes = sorted_nodes
Loading

0 comments on commit 83240dc

Please sign in to comment.