In [None]:
# default_exp network

In [None]:
#export
from wong.imports import *
from wong.core import *
from wong.config import cfg, assert_cfg
from wong.graph import *

In [None]:
from fastcore.all import *

# Network
> CNN models generated according to abstracted DAGs.

In [None]:
#export
class NodeOP(nn.Module):
    "The Operation of inner nodes in the network."
    def __init__(self, ni:int, no:int, nh:int, Unit:nn.Module, **kwargs):
        super(NodeOP, self).__init__()
        self.unit = Unit(ni, no, nh, **kwargs)
            
    def forward(self, *inputs):
        sum_inputs = sum(inputs)        
        out = self.unit(sum_inputs)
        return out
    

Parameters:
- ni : number of input channels
- no : number of output channels
- nh : number of hidden channels
- Unit : the operation at the node
- kwargs : arguments into `Unit`

> Note: `sum` op may has performance problem, should we use `torch.stack(inputs, dim=0).sum(dim=0)` ?

In [None]:
ni, no, nh = 16, 32, 8
Unit = resnet_bottleneck
input1 = torch.rand(64, ni, 224, 224)
input2 = torch.rand(64, ni, 224, 224)
inputs = [input1, input2]
m = NodeOP(ni, no, nh, Unit)
out = m(*inputs)
test_eq(out.shape, torch.Size([64, no, 224, 224]))

In [None]:
m

NodeOP(
  (unit): Sequential(
    (0): ReLU()
    (1): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
    (4): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (5): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (8): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [None]:
cfg_file = 'configs/imagenet/resnet/resnet50.yaml'
cfg.merge_from_file(cfg_file)
assert_cfg(cfg)
cfg.freeze()
cfg

CfgNode({'GRAPH': CfgNode({'NUM_STAGES': 4, 'NUM_NODES': (3, 4, 6, 3), 'NUM_CHANNELS': (64, 128, 256, 512)})})

In [None]:
G = resnet_dag(cfg.GRAPH.NUM_NODES)

In [None]:
G.nodes, G.edges

(NodeView((0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)),
 OutEdgeView([(0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (1, 2), (1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5), (3, 4), (3, 5), (4, 6), (4, 7), (4, 8), (4, 9), (4, 10), (5, 6), (5, 7), (5, 8), (5, 9), (5, 10), (6, 7), (6, 8), (6, 9), (6, 10), (7, 8), (7, 9), (7, 10), (8, 9), (8, 10), (9, 11), (9, 12), (9, 13), (9, 14), (9, 15), (9, 16), (9, 17), (10, 11), (10, 12), (10, 13), (10, 14), (10, 15), (10, 16), (10, 17), (11, 12), (11, 13), (11, 14), (11, 15), (11, 16), (11, 17), (12, 13), (12, 14), (12, 15), (12, 16), (12, 17), (13, 14), (13, 15), (13, 16), (13, 17), (14, 15), (14, 16), (14, 17), (15, 16), (15, 17), (16, 18), (16, 19), (16, 20), (17, 18), (17, 19), (17, 20), (18, 19), (18, 20), (19, 20)]))

In [None]:
#export
class NetworkOP(nn.Module):
    "The operations along a DAG network."
    def __init__(self, G:nx.DiGraph, ni:int, no:int, Unit:nn.Module, **kwargs):
        super(NetworkOP, self).__init__()
        self.G = G
        self.n = G.graph['n'] # number of nodes
        self.nodeops = nn.ModuleList() 
        for id in G.nodes(): # for each node
            if id == 0:  # if is the unique input node, do nothing
                continue
            elif id == self.n:  # if is the unique output node
                # then, concat its predecessors
                n_preds = len([*G.predecessors(id)])
                self.nodeops += [IdentityMapping(n_preds * ni, no)]
            else:  # if is the inner node
                self.nodeops += [NodeOP(ni, ni, ni, Unit, **kwargs)]
            
    def forward(self, x):
        results = {}
        results[-1] = x  # input data is the result of the unique input node
        for id in self.G.nodes(): # for each node
            if id == -1:  # if is the input node, do nothing
                continue
            # get the results of all predecessors
            inputs = [results[pred]  for pred in self.G.predecessors(id)]
            if id == self.n: # if is the output node
                cat_inputs = torch.cat(inputs, dim=1) # concat results of all predecessors
                if self.efficient:
                    return cp.checkpoint(self.nodeops[id], cat_inputs) 
                else:
                    return self.nodeops[id](cat_inputs)
            else: # if is inner nodes
                if self.efficient:
                    results[id] = cp.checkpoint(self.nodeops[id], *inputs) 
                else:
                    results[id] = self.nodeops[id](*inputs)

            # 删除前驱结点result中，不再需要的result
            for pred in self.G.predecessors(id):  # 获得节点的所有前驱结点
                succs = list(self.G.successors(pred))  # 获得每个前驱结点的所有后继节点
                # 如果排名最后的后继节点是当前节点，说明该前驱结点的result不再被后续的节点需要，可以删除
                if max(succs) == id:  
                    del results[pred]
        



Parameters:

- G   :  the `NetworkX` 'DiGraph' object, represent a DAG.
- ni  :  number of input channels of the network
- no  :  number of output channel of the network
- Unit : operation at inner nodes
- kwargs : arguments into `Unit`
