In [88]:
import networkx as nx


class RandomGraph(object):
    def __init__(self, node_num, p, seed, k=4, m=5, graph_mode="ER"):
        self.node_num = node_num
        self.p = p
        self.k = k
        self.m = m
        self.seed = seed
        self.graph_mode = graph_mode

        self.graph = self.make_graph()

    def make_graph(self):
        # reference
        # https://networkx.github.io/documentation/networkx-1.9/reference/generators.html

        if self.graph_mode is "ER":
            graph = nx.random_graphs.erdos_renyi_graph(self.node_num, self.p, self.seed)
        elif self.graph_mode is "WS":
            graph = nx.random_graphs.watts_strogatz_graph(self.node_num, self.k, self.p, self.seed)
        elif self.graph_mode is "BA":
            graph = nx.random_graphs.barabasi_albert_graph(self.node_num, self.m, self.seed)
            
        return graph

    def get_graph_info(self):
        in_edges = {}
        in_edges[0] = []
        nodes = [0]
        end = []
        for node in self.graph.nodes():
            neighbors = list(self.graph.neighbors(node))
            neighbors.sort()
#             print(node, neighbors)

            edges = []
            check = []
            for neighbor in neighbors:
                if node > neighbor:
                    edges.append(neighbor + 1)
                    check.append(neighbor)
            if not edges:
                edges.append(0)
            in_edges[node + 1] = edges
            if check == neighbors:
                end.append(node + 1)
            nodes.append(node + 1)
        in_edges[self.node_num + 1] = end
        nodes.append(self.node_num + 1)
        
#         print("node, edges: ", nodes, in_edges)
        return nodes, in_edges

In [89]:
import torch
import torch.nn as nn

from graph import RandomGraph


def weights_init(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)


# reference, Thank you.
# https://github.com/tstandley/Xception-PyTorch/blob/master/xception.py
# Reporting 1,
# I don't know which one is better, between 'bias=False' and 'bias=True'
class SeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False):
        super(SeparableConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=bias)
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)

        self.apply(weights_init)

    def forward(self, x):
        x = self.conv(x)
        x = self.pointwise(x)
        return x


