# Testing complex cell

In [1]:
from customcells import ComplexCell
from customcells import GRUCell
import tensorflow as tf
import numpy as np
from helper_funcs import linear, DiagonalGaussianFromExisting

%load_ext autoreload
%autoreload 2

In [2]:
tf.enable_eager_execution()

In [3]:
# random inputs

bs = 300
dim = 10
T = 10
hps ={'gen_dim':10,
     'con_dim':5,
     'co_dim':2,
     'factors_dim':5,
     'ext_input_dim':0,
     }


complex_cell_inputs = tf.random.normal([bs, T, dim], mean=0.0, stddev=0.1, seed=10)
tf.random.set_random_seed(10)

In [4]:
#tf.reset_default_graph()
comcell_state_dims = [hps['gen_dim'],
                           hps['con_dim'],
                           hps['co_dim'], # for the controller output means
                           hps['co_dim'], # for the variances
                           hps['co_dim'], # for the sampled controller output
                           hps['factors_dim']]


# construct the complexcell
complexcell=ComplexCell(num_units_gen=hps['gen_dim'],
                             num_units_con=hps['con_dim'],
                             factors_dim=hps['factors_dim'],
                             co_dim=hps['co_dim'],
                             ext_input_dim=hps['ext_input_dim'],
                             inject_ext_input_to_gen=True,
                             run_type =0,
                             keep_prob=1.0,
                             #clip_value=hps['cell_clip_value']
                             kernel_initializer=tf.constant_initializer(0.2),
                        #bias_initializer=tf.zeros_initializer()
                             )


In [5]:
complexcell_init_state = tf.zeros([bs,sum(comcell_state_dims)]) + 1.0

In [6]:
tf.random.set_random_seed(11)
seq=[]
for t in range(T):
    seq.append(complex_cell_inputs[:,t,:])
complex_outputs, complex_final_state =\
tf.nn.static_rnn(complexcell,
                  inputs = seq,
                  initial_state = complexcell_init_state,
                  dtype=tf.float32)

(300, 10)


In [7]:
tf.random.set_random_seed(10)
complex_outputs, complex_final_state =\
tf.nn.dynamic_rnn(complexcell,
                  inputs = complex_cell_inputs,
                  initial_state = complexcell_init_state,
                  dtype=tf.float32)

In [8]:
tf.reduce_sum(complex_final_state[:,:hps['gen_dim']])

<tf.Tensor: id=3583, shape=(), dtype=float32, numpy=2966.2678>

In [9]:
tf.random.set_random_seed(10)

In [10]:
#tf.reset_default_graph()
gen_s = [0] * T
con_s = [0] * T
fac_s = [0] * T

gen_s[-1] = tf.zeros([bs,hps['gen_dim']]) + 1.0
con_s[-1] = tf.zeros([bs,hps['con_dim']])+1.0
fac_s[-1] = tf.zeros([bs,hps['factors_dim']]) + 1.0


gencell = GRUCell(hps['gen_dim'], kernel_initializer=tf.constant_initializer(0.2))
concell = GRUCell(hps['con_dim'], kernel_initializer=tf.constant_initializer(0.2))

gg = tf.zeros([bs,hps['co_dim']]).get_shape()
print(gg)
gencell.build(gg)
concell.build(tf.zeros([bs,hps['factors_dim']+dim]).get_shape() )


for t in range(T):
    # if external inputs are used split the inputs
    if False:
        pass
    else:
        con_i = complex_cell_inputs[:,t,:]

    # split the state to get the gen and con states, and factors

    # input to the controller is (enc_con output and factors)
    # MRKT
    # con_i = tf.zeros_like(con_i)
    if hps['co_dim'] > 0:
        # if controller is used
        con_inputs = tf.concat([con_i, fac_s[t-1]], axis=1, )
        # controller GRU recursion, get new state
        # add dropout to controller inputs (MRK fix)
        con_inputs = tf.nn.dropout(con_inputs, 1.0)
        con_s[t],_ = concell.call(con_inputs, con_s[t-1])

        # calculate the inputs to the generator
        with tf.name_scope("con_2_gen"):
            # transformation to mean and logvar of the posterior
            co_mean = linear(con_s[t], hps['co_dim'],
                             name="con_2_gen_transform_mean",
                            mat_init_value=0.2)
            co_logvar = linear(con_s[t], hps['co_dim'],
                               name="con_2_gen_transform_logvar",
                              mat_init_value=0.2)

            cos_posterior = DiagonalGaussianFromExisting(co_mean, co_logvar, name='co_posterior')
            # whether to sample the posterior or pass its mean
            # MRK, fixed the following
            #do_posterior_sample = tf.logical_or(tf.equal(self._run_type, tf.constant(kind_dict("train"))),
            #                                    tf.equal(self._run_type,
            #                                             tf.constant(kind_dict("posterior_sample_and_average"))))
            # co_out = tf.cond(do_posterior_sample, lambda: cos_posterior.sample, lambda: cos_posterior.mean)
            # MRKT
            co_out = cos_posterior.sample
            #co_out = cos_posterior.mean
            #co_out = co_mean
    else:
        # pass zeros (0-dim) as inputs to generator
        co_out = tf.zeros([tf.shape(gen_s[t-1])[0], 0])
        con_s_new = co_mean = co_logvar = tf.zeros([tf.shape(gen_s[t-1])[0], 0])

    # generator's inputs
    if 0 > 0 and 0:
        # passing external inputs along with controller output as generator's input
        gen_inputs = tf.concat([co_out, ext_inputs], axis=1)
    elif 0 > 0 and 0:
        assert 0, "Not Implemented!"
    else:
        # using only controller output as generator's input
        gen_inputs = co_out

    # generator GRU recursion, get the new state
    #gen_inputs = tf.zeros_like(gen_inputs)
    gen_s[t], _ = gencell.call(gen_inputs, gen_s[t-1])
    # calculate the factors
    with tf.variable_scope("gen_2_fac"):
        # add dropout to gen output (MRK fix)
        gen_s_new_dropped = tf.nn.dropout(gen_s[t],1.0)
        # MRK, make do_bias=False, and normalized the factors
        fac_s[t] = linear(gen_s_new_dropped, hps['factors_dim'],
                           name="gen_2_fac_transform",
                           do_bias=False,
                           normalized=True,
                          mat_init_value=0.2
                           # collections=self.col_names['fac']
                           )
    # pass the states and make other values accessible outside DynamicRNN
    #state_concat = [gen_s_new, con_s_new, co_mean, co_logvar, co_out, fac_s_new]
    #new_h = tf.concat(state_concat, axis=1)


(300, 2)


In [11]:
tf.reduce_sum(gen_s[-1])

<tf.Tensor: id=5207, shape=(), dtype=float32, numpy=2966.2678>