Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relational memory Input to reshape is a tensor with 1200 values, but the requested shape has 12 #89

Open
gaceladri opened this issue Jun 25, 2018 · 2 comments

Comments

Projects
None yet
2 participants
@gaceladri
Copy link

commented Jun 25, 2018

Hi,

I am trying to implement your relational memory into a Meta-A3C. This is part of the code.

class Agent():
    def __init__(self, a_size, scope, trainer):
        with tf.variable_scope(scope):
            # Input placeholders
            self.state = tf.placeholder(
                shape=[None, 1, 2, 3], dtype=tf.float32)
            self.prev_rewards = tf.placeholder(
                shape=[None, 1], dtype=tf.float32)
            self.prev_actions = tf.placeholder(shape=[None], dtype=tf.int32)
            self.timestep = tf.placeholder(shape=[None, 1], dtype=tf.float32)
            self.prev_actions_onehot = tf.one_hot(
                self.prev_actions, a_size, dtype=tf.float32)

            # Recurrent network for temporal dependencies
            hidden = tf.concat([tf.layers.flatten(self.state),
                                self.prev_rewards, self.prev_actions_onehot, self.timestep], 1)
            relational_cell = RelationalMemory(mem_slots=2048, head_size=12, num_heads=1, num_blocks=1,
                                               forget_bias=1.0, input_bias=0.0, gate_style='unit', attention_mlp_layers=2, key_size=None, name='relational_memory')
            state_init = np.eye(relational_cell._mem_slots, dtype=np.float32)
            state_init = state_init[np.newaxis, ...]
            state_init = state_init[:, :, :relational_cell._mem_size]
            self.state_init = state_init
            self.state_in = tf.placeholder(
                tf.float32, shape=[1, relational_cell._mem_slots, relational_cell._mem_size])
            step_size = tf.shape(self.prev_rewards)[:1]
            rnn_in = tf.expand_dims(hidden, [0])

            output_sequence, cell_state = tf.nn.dynamic_rnn(
                relational_cell, rnn_in, sequence_length=step_size, initial_state=self.state_init,
                time_major=True)
            self.state_out = cell_state
            rnn_out = tf.reshape(output_sequence, [-1, 12])

            self.actions = tf.placeholder(shape=[None], dtype=tf.int32)
            self.actions_onehot = tf.one_hot(
                self.actions, a_size, dtype=tf.float32)

            # Output layer for policy and value estimations
            self.policy = tf.contrib.layers.fully_connected(rnn_out, a_size,
                                                            activation_fn=tf.nn.softmax,
                                                            weights_initializer=normalized_columns_initializer(
                                                                0.01),
                                                            biases_initializer=None)
            self.value = tf.contrib.layers.fully_connected(rnn_out, 1,
                                                           activation_fn=None,
                                                           weights_initializer=normalized_columns_initializer(
                                                               1.0),
                                                           biases_initializer=None)

When I try to run the code it gives me the next error:

