In [1]:
import tensorflow as tf
from memory import Memory
import numpy as np
import math

In [2]:
batch_size = 2
read_heads = 1
num_keys = 1

mem_size = 10
vector_size = 5

memory = Memory(mem_size, vector_size, read_heads, batch_size)
memory_state = memory.init_memory()
memory_state

(<tf.Tensor 'Fill:0' shape=(2, 10, 5) dtype=float32>,
 <tf.Tensor 'Fill_1:0' shape=(2, 10) dtype=float32>,
 <tf.Tensor 'Fill_2:0' shape=(2, 10, 1) dtype=float32>,
 <tf.Tensor 'Fill_3:0' shape=(2, 5, 1) dtype=float32>)

In [3]:
keys = tf.fill([batch_size,vector_size,num_keys], 1.0)
strengths = tf.fill([batch_size,1], 1.0)
interpolation_gate = tf.fill([batch_size,1], 1.0)
shift_weighting = tf.nn.softmax(tf.constant([0.2, 0.6, 0.2]) + 1e-12)
gamma = tf.fill([batch_size,1], 1.0)

add = tf.constant([2.0, 1.6, 4.0, 1.0, 0.1])
erase = tf.constant([0.0, 0.0, 0.0, 0.0, 0.0])

## Get content addressing

In [4]:
normalized_memory = tf.nn.l2_normalize(memory_state[0], 2)
normalized_keys = tf.nn.l2_normalize(keys, 1)

similiarity = tf.batch_matmul(normalized_memory, normalized_keys)

strengths = tf.expand_dims(strengths, 1)

content_weights = tf.nn.softmax(similiarity * strengths, 1)
print(content_weights)

Tensor("transpose_1:0", shape=(2, 10, 1), dtype=float32)


## Test Interpolation

In [5]:
content_weights = tf.squeeze(content_weights,axis=2)
gated_weighting = interpolation_gate * content_weights + (1 - interpolation_gate) * memory_state[1]
print(gated_weighting)

Tensor("add_1:0", shape=(2, 10), dtype=float32)


## Circular convolution

In [28]:
tf.expand_dims(gated_weighting,2)

<tf.Tensor 'ExpandDims_2:0' shape=(2, 10, 1) dtype=float32>

In [25]:
size = int(gated_weighting.get_shape()[0])
kernel_size = int(shift_weighting.get_shape()[0])
kernel_shift = int(math.floor(kernel_size/2.0))

print(size,kernel_size,kernel_shift)

def loop(idx):
    if idx < 0: return size + idx
    if idx >= size : return idx - size
    else: return idx

kernels = []
for i in xrange(size):
    indices = [loop(i+j) for j in xrange(kernel_shift, -kernel_shift-1, -1)]
    temp = tf.transpose(gated_weighting)
    v_ = tf.transpose(tf.gather(temp, indices))
    kernels.append(tf.reduce_sum(v_ * shift_weighting, 0))

result_after_shift = tf.dynamic_stitch([i for i in xrange(size)], kernels)
print(result_after_shift)
print(kernels)

(2, 3, 1)
Tensor("DynamicStitch_2:0", shape=(?, 3), dtype=float32)
[<tf.Tensor 'Sum_4:0' shape=(3,) dtype=float32>, <tf.Tensor 'Sum_5:0' shape=(3,) dtype=float32>]


In [16]:
with tf.Session() as session:
    values = session.run(result_after_shift)
    print(values)

[[ 0.05727664  0.08544672  0.05727664]
 [ 0.05727664  0.08544672  0.05727664]]


In [None]:
new_weigth, new_memory = memory.write(memory_state[0],memory_state[1],keys,strengths,interpolation_gate,shift_weighting,gamma,add,erase)