# ReLU-convolution-BN triplet
class Unit(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(Unit, self).__init__()

        self.unit = nn.Sequential(
            nn.ReLU(),
            SeparableConv2d(in_channels, out_channels, stride=stride),
            nn.BatchNorm2d(out_channels),
#             nn.Dropout(0.2)
        )

    def forward(self, x):
        return self.unit(x)


# Reporting 2,
# In the paper, they said "The aggregation is done by weighted sum with learnable positive weights".
class Node(nn.Module):
    def __init__(self, in_degree, in_channels, out_channels, stride=1):
        super(Node, self).__init__()
        self.in_degree = in_degree
        if len(self.in_degree) > 1:
            # self.weights = nn.Parameter(torch.zeros(len(self.in_degree), requires_grad=True))
            self.weights = nn.Parameter(torch.zeros(len(self.in_degree), requires_grad=True))
        self.unit = Unit(in_channels, out_channels, stride=stride)

    def forward(self, *input):
        if len(self.in_degree) > 1:
            x = (input[0] * torch.sigmoid(self.weights[0]))
            for index in range(1, len(input)):
                x += (input[index] * torch.sigmoid(self.weights[index]))
            out = self.unit(x)

        else:
            out = self.unit(input[0])
        return out


class RandWire(nn.Module):
    def __init__(self, node_num, p, seed, in_channels, out_channels):
        super(RandWire, self).__init__()
        self.node_num = node_num
        self.p = p
        self.seed = seed
        self.in_channels = in_channels
        self.out_channels = out_channels

        # get graph nodes and in edges
        graph = RandomGraph(self.node_num, self.p, self.seed)
        self.nodes, self.in_edges = graph.get_graph_info()
#         print(self.nodes, self.in_edges)

        # define input Node
        self.module_list = nn.ModuleList([Node(self.in_edges[0], self.in_channels, self.out_channels, stride=2)])
        # define the rest Node
        self.module_list.extend([Node(self.in_edges[node], self.out_channels, self.out_channels) for node in self.nodes if node > 0])

    def forward(self, x):
        memory = {}
        # start vertex
        out = self.module_list[0].forward(x)
        memory[0] = out

        # the rest vertex
        for node in range(1, len(self.nodes) - 1):
            # print(node, self.in_edges[node][0], self.in_edges[node])
            if len(self.in_edges[node]) > 1:
                out = self.module_list[node].forward(*[memory[in_vertex] for in_vertex in self.in_edges[node]])
            else:
                out = self.module_list[node].forward(memory[self.in_edges[node][0]])
            memory[node] = out

        # 논문과 다른 부분
        # Reporting 3,
        # How do I handle the last part?
        # It has two kinds of methods.
        # first, Think of the last module as a Node and collect the data by proceeding in the same way as the previous operation.
        # second, simply sum the data and export the output.
        # out = self.module_list[self.node_num + 1].forward(*[memory[in_vertex] for in_vertex in self.in_edges[self.node_num + 1]])


        # test code
        # print("self.in_edges: ", self.in_edges[self.node_num + 1], self.in_edges[self.node_num + 1][0])
        out = memory[self.in_edges[self.node_num + 1][0]]
        for in_vertex_index in range(1, len(self.in_edges[self.node_num + 1])):
            out += memory[self.in_edges[self.node_num + 1][in_vertex_index]]
        out = out / len(self.in_edges[self.node_num + 1])
        return out

In [90]:
import argparse
import hiddenlayer as hl
import torchvision.models

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [91]:
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self, node_num, p, seed, in_channels, out_channels):
        super(Model, self).__init__()
        self.node_num = node_num
        self.p = p
        self.seed = seed
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=self.out_channels, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU()
        )
        self.rand_wire1 = nn.Sequential(
            RandWire(self.node_num, self.p, self.seed, self.in_channels, self.out_channels)
        )
        self.rand_wire2 = nn.Sequential(
            RandWire(self.node_num, self.p, self.seed, self.in_channels, self.out_channels * 2)
        )
        self.rand_wire3 = nn.Sequential(
            RandWire(self.node_num, self.p, self.seed, self.in_channels * 2, self.out_channels * 4)
        )
        self.conv_output = nn.Sequential(
            nn.Conv2d(self.in_channels * 4, 1280, kernel_size=1),
            nn.BatchNorm2d(1280)
        )

        self.output = nn.Linear(1280, 10)

    def forward(self, x):
        out = self.conv1(x)
        out = self.rand_wire1(out)
        out = self.rand_wire2(out)
        out = self.rand_wire3(out)
        out = self.conv_output(out)
        print(out.shape)
        
        # You have to fixed kernel_size in avg_pool2d layer.
        # If you don't use a fixed kernel_size, you can meet a problems.
        # avg_pool2d kernel_size: out.size()[2:], example) In this case, out.shape([1, 1280, 2, 2]) So, kernel_size is [2, 2].
        # global average pooling
        out = F.avg_pool2d(out, kernel_size=[2, 2])
        out = out.view(-1, 1280)
        out = self.output(out)

        return out

In [92]:
# m = torchvision.models.resnet152()
model = Model(7, 0.4, 10, 78, 78)
k = hl.build_graph(model, torch.zeros([1, 3, 32, 32]))
k.save('./grpah_pdf')

0 [4]
1 [3, 5, 6]
2 [5]
3 [1, 5, 6]
4 [0]
5 [1, 2, 3]
6 [1, 3]
0 [4]
1 [3, 5, 6]
2 [5]
3 [1, 5, 6]
4 [0]
5 [1, 2, 3]
6 [1, 3]
0 [4]
1 [3, 5, 6]
2 [5]
3 [1, 5, 6]
4 [0]
5 [1, 2, 3]
6 [1, 3]




torch.Size([1, 1280, 2, 2])



