In [1]:
import tensorflow as tf
from memory import Memory
import numpy as np
import math

In [2]:
batch_size = 2
read_heads = 1
num_keys = 1

mem_size = 10
vector_size = 5

memory = Memory(mem_size, vector_size, read_heads, batch_size)
memory_state = memory.init_memory()
memory_state

(<tf.Tensor 'Fill:0' shape=(2, 10, 5) dtype=float32>,
 <tf.Tensor 'Fill_1:0' shape=(2, 10) dtype=float32>,
 <tf.Tensor 'Fill_2:0' shape=(2, 10, 1) dtype=float32>,
 <tf.Tensor 'Fill_3:0' shape=(2, 5, 1) dtype=float32>)

In [3]:
keys = tf.fill([batch_size,vector_size,num_keys], 1.0)
strengths = tf.fill([batch_size,1], 1.0)
interpolation_gate = tf.fill([batch_size,1], 1.0)
shift_weighting = tf.nn.softmax(tf.constant([0.2, 0.6, 0.2]) + 1e-12)
gamma = tf.fill([batch_size,1], 1.0)

add = tf.constant([[2.0, 1.6, 4.0, 1.0, 0.1],[1.0, 1.1, 2.0, 1.0, 0.1]])
erase = tf.constant([[0.0, 0.0, 0.0, 0.0, 0.0],[0.5, 0.5, 0.2, 0.0, 0.0]])

## Get content addressing

In [4]:
normalized_memory = tf.nn.l2_normalize(memory_state[0], 2)
normalized_keys = tf.nn.l2_normalize(keys, 1)

similiarity = tf.batch_matmul(normalized_memory, normalized_keys)
print(similiarity)

strengths = tf.expand_dims(strengths, 1)
print(strengths)

content_weights = tf.nn.softmax(similiarity * strengths, 1)
print(content_weights)

Tensor("BatchMatMul:0", shape=(2, 10, 1), dtype=float32)
Tensor("ExpandDims:0", shape=(2, 1, 1), dtype=float32)
Tensor("transpose_1:0", shape=(2, 10, 1), dtype=float32)


In [5]:
tf.squeeze(content_weights,axis=2)

<tf.Tensor 'Squeeze:0' shape=(2, 10) dtype=float32>

## Test Interpolation

In [5]:
content_weights = tf.squeeze(content_weights,axis=2)
gated_weighting = interpolation_gate * content_weights + (1 - interpolation_gate) * memory_state[1]
print(gated_weighting)

Tensor("add_1:0", shape=(2, 10), dtype=float32)


## Circular convolution

In [6]:
gated_weighting = tf.constant([[2.0, 1.6, 4.0, 1.0, 0.1, 0.2, 0.4, 2, 3.5, 10],[1.0, 0.5, 2.0, 1.1, 3.1, 0.25, 0.33, 0.14, 0.1, 0]])
gated_weighting

<tf.Tensor 'Const_3:0' shape=(2, 10) dtype=float32>

In [7]:
size = int(gated_weighting.get_shape()[1])
kernel_size = int(shift_weighting.get_shape()[0])
kernel_shift = int(math.floor(kernel_size/2.0))

def loop(idx):
    if idx < 0: return size + idx
    if idx >= size : return idx - size
    else: return idx

kernels = []
for i in xrange(size):
    indices = [loop(i+j) for j in xrange(kernel_shift, -kernel_shift-1, -1)]
    v_ = tf.transpose(tf.gather(tf.transpose(gated_weighting), indices))
    kernels.append(tf.reduce_sum(v_ * shift_weighting, 1))

result_after_shift = tf.transpose(tf.stack(kernels, axis=0))

## Sharpening

In [8]:
powed_conv_w = tf.pow(result_after_shift, gamma)
after_sharp = powed_conv_w / tf.expand_dims(tf.reduce_sum(powed_conv_w,1),1)

## Update memory

In [9]:
write_weighting = tf.expand_dims(after_sharp, 2)
write_vector = tf.expand_dims(add, 1)
erase_vector = tf.expand_dims(erase, 1)
erasing = memory_state[0][0] * (1 - tf.batch_matmul(write_weighting, erase_vector))
writing = tf.batch_matmul(write_weighting, write_vector)
updated_memory = erasing + writing
print(updated_memory)

Tensor("add_2:0", shape=(2, 10, 5), dtype=float32)


In [4]:
new_weigth, new_memory = memory.write(memory_state[0],memory_state[1],keys,strengths,interpolation_gate,shift_weighting,gamma,add,erase)
print(new_weigth)
print(new_memory)


Tensor("div:0", shape=(2, 10), dtype=float32)
Tensor("add_2:0", shape=(2, 10, 5), dtype=float32)


In [5]:
with tf.Session() as session:
    values = session.run(new_memory)
    print(values)

[[[ 0.200001    0.16000101  0.40000102  0.100001    0.010001  ]
  [ 0.200001    0.16000101  0.40000102  0.100001    0.010001  ]
  [ 0.200001    0.16000101  0.40000102  0.100001    0.010001  ]
  [ 0.200001    0.16000101  0.40000102  0.100001    0.010001  ]
  [ 0.200001    0.16000101  0.40000102  0.100001    0.010001  ]
  [ 0.200001    0.16000101  0.40000102  0.100001    0.010001  ]
  [ 0.200001    0.16000101  0.40000102  0.100001    0.010001  ]
  [ 0.200001    0.16000101  0.40000102  0.100001    0.010001  ]
  [ 0.200001    0.16000101  0.40000102  0.100001    0.010001  ]
  [ 0.200001    0.16000101  0.40000102  0.100001    0.010001  ]]

 [[ 0.10000096  0.11000096  0.20000099  0.100001    0.010001  ]
  [ 0.10000096  0.11000096  0.20000099  0.100001    0.010001  ]
  [ 0.10000096  0.11000096  0.20000099  0.100001    0.010001  ]
  [ 0.10000096  0.11000096  0.20000099  0.100001    0.010001  ]
  [ 0.10000096  0.11000096  0.20000099  0.100001    0.010001  ]
  [ 0.10000096  0.11000096  0.20000099