### torch.nn.Parameter

- model.py / Node_OP class 에서 input 받아올 때

In [3]:
# origin source
# self.mean_weight = nn.Parameter(torch.ones(self.input_nums))
# self.sigmoid = nn.Sigmoid()

In [5]:
import torch
import torch.nn as nn

In [6]:
input_nums = 10
temp = nn.Parameter(torch.ones(input_nums))

In [9]:
print(type(temp))
print(temp.shape)
temp

<class 'torch.nn.parameter.Parameter'>
torch.Size([10])


Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], requires_grad=True)

In [10]:
temp[0]

tensor(1., grad_fn=<SelectBackward>)

In [18]:
nn.Sigmoid()(temp[0])

tensor(0.7311, grad_fn=<SigmoidBackward>)

In [19]:
nn.Sigmoid()(torch.ones(1))

tensor([0.7311])

# utils_graph.py의 함수들

In [20]:
import networkx as nx
import collections

In [None]:
Node = collections.namedtuple('Node', ['id', 'inputs', 'type'])

def get_graph_info(graph):
  input_nodes = []
  output_nodes = []
  Nodes = []
  for node in range(graph.number_of_nodes()):
    tmp = list(graph.neighbors(node))
    tmp.sort()
    type = -1
    if node < tmp[0]:
      input_nodes.append(node)
      type = 0
    if node > tmp[-1]:
      output_nodes.append(node)
      type = 1
    Nodes.append(Node(node, [n for n in tmp if n < node], type))
  return Nodes, input_nodes, output_nodes


def build_graph(Nodes, args):
  if args.graph_model == 'ER':
    return nx.random_graphs.erdos_renyi_graph(Nodes, args.P, args.seed)
  elif args.graph_model == 'BA':
    return nx.random_graphs.barabasi_albert_graph(Nodes, args.M, args.seed)
  elif args.graph_model == 'WS':
    return nx.random_graphs.connected_watts_strogatz_graph(Nodes, args.K, args.P, tries=200, seed=args.seed)


def save_graph(graph, path):
  nx.write_yaml(graph, path)


def load_graph(path):
  return nx.read_yaml(path)

In [29]:
# build_graph
nodes = 32
P = 0.75
K = 4
seed = 1
save_path = './graph_ER.yaml'

graph_ER = nx.random_graphs.erdos_renyi_graph(nodes, P, seed)

In [30]:
# save_graph
nx.write_yaml(graph_ER, save_path)   

In [31]:
# load graph
graph_ER_saved = nx.read_yaml(save_path)

In [41]:
# get_graph_info
input_nodes = []
output_nodes = []
Nodes = []

print(graph_ER_saved.number_of_nodes())
for i in range(graph_ER_saved.number_of_nodes()):
    print(list(graph_ER_saved.neighbors(i)))

32
[1, 4, 5, 6, 7, 9, 10, 12, 14, 15, 16, 17, 20, 21, 22, 24, 25, 26, 27, 28, 29, 30, 31]
[0, 2, 3, 4, 5, 6, 8, 9, 10, 13, 14, 15, 16, 18, 20, 21, 22, 25, 26, 27, 28, 30, 31]
[1, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 23, 25, 26, 28, 29, 30, 31]
[1, 2, 5, 9, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
[0, 1, 2, 7, 9, 10, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
[0, 1, 2, 3, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 22, 23, 24, 25, 26, 27, 28, 29, 30]
[0, 1, 2, 8, 10, 12, 13, 14, 15, 17, 20, 21, 24, 25, 26, 27, 28, 29, 30, 31]
[0, 2, 4, 5, 8, 9, 10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 27, 28, 30, 31]
[1, 2, 5, 6, 7, 9, 10, 12, 14, 16, 17, 18, 20, 23, 24, 25, 26, 27, 31]
[0, 1, 3, 4, 5, 7, 8, 10, 12, 13, 14, 15, 16, 18, 20, 21, 24, 25, 26, 27, 29, 30]
[0, 1, 2, 4, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 18, 21, 22, 23, 24, 25, 26, 27, 29, 30, 31]
[2, 3, 4, 5, 7, 10, 12, 13, 17, 18, 19, 21, 22, 23, 24, 25

In [46]:
tmp = list(graph_ER_saved.neighbors(10))
tmp.sort()
print(tmp)

[0, 1, 2, 4, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 18, 21, 22, 23, 24, 25, 26, 27, 29, 30, 31]


In [49]:
Node = collections.namedtuple('Node', ['id', 'inputs', 'type'])

for node in range(graph_ER_saved.number_of_nodes()):
    # node i 에 대해
    tmp = list(graph_ER_saved.neighbors(node))
    tmp.sort()
    # node type 정의
    type = -1    # input node도, output node도 아닌. 그래프의 중간에 매개자처럼 있는 중간 node.
    if node < tmp[0]:  # id 가장 작은 노드보다 작으면, 이건 외부에서 input을 받는 노드. 즉 input node.
        input_nodes.append(node)
        type = 0
    if node > tmp[-1]:  # id 가장 큰 노드보다 크면, 이건 외부로 output 내보내는 노드. 즉 output node.
        output_nodes.append(node)
        type = 1
    # dag로 변환 (자신의 id보다 작은 노드들과의 연결만 남기기)
    Nodes.append(Node(node, [n for n in tmp if n < node], type))

In [50]:
Nodes

[Node(id=0, inputs=[], type=0),
 Node(id=1, inputs=[0], type=-1),
 Node(id=2, inputs=[1], type=-1),
 Node(id=3, inputs=[1, 2], type=-1),
 Node(id=4, inputs=[0, 1, 2], type=-1),
 Node(id=5, inputs=[0, 1, 2, 3], type=-1),
 Node(id=6, inputs=[0, 1, 2], type=-1),
 Node(id=7, inputs=[0, 2, 4, 5], type=-1),
 Node(id=8, inputs=[1, 2, 5, 6, 7], type=-1),
 Node(id=9, inputs=[0, 1, 3, 4, 5, 7, 8], type=-1),
 Node(id=10, inputs=[0, 1, 2, 4, 6, 7, 8, 9], type=-1),
 Node(id=11, inputs=[2, 3, 4, 5, 7, 10], type=-1),
 Node(id=12, inputs=[0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], type=-1),
 Node(id=13, inputs=[1, 2, 3, 5, 6, 9, 10, 11, 12], type=-1),
 Node(id=14, inputs=[0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 12, 13], type=-1),
 Node(id=15, inputs=[0, 1, 2, 4, 5, 6, 7, 9, 10, 13, 14], type=-1),
 Node(id=16, inputs=[0, 1, 3, 4, 5, 7, 8, 9, 10, 12, 13, 15], type=-1),
 Node(id=17, inputs=[0, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16], type=-1),
 Node(id=18, inputs=[1, 2, 3, 4, 5, 8, 9, 10, 11, 12, 15, 16, 17], type=

In [51]:
input_nodes

[0, 0]

In [52]:
output_nodes

[29, 30, 31]

# => 최종 Stage Block

In [60]:
# unpack list
# function(a, b), 즉 2개의 argument가 들어가는 함수에 *[1 ,2]를 input으로 넣을 수 있음
a = [1, [3, 2], 4, 5]
print(*a)

1 [3, 2] 4 5
