# LSTMControll

- num_out_layers (4개 레이어 아키텍쳐)
- num_ops (6개 conv ops)
- num_cell_units (lstm hidden size)
- num_input_nodes (생성 모델에서 입력 노드의 개수)

# ref 

https://github.com/shibuiwilliam/ENAS-Keras/blob/master/src/controller_network.py
https://github.com/thinkronize/ripenet/blob/master/controllers/enas_controller.py
https://github.com/thinkronize/ripenet/blob/master/cub.py
https://github.com/thinkronize/ENAS-pytorch/blob/master/models/controller.py

In [None]:
from keras.models import Model
from keras.layers import Embedding, Input, Reshape, Dense, Activation
from keras.layers.recurrent import LSTMCell, RNN

In [None]:
num_ops = 6
num_nodes = 4    
num_cell_units = 32
num_input_nodes = 2

In [None]:
def decode(inputs, num_tokens):
    if type(inputs) is list:
        inputs = inputs[0]    # top result 선택
    y = Dense(num_tokens)(inputs)
    y = Activation(activation="softmax")(y)
    return y

In [None]:
# NOTE: reference 모두 input 사이즈가 #ops + #nodes. why??
# 입력을 무엇으로 줄 것인가??? 이전 arc_seq??
# input_size = num_ops + num_nodes
input_size = num_ops
encoder_inputs = Input(shape=(None, input_size))   
# lstm latent 사이즈로 embedding
# embed = Embedding(num_ops, num_cell_units)(encoder_inputs)

# Controller RNN
out_seqs = []
for node_idx in range(1, num_nodes+1):  
    num_token_ops = num_ops        # ops 추정시 가능한 개수
    num_token_idx = node_idx-1     # 노드 추정시 선택 가능한 노드 인덱스 (현재 노드 미만)
    
    if node_idx == 1:
        cell = LSTMCell(num_cell_units)
        x = RNN(cell, return_state=True)(encoder_inputs)    # NOTE: 첫 입력이 무엇일 될 것인가??
        rx = Reshape((-1, num_cell_units))(x[0])
        y = decode(x, num_token_ops)
        out_seqs.append(y)
        continue
    
    # 노드 추정
    cell = LSTMCell(num_cell_units)
    x = RNN(cell, return_state=True)(rx, initial_state=x[1:])
    rx = Reshape((-1, num_cell_units))(x[0])
    y = decode(x, num_token_idx)
    out_seqs.append(y)

    # ops 추정
    cell = LSTMCell(num_cell_units)
    x = RNN(cell, return_state=True)(rx, initial_state=x[1:])
    rx = Reshape((-1, num_cell_units))(x[0])
    y = decode(x, num_token_ops)
    out_seqs.append(y)

rnn_controller = Model(inputs=encoder_inputs, outputs=out_seqs)

In [None]:
rnn_controller.summary()