In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
import json
import collections
import nltk
import matplotlib.pyplot as plt

%matplotlib inline

In [3]:
import tensorflow as tf
from tensorflow.python.layers.core import Dense
from tensorflow.contrib.rnn import RNNCell

In [4]:
def oneplus(x):
    return 1 + tf.log(1 + tf.exp(x))

def content_weights(memory, keys, strengths):
    similarity = tf.matmul(
        tf.nn.l2_normalize(keys, -1),
        tf.nn.l2_normalize(memory, -1),
        transpose_b = True
    )

    return tf.nn.softmax(tf.expand_dims(strengths, -1) * similarity, -1)

def allocation_weights(usage):
    sorted_usage_values = -tf.nn.top_k(-usage, k=tf.shape(usage)[-1]).values
    return (1 - usage) * tf.cumprod(sorted_usage_values, -1, exclusive=True)

def set_diag_to_zero(matrix):
    return tf.matrix_set_diag(matrix, tf.zeros(tf.shape(matrix)[:-1]))

class Controller:
    
    def __init__(self, memory_shape, output_size, n_read_keys):
        self._memory_shape = memory_shape
        self._output_size = output_size
        self._n_read_keys = n_read_keys
        
        self._output = Dense(self._output_size, use_bias=False, name='controller_output')
        self._interface = Dense(self._interface_size, use_bias=False, name='controller_interface')
        
    @property
    def _interface_size(self):
        return (self._n_read_keys + 3)*self._memory_shape[1] + 2*self._n_read_keys + 6
    
    def __call__(self, inputs, previous_read_values):
        inputs = tf.concat([
            inputs,
            tf.reshape(previous_read_values, [tf.shape(inputs)[0], self._n_read_keys*self._memory_shape[1]])
        ], 1)

        return (
            self._output(inputs),
            self._parse_interface(self._interface(inputs))
        )
    
    def _parse_interface(self, interface):
        (read_keys, read_strengths,
         write_key, write_strength,
         erase_vector, write_vector,
         free_gates, write_gate, allocation_gate,
         read_modes) = tf.split(interface, [
            self._n_read_keys * self._memory_shape[1], # Read keys
            self._n_read_keys, # Read strengths
            self._memory_shape[1], # Write key
            1, # Write strength
            self._memory_shape[1], # Erase vector
            self._memory_shape[1], # Write vector
            self._n_read_keys, # Free gates
            1, # Write gate
            1, # Allocation gate
            3, # Read modes
        ], -1)
        
        return {
            'read_keys': tf.reshape(read_keys, [-1, self._n_read_keys, self._memory_shape[1]]),
            'read_strengths': oneplus(read_strengths),
            'write_key': write_key,
            'write_strength': oneplus(write_strength),
            'erase_vector': tf.sigmoid(erase_vector),
            'write_vector': write_vector,
            'free_gates': tf.sigmoid(free_gates),
            'write_gate': tf.sigmoid(write_gate),
            'allocation_gate': tf.sigmoid(allocation_gate),
            'read_modes': tf.nn.softmax(read_modes, -1)
        }

DNCState = collections.namedtuple('DNCState', ('memory', 'read_values', 'read_weights', 'write_weights', 'usage', 'precedence', 'linkage'))
    