Caused by op 'worker3/rnn/while/relational_memory/batch_flatten_2/Reshape', defined at:
  File "agent.py", line 377, in <module>
    main(FLAGS)
  File "agent.py", line 322, in main
    trainer, args.save_dir, global_episodes))
  File "agent.py", line 170, in __init__
    self.local_AC = Agent(a_size, self.name, trainer)
  File "agent.py", line 100, in __init__
    time_major=True)
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/tensorflow/python/ops/rnn.py", line 627, in dynamic_rnn
    dtype=dtype)
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/tensorflow/python/ops/rnn.py", line 824, in _dynamic_rnn_loop
    swap_memory=swap_memory)
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3224, in while_loop
    result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2956, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2893, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3194, in <lambda>
    body = lambda i, lv: (i + 1, orig_body(*lv))
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/tensorflow/python/ops/rnn.py", line 793, in _time_step
    skip_conditionals=True)
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/tensorflow/python/ops/rnn.py", line 248, in _rnn_step
    new_output, new_state = call_cell()
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/tensorflow/python/ops/rnn.py", line 781, in <lambda>
    call_cell = lambda: cell(input_t, state)
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/sonnet/python/modules/base.py", line 389, in __call__
    outputs, subgraph_name_scope = self._template(*args, **kwargs)
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/tensorflow/python/ops/template.py", line 455, in __call__
    result = self._call_func(args, kwargs)
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/tensorflow/python/ops/template.py", line 406, in _call_func
    result = self._func(*args, **kwargs)
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/sonnet/python/modules/base.py", line 246, in _build_wrapper
    output = self._build(*args, **kwargs)
  File "/media/proto/942ea58f-5397-4c08-aa32-b507c86d1c07/IA/Reinforcement/ABN_Robotics/surfer/meinemashine/Meta-Relational-A3C/memory.py", line 257, in _build
    inputs_reshape, memory)
  File "/media/proto/942ea58f-5397-4c08-aa32-b507c86d1c07/IA/Reinforcement/ABN_Robotics/surfer/meinemashine/Meta-Relational-A3C/memory.py", line 188, in _create_gates
    inputs = basic.BatchFlatten()(inputs)
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/sonnet/python/modules/base.py", line 389, in __call__
    outputs, subgraph_name_scope = self._template(*args, **kwargs)
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/tensorflow/python/ops/template.py", line 455, in __call__
    result = self._call_func(args, kwargs)
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/tensorflow/python/ops/template.py", line 406, in _call_func
    result = self._func(*args, **kwargs)
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/sonnet/python/modules/base.py", line 246, in _build_wrapper
    output = self._build(*args, **kwargs)
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/sonnet/python/modules/basic.py", line 763, in _build
    output = tf.reshape(inputs, shape)
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/tensorflow/python/ops/gen_array_ops.py", line 6113, in reshape
    "Reshape", tensor=tensor, shape=shape, name=name)
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3392, in create_op
    op_def=op_def)
  File "/home/proto/anaconda3/envs/t/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1718, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Input to reshape is a tensor with 1200 values, but the requested shape has 12
	 [[Node: worker3/rnn/while/relational_memory/batch_flatten_2/Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](worker3/rnn/while/relational_memory/ExpandDims, worker3/rnn/while/relational_memory/batch_flatten_2/concat)]]

Am I running the cell in a wrong way? Is it calling to the RNNCore _build() instead of the RelationalMemory when it is ran? I run the initial state with numpy because when I am going to feed it into the graph it gives me an error that the initial state can not be a tensor.

The entire A3C code:
https://github.com/gaceladri/Meta-Relational-A3C

Thanks!

@kosklain

This comment has been minimized.

Copy link
Contributor

commented Jul 27, 2018

It looks like BatchFlatten is getting confused about the shapes to expect. Please use relational_cell.initial_state(batch_size) as an initial state, not a numpy array, as you're doing.

@gaceladri

This comment has been minimized.

Copy link
Author

commented Jul 29, 2018

@kosklain Thanks a lot for your response. I have putted a numpy array because when I put relational_cell.initial_state(batch_size=1) I cannot feed it into my dictionary.

If I put relational_cell.initial_state(batch_size=1) as an initial state I get the next error:

TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles.For reference, the tensor object was Tensor("worker2/strided_slice:0", shape=(1, 2048, 12), dtype=float32, device=/device:CPU:0) which was passed to the feed with key Tensor("worker2/Placeholder_4:0", shape=(1, 2048, 12), dtype=float32, device=/device:CPU:0).

Part of the code where I feed the initial state into tf:

    # Update the global network using gradients from loss
    # Generate network statistics to periodically save
    rnn_state = self.local_AC.state_init
    feed_dict = {self.local_AC.target_v: discounted_rewards,
                 self.local_AC.state: np.stack(states, axis=0),
                 self.local_AC.prev_rewards: np.vstack(prev_rewards),
                 self.local_AC.prev_actions: prev_actions,
                 self.local_AC.actions: actions,
                 self.local_AC.timestep: np.vstack(timesteps),
                 self.local_AC.advantages: advantages,
                 self.local_AC.state_in: rnn_state}
    v_l, p_l, e_l, g_n, v_n, _ = sess.run([self.local_AC.value_loss,
                                           self.local_AC.policy_loss,
                                           self.local_AC.entropy,
                                           self.local_AC.grad_norms,
                                           self.local_AC.var_norms,
                                           self.local_AC.apply_grads],
                                          feed_dict=feed_dict)

Is for that that I changed to an initial state as a numpy array.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.