In [24]:
import os

import matplotlib
%matplotlib inline

import math
import random

import numpy as np
import pandas as pd
import tensorflow as tf
import _pickle as cPickle
from collections import defaultdict

In [25]:
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

In [27]:
class DoublyRNNCell:
    def __init__(self, dim_hidden, output_layer=None):
        self.dim_hidden = dim_hidden
        
        self.ancestral_layer=tf.layers.Dense(units=dim_hidden, activation=tf.nn.tanh, name='ancestral')
        self.fraternal_layer=tf.layers.Dense(units=dim_hidden, activation=tf.nn.tanh, name='fraternal')
        self.hidden_layer = tf.layers.Dense(units=dim_hidden, activation=tf.nn.tanh, name='hidden')
        
        self.output_layer=output_layer
        
    def __call__(self, state_ancestral, state_fraternal, reuse=True):
        with tf.variable_scope('input', reuse=reuse):
            state_ancestral = self.ancestral_layer(state_ancestral)
            state_fraternal = self.fraternal_layer(state_fraternal)

        with tf.variable_scope('output', reuse=reuse):
            state_hidden = self.hidden_layer(state_ancestral + state_fraternal)
            if self.output_layer is not None: 
                output = self.output_layer(state_hidden)
            else:
                output = state_hidden
            
        return output, state_hidden
    
    def get_initial_state(self, name):
        initial_state = tf.get_variable(name, [1, self.dim_hidden], dtype=tf.float32)
        return initial_state
    
    def get_zero_state(self, name):
        zero_state = tf.zeros([1, self.dim_hidden], dtype=tf.float32, name=name)
        return zero_state

In [60]:
def doubly_rnn(dim_hidden, tree_idxs, doubly_rnn_cell, initial_state_parent=None, initial_state_sibling=None, output_layer=None, name=''):
    outputs, states_parent = {}, {}
    
    with tf.variable_scope(name, reuse=False):
        if initial_state_parent is None: 
            initial_state_parent = doubly_rnn_cell.get_initial_state('init_state_parent')
#             initial_state_parent = doubly_rnn_cell.get_zero_state('init_state_parent')
        if initial_state_sibling is None: 
#             initial_state_sibling = doubly_rnn_cell.get_initial_state('init_state_sibling')
            initial_state_sibling = doubly_rnn_cell.get_zero_state('init_state_sibling')
        output, state_sibling = doubly_rnn_cell(initial_state_parent, initial_state_sibling, reuse=False)
        outputs[0], states_parent[0] = output, state_sibling

        for parent_idx, child_idxs in tree_idxs.items():
            state_parent = states_parent[parent_idx]
            state_sibling = initial_state_sibling
            for child_idx in child_idxs:
                output, state_sibling = doubly_rnn_cell(state_parent, state_sibling)
                outputs[child_idx], states_parent[child_idx] = output, state_sibling

    return outputs, states_parent

In [61]:
tree_idxs = {0:[1, 2, 3], 
              1:[10, 11], 2:[20, 21], 3:[30, 31]}

In [63]:
tf.reset_default_graph()

doubly_rnn_cell = DoublyRNNCell(2)

tree, _ = doubly_rnn(2, tree_idxs, doubly_rnn_cell)

In [64]:
if 'sess' in globals(): sess.close()
sess = tf.Session()
vars={v.name:v for v in tf.trainable_variables()}
sess.run([
    tf.assign(vars['init_state_parent:0'], np.array([[1., 1.]])),
    tf.assign(vars['input/ancestral/kernel:0'], np.array([[2., 2.], [2., 2.]])),
    tf.assign(vars['input/ancestral/bias:0'], np.array([2., 2.])),
    tf.assign(vars['input/fraternal/kernel:0'], np.array([[3., 3.], [3., 3.]])),
    tf.assign(vars['input/fraternal/bias:0'], np.array([3., 3.])),
    tf.assign(vars['output/hidden/kernel:0'], np.array([[2., 2.], [2., 2.]])),
    tf.assign(vars['output/hidden/bias:0'], np.array([2., 2.])),    
])

for idx, node in tree.items():
    print(idx, node.eval(session=sess))

0 [[1. 1.]]
1 [[1. 1.]]
2 [[1. 1.]]
3 [[1. 1.]]
10 [[1. 1.]]
11 [[1. 1.]]
20 [[1. 1.]]
21 [[1. 1.]]
30 [[1. 1.]]
31 [[1. 1.]]


<tf.Variable 'init_state_parent:0' shape=(1, 2) dtype=float32_ref>

<tf.Tensor 'Assign:0' shape=(1, 2) dtype=float32_ref>