class DNC(RNNCell):
    
    def __init__(self, memory_shape, output_size, read_keys=3, activation=None):
        self._memory_shape = list(memory_shape)
        self._output_size = int(output_size)
        self._n_read_keys = int(read_keys)
        self._activation = activation
        
        self._controller = Controller(self._memory_shape, self._output_size, self._n_read_keys)
        self._read_proj = Dense(self._output_size, use_bias=False, name='read_values_projetion')
    
    def __call__(self, inputs, state):
        memory = state.memory
        output, interface = self._controller(inputs, state.read_values)
        
        write_weights, usage = self._write_operators(memory, interface, state)
        memory = self._write_memory(memory, interface, write_weights)
        
        read_weights, precedence, linkage = self._read_operators(memory, interface, state, write_weights)
        read_values = self._read_memory(memory, read_weights)
        
        read_values = tf.Print(read_values, [read_values])
        
        return (
            self._merge_output(output, read_values),
            DNCState(memory, read_values, read_weights, write_weights, usage, precedence, linkage)
        )

    def zero_state(self, batch_size, dtype):
        init_value = 1e-6
        return DNCState(
            memory = tf.fill([batch_size] + self._memory_shape, init_value),
            read_values = tf.fill([batch_size, self._n_read_keys, self._memory_shape[1]], init_value),
            read_weights = tf.fill([batch_size, self._n_read_keys, self._memory_shape[0]], init_value),
            write_weights = tf.fill([batch_size, self._memory_shape[0]], init_value),
            usage = tf.zeros([batch_size, self._memory_shape[0]], dtype),
            precedence = tf.zeros([batch_size, self._memory_shape[0]], dtype),
            linkage = tf.zeros([batch_size, self._memory_shape[0], self._memory_shape[0]], dtype)
        )
    
    @property
    def output_size(self):
        return self._output_size
    
    @property
    def state_size(self):
        return DNCState(
            memory = self._memory_shape,
            read_values = [self._n_read_keys, self._memory_shape[1]],
            read_weights = [self._n_read_keys, self._memory_shape[0]],
            write_weights = self._memory_shape[0],
            usage = self._memory_shape[0],
            precedence = self._memory_shape[0],
            linkage = [self._memory_shape[0], self._memory_shape[0]]
        )

    def _write_operators(self, memory, interface, state):
        retention = tf.reduce_prod(1 - tf.expand_dims(interface['free_gates'], -1) * state.read_weights, 1)
        usage = (state.usage + state.write_weights - state.usage*state.write_weights) * retention
        
        allocation_w = allocation_weights(usage)
        content_w = content_weights(memory, tf.expand_dims(interface['write_key'], -2), interface['write_strength'])
        content_w = tf.squeeze(content_w, 1)
        
        write_weights = (interface['allocation_gate']*allocation_w + (1 - interface['allocation_gate'])*content_w) * interface['write_gate']
        
        return write_weights, usage
    
    def _read_operators(self, memory, interface, state, write_weights):
        ww_rows = tf.expand_dims(write_weights, -1)
        ww_cols = tf.expand_dims(write_weights, -2)
        pp_cols = tf.expand_dims(state.precedence, -2)
        
        precedence = (1 - tf.reduce_sum(write_weights, -1, keep_dims=True))*state.precedence + write_weights
        linkage = set_diag_to_zero((1 - ww_rows - ww_cols)*state.linkage + ww_rows*pp_cols)
        
        forward_w = tf.matmul(state.read_weights, linkage)
        backward_w = tf.matmul(state.read_weights, linkage, transpose_b=True)
        content_w = content_weights(memory, interface['read_keys'], interface['read_strengths'])
        
        read_modes = tf.reshape(interface['read_modes'], [tf.shape(interface['read_modes'])[0],3,1,1])
        read_weights = tf.concat([
            tf.expand_dims(forward_w, 1),
            tf.expand_dims(content_w, 1),
            tf.expand_dims(backward_w, 1)
        ], 1)
        
        read_weights = tf.reduce_sum(read_modes * read_weights, 1)

        return read_weights, precedence, linkage
    
    def _write_memory(self, memory, interface, write_weights):
        ww_rows = tf.expand_dims(write_weights, -1)
        
        previous_memory_gate = 1 - tf.matmul(
            ww_rows,
            tf.expand_dims(interface['erase_vector'], -2)
        )
        new_memory_values = tf.matmul(
            ww_rows,
            tf.expand_dims(interface['write_vector'], -2)
        )
        
        return memory*previous_memory_gate + new_memory_values
    
    def _read_memory(self, memory, read_weights):
        return tf.matmul(read_weights, memory)
    
    def _merge_output(self, output, read_values):
        read_values = tf.reshape(read_values, [tf.shape(read_values)[0], self._n_read_keys*self._memory_shape[1]])
        
        linear = output + self._read_proj(read_values)
        
        if self._activation is not None:
            return self._activation(linear)
        return linear

In [5]:
train_data = json.load(open('./data/babi/train_embedded.json', 'r'))
valid_data = json.load(open('./data/babi/valid_embedded.json', 'r'))
word_dict = json.load(open('./data/babi/word_dict.json', 'r'))

In [6]:
graph = tf.Graph()

with graph.as_default():
    inputs_ph = tf.placeholder(tf.int32, [1, None])
    targets_ph = tf.placeholder(tf.int32, [1, None])
    inputs_one_hot = tf.one_hot(inputs_ph, len(word_dict))

    dnc_cell = DNC([256, 64], len(word_dict))
    outputs, state = tf.nn.dynamic_rnn(
        dnc_cell, inputs_one_hot, dtype=tf.float32
    )
    
    
    loss = tf.nn.softmax_cross_entropy_with_logits(
        logits = outputs,
        labels = tf.one_hot(targets_ph, len(word_dict))
    )
    loss = tf.reduce_mean(tf.reduce_sum(loss * tf.cast(tf.greater(targets_ph, 0), tf.float32), -1))
    accuracy = tf.reduce_sum(tf.cast(tf.equal(targets_ph, tf.cast(tf.argmax(outputs, -1), tf.int32)), tf.float32)) / tf.reduce_sum(tf.cast(tf.greater(targets_ph, 0), tf.float32))
    
    tf.summary.scalar('loss', loss)
    tf.summary.scalar('accuracy', accuracy)
    
    optimizer = tf.train.AdamOptimizer(learning_rate=1e-1)
    
    gradients = optimizer.compute_gradients(loss)
    for i, (grad, var) in enumerate(gradients):
        if grad is not None:
            gradients[i] = (tf.clip_by_value(grad, -10, 10), var)
    train_op = optimizer.apply_gradients(gradients)

In [7]:
with tf.Session(graph=graph) as sess:
    sess.run(tf.global_variables_initializer())
    
    for e in range(10):
        for i, sample in enumerate(train_data):
            assert not np.any(np.isnan(sample['inputs']))
            assert not np.any(np.isnan(sample['targets']))

            _, loss_val = sess.run([train_op, loss], feed_dict={
                inputs_ph: [sample['inputs']],
                targets_ph: [sample['targets']]
            })
            
            if np.isnan(loss_val):
                raise ValueError('NaN loss.')

            print(e, i, 'Loss:', loss_val)

0 0 Loss: 25.5995
0 1 Loss: 25.042
0 2 Loss: 23.8768
0 3 Loss: 22.5398
0 4 Loss: 20.707
0 5 Loss: 100.442
0 6 Loss: 17.2951
0 7 Loss: 15.758
0 8 Loss: 16.3252
0 9 Loss: 15.4907
0 10 Loss: 15.5617
0 11 Loss: 15.4286
0 12 Loss: 14.5421
0 13 Loss: 14.2582
0 14 Loss: 13.0063
0 15 Loss: 13.0899
0 16 Loss: 12.5053
0 17 Loss: 12.7229
0 18 Loss: 11.8001
0 19 Loss: 11.6789
0 20 Loss: 12.1858
0 21 Loss: 8.21824
0 22 Loss: 9.68176
0 23 Loss: 10.1673
0 24 Loss: 12.5549
0 25 Loss: 10.3904
0 26 Loss: 10.6851
0 27 Loss: 8.92265
0 28 Loss: 8.56984
0 29 Loss: 8.52754
0 30 Loss: 8.68357
0 31 Loss: 10.3567
0 32 Loss: 9.21982
0 33 Loss: 12.1534
0 34 Loss: 11.0285
0 35 Loss: 9.50737
0 36 Loss: 10.0491
0 37 Loss: 8.45812
0 38 Loss: 7.92455
0 39 Loss: 10.5738
0 40 Loss: 7.96393
0 41 Loss: 9.77997
0 42 Loss: 10.1436
0 43 Loss: 8.9352
0 44 Loss: 9.20702
0 45 Loss: 8.20384
0 46 Loss: 10.0306
0 47 Loss: 8.07754
0 48 Loss: 9.60227
0 49 Loss: 7.06753
0 50 Loss: 9.40269
0 51 Loss: 10.3319
0 52 Loss: 9.55987
0 53 Lo

ValueError: NaN loss.

In [60]:
len(word_dict)

161