In [3]:
import numpy as np
import matplotlib.pyplot as plt
import keras
import keras.backend as K
import tensorflow as tf

In [4]:
X_train = np.random.randint(3, high=32, size=(2000,16))
X_target = np.sort(X_train)
Y_train = np.copy(X_target)
Y_train[:,1:] = Y_train[:,:-1]
Y_train[:,0] = 1
X_train = np.concatenate([X_train, X_target], axis=-1)
X_train,X_test = X_train[1000:],X_train[:1000]
Y_train,Y_test = Y_train[1000:],Y_train[:1000]
X_train.shape,Y_train.shape,X_test.shape,Y_test.shape

((1000, 32), (1000, 16), (1000, 32), (1000, 16))

In [5]:
data_input_dim = max(np.max(X_train), np.max(X_test)) + 1
data_input_dim

32

In [6]:
data_output_dim = max(np.max(Y_train), np.max(Y_test)) + 1
data_output_dim

32

In [7]:
class DNCCell(keras.layers.Layer):
    
    def __init__(self,
                 output_size,
                 word_size,
                 memory_size,
                 register_size=2,
                 write_head_count=2,
                 read_head_count=2,
                 enable_temporal=True,
                 enable_final_bias=True,
                 bypass_dropout_factor=None,
                 **kwargs):
        self.output_size = output_size
        self.word_size = word_size
        self.memory_size = memory_size
        self.register_size = register_size
        self.write_head_count = write_head_count
        self.read_head_count = read_head_count
        self.enable_temporal = enable_temporal
        self.enable_final_bias = enable_final_bias
        self.bypass_dropout_factor = bypass_dropout_factor
        self.state_size = (output_size,)
        self.state_size += (register_size * word_size,)
        self.state_size += (register_size * word_size,)
        self.state_size += (memory_size * word_size,)
        self.state_size += (read_head_count * memory_size,)
        self.state_size += (read_head_count * word_size,)
        self.state_size += (write_head_count * memory_size,)
        self.state_size += (memory_size,)
        if enable_temporal:
            self.state_size += (write_head_count * memory_size,)
            self.state_size += (write_head_count * memory_size**2,)
        super().__init__(**kwargs)
    
    def build(self, input_shape):
        read_vec_size = self.word_size * self.read_head_count
        controller_input_size = input_shape[-1]
        controller_input_size += read_vec_size
        controller_input_size += self.register_size*self.word_size
        controller_kernel_size = self.register_size*self.word_size*4
        controller_hidden_size = self.register_size*self.word_size
        interface_vec_size = read_vec_size
        interface_vec_size += self.read_head_count*2
        if self.enable_temporal:
            interface_vec_size += self.read_head_count*3
        interface_vec_size += self.write_head_count*self.word_size*3
        interface_vec_size += self.write_head_count*3
        self.kernel_controller_hidden = self.add_weight(
            shape=(controller_input_size, controller_kernel_size),
            initializer='glorot_normal', name='kernel_controller_hidden')
        self.bias_controller_hidden = self.add_weight(
            shape=(controller_kernel_size,),
            initializer='zeros', name='bias_controller_hidden')
        self.kernel_controller_output = self.add_weight(
            shape=(controller_hidden_size, self.output_size + interface_vec_size),
            initializer='glorot_normal', name='kernel_controller_output')
        self.kernel_read_vec_to_output = self.add_weight(
            shape=(read_vec_size, self.output_size),
            initializer='glorot_normal', name='kernel_read_vec_to_output')
        if self.enable_final_bias:
            self.bias_final_output = self.add_weight(
                shape=(self.output_size,),
                initializer='zeros', name='bias_final_output')
        super().build(input_shape)
    
    def call(self, inputs, states, training=None):
        
        def oneplus(x):
            return K.softplus(x) + 1.
        
        def similarity(m, k, b):
            dot = K.batch_dot(k, m, axes=2)
            m_len = K.sqrt(K.sum(K.square(m), axis=-1))
            k_len = K.sqrt(K.sum(K.square(k), axis=-1))
            mk_len = K.expand_dims(k_len, axis=-1) @ K.expand_dims(m_len, axis=-2)
            mk_len = K.switch(K.not_equal(mk_len, 0.), mk_len, K.ones_like(mk_len))
            cos_sim = dot / mk_len
            return K.softmax(cos_sim * K.expand_dims(b))
        
        def batch_invert_permutation(permutations):
            # from https://github.com/deepmind/dnc/blob/master/dnc/util.py
            perm = tf.cast(permutations, tf.float32)
            dim = int(perm.get_shape()[-1])
            size = tf.cast(tf.shape(perm)[0], tf.float32)
            delta = tf.cast(tf.shape(perm)[-1], tf.float32)
            rg = tf.range(0, size * delta, delta, dtype=tf.float32)
            rg = tf.expand_dims(rg, 1)
            rg = tf.tile(rg, [1, dim])
            perm = tf.add(perm, rg)
            flat = tf.reshape(perm, [-1])
            perm = tf.invert_permutation(tf.cast(flat, tf.int32))
            perm = tf.reshape(perm, [-1, dim])
            return tf.subtract(perm, tf.cast(rg, tf.int32))

        def batch_gather(values, indices):
            # from https://github.com/deepmind/dnc/blob/master/dnc/util.py
            idx = tf.expand_dims(indices, -1)
            size = tf.shape(indices)[0]
            rg = tf.range(size, dtype=tf.int32)
            rg = tf.expand_dims(rg, -1)
            rg = tf.tile(rg, [1, int(indices.get_shape()[-1])])
            rg = tf.expand_dims(rg, -1)
            gidx = tf.concat([rg, idx], -1)
            return tf.gather_nd(values, gidx)
        
        _, register_s_last, register_h_last, memory_last, \
            read_weights_last, read_vec_last, write_weights_last, \
            usage_last, *temporal_states = states
        memory_last = K.reshape(memory_last,
            (-1,self.memory_size,self.word_size))
        read_weights_last = K.reshape(read_weights_last,
            (-1,self.read_head_count,self.memory_size))
        write_weights_last = K.reshape(write_weights_last,
            (-1,self.write_head_count,self.memory_size))
        if not self.enable_temporal:
            assert not temporal_states
        else:
            preced_last, link_last = temporal_states
            preced_last = K.reshape(preced_last,
                (-1,self.write_head_count,self.memory_size))
            link_last = K.reshape(link_last,
                (-1,self.write_head_count,self.memory_size,self.memory_size))
        
        # feeding controller with current input and last read vecs
        output = K.concatenate([inputs, read_vec_last, register_h_last])
        output = output @ self.kernel_controller_hidden
        output = output + self.bias_controller_hidden
        ctr_input_gate = output[:,:self.register_size*self.word_size]
        ctr_input_gate = K.sigmoid(ctr_input_gate)
        ctr_forget_gate = output[:,
            self.register_size*self.word_size:
            self.register_size*self.word_size*2]
        ctr_forget_gate = K.sigmoid(ctr_forget_gate)
        register_s = output[:,
            self.register_size*self.word_size*2:
            self.register_size*self.word_size*3]
        register_s = ctr_input_gate * K.tanh(register_s)
        register_s = register_s + ctr_forget_gate*register_s_last
        ctr_output_gate = output[:,
            self.register_size*self.word_size*3:
            self.register_size*self.word_size*4]
        ctr_output_gate = K.sigmoid(ctr_output_gate)
        register_h = ctr_output_gate * K.tanh(register_s)
        output = register_h @ self.kernel_controller_output
        
        # break down controller output into semi final output and interface vec
        interface_vec = output[:,self.output_size:]
        output = output[:,:self.output_size]
        if self.bypass_dropout_factor is not None:
            output = K.in_train_phase(K.dropout(
                output,self.bypass_dropout_factor),output,training=training)
        
        # break down interface vec
        interface_pos_last = 0
        interface_partition = []
        for interface_part_len in [
            self.read_head_count * self.word_size,
            self.read_head_count,
            self.write_head_count * self.word_size,
            self.write_head_count,
            self.write_head_count * self.word_size,
            self.write_head_count * self.word_size,
            self.read_head_count,
            self.write_head_count,
            self.write_head_count,
            *([self.read_head_count * 3] if self.enable_temporal else [])]:
            interface_pos = interface_pos_last + interface_part_len
            interface_partition.append(
                interface_vec[:,interface_pos_last:interface_pos])
            interface_pos_last = interface_pos
        read_keys, read_stre, write_keys, write_stre, erase_vecs, \
            write_vecs, free_gates, alloc_gates, write_gates, \
            *temporal_interface_partition = interface_partition
        read_keys = K.reshape(read_keys,(-1,self.read_head_count,self.word_size))
        read_stre = oneplus(read_stre)
        write_keys = K.reshape(write_keys,(-1,self.write_head_count,self.word_size))
        write_stre = oneplus(write_stre)
        erase_vecs = K.reshape(erase_vecs,(-1,self.write_head_count,self.word_size))
        erase_vecs = K.sigmoid(erase_vecs)
        write_vecs = K.reshape(write_vecs,(-1,self.write_head_count,self.word_size))
        free_gates = K.expand_dims(free_gates)
        free_gates = K.sigmoid(free_gates)
        alloc_gates = K.expand_dims(alloc_gates)
        alloc_gates = K.sigmoid(alloc_gates)
        write_gates = K.expand_dims(write_gates)
        write_gates = K.sigmoid(write_gates)
        if not self.enable_temporal:
            assert not temporal_interface_partition
        else:
            read_modes, = temporal_interface_partition
            read_modes = K.reshape(read_modes,(-1,self.read_head_count,3))
            read_modes = K.softmax(read_modes,axis=-1)
        
        # compute allocation vector
        retention = K.prod(1.-(free_gates*read_weights_last), axis=-2)
        # https://github.com/deepmind/dnc/blob/master/dnc/addressing.py
        # according to the deepmind implementation,
        # only write weight is not differentiable
        write_weights_last_nograd = K.stop_gradient(write_weights_last)
        # reduce for multi-write-head, not presented in original papaer
        write_weights_last_nograd = 1.-K.prod(1.-write_weights_last_nograd, axis=-2)
        usage = ((usage_last+write_weights_last_nograd) - \
                usage_last*write_weights_last_nograd) * retention
        # loop for multi-write-head support
        mwh_write_gates = write_gates * alloc_gates
        mwh_usage = usage
        mwh_alloc = []
        for i in range(self.write_head_count):
            # quickfix for tf.cumprod grad bug
            # https://github.com/tensorflow/tensorflow/issues/3862
            usage_qfixed = 1e-6 + (1-1e-6) * mwh_usage
            usage_asc,usage_perm = tf.nn.top_k(1.-usage_qfixed,k=self.memory_size)
            usage_asc = 1.-usage_asc
            alloc_asc = (1.-usage_asc) * tf.cumprod(usage_asc,axis=-1,exclusive=True)
            alloc_perm = batch_invert_permutation(usage_perm)
            alloc = batch_gather(alloc_asc, alloc_perm)
            mwh_alloc.append(alloc)
            mwh_usage += (1.-mwh_usage) * mwh_write_gates[:,i,:] * alloc
        alloc = K.stack(mwh_alloc, axis=-2)
        
        # compute write weight
        write_sims = similarity(memory_last,write_keys,write_stre)
        write_weights = write_gates*(alloc_gates*alloc + (1.-alloc_gates)*write_sims)
        
        # compute precedence and temporal links
        if self.enable_temporal:
            write_weights_rep_h =  K.expand_dims(write_weights,axis=-1)
            write_weights_rep_v =  K.expand_dims(write_weights,axis=-2)
            link = (1.-(write_weights_rep_h+write_weights_rep_v))*link_last + \
                write_weights_rep_h * K.expand_dims(preced_last,axis=-2)
            link = link * (1.-tf.eye(self.memory_size))
            read_weights_last_rep = K.repeat_elements(
                K.expand_dims(read_weights_last,axis=-3),self.write_head_count,-3)
            # reduce sum for multi-write-head
            link_forward = K.sum(read_weights_last_rep @ link,axis=-3)
            link_backward = K.sum(K.permute_dimensions(link @ K.permute_dimensions(
                read_weights_last_rep,(0,1,3,2)),(0,1,3,2)),axis=-3)
            preced = (1.-K.sum(
                write_weights,axis=-1,keepdims=True))*preced_last + write_weights
        
        # update memory
        m_reset = K.expand_dims(write_weights,axis=-1) @ \
            K.expand_dims(erase_vecs,axis=-2)
        # reduce prod for multi-write-head
        m_keep = K.prod(1.-m_reset,axis=-3)
        m_new = K.expand_dims(write_weights,axis=-1) @ \
            K.expand_dims(write_vecs,axis=-2)
        # reduce sum for multi-write-head
        m_new = K.sum(m_new, axis=-3)
        memory = memory_last*m_keep + m_new
        
        # compute read weights and read vectors
        read_sims = similarity(memory,read_keys,read_stre)
        if not self.enable_temporal:
            read_weights = read_sims
        else:
            read_weights = read_modes[:,:,0:1] * read_sims + \
                            read_modes[:,:,1:2] * link_forward + \
                            read_modes[:,:,2:3] * link_backward
        read_vecs = read_weights @ memory
        read_vec = K.reshape(read_vecs, (-1,self.read_head_count*self.word_size))
        
        # compute final output from controller output and current read vecs
        output = output + read_vec @ self.kernel_read_vec_to_output
        if self.enable_final_bias:
            output = output + self.bias_final_output
        
        return output, [
            output,
            register_s,
            register_h,
            K.reshape(memory, (-1,self.memory_size*self.word_size)),
            K.reshape(read_weights, (-1,self.read_head_count*self.memory_size)),
            read_vec,
            K.reshape(write_weights, (-1,self.write_head_count*self.memory_size)),
            usage,
            *([
                K.reshape(preced, (-1,self.write_head_count*self.memory_size)),
                K.reshape(link, (-1,self.write_head_count*self.memory_size**2))
            ] if self.enable_temporal else [])]
    
    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_size)

In [8]:
X = X_inputs = keras.layers.Input((None,), dtype='int32')
X = keras.layers.Embedding(data_input_dim, 16, mask_zero=True)(X)
X = keras.layers.RNN(DNCCell(data_output_dim, 8, 16), return_sequences=True)(X)
X = keras.layers.Lambda(lambda x,s: x[:,s:,:],
    arguments={'s':X_train.shape[1]//2})(X)
X = keras.layers.Activation('relu')(X)
X = keras.layers.TimeDistributed(keras.layers.Dense(data_output_dim))(X)
X = keras.layers.Softmax()(X)
X = [keras.layers.Lambda(lambda x,s: x[:,s,:], name=f'o{i+1}',
    arguments={'s':i}, output_shape=(data_output_dim,))(X)
    for i in range(X_train.shape[1]//2)]
M = keras.Model(X_inputs, X)
M.compile('nadam', 'sparse_categorical_crossentropy', ['acc'])
M.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, None)         0                                            
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, None, 16)     512         input_1[0][0]                    
__________________________________________________________________________________________________
rnn_1 (RNN)                     (None, None, 32)     5472        embedding_1[0][0]                
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, None, 32)     0           rnn_1[0][0]                      
__________________________________________________________________________________________________
activation

In [9]:
M.fit(X_train, np.split(Y_train,Y_train.shape[-1],axis=-1),
    validation_data=(X_test,np.split(Y_test,Y_test.shape[-1],axis=-1)),
    batch_size=8, epochs=50, callbacks=[
    keras.callbacks.ReduceLROnPlateau('loss', patience=3, verbose=1)
])

Train on 1000 samples, validate on 1000 samples
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50


Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50


Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50


Epoch 19/50
Epoch 20/50

Epoch 00020: ReduceLROnPlateau reducing learning rate to 0.00020000000949949026.
Epoch 21/50
Epoch 22/50
Epoch 23/50

Epoch 00023: ReduceLROnPlateau reducing learning rate to 2.0000000949949027e-05.
Epoch 24/50


Epoch 25/50
Epoch 26/50

Epoch 00026: ReduceLROnPlateau reducing learning rate to 2.0000001313746906e-06.
Epoch 27/50
Epoch 28/50
Epoch 29/50

Epoch 00029: ReduceLROnPlateau reducing learning rate to 2.000000222324161e-07.
Epoch 30/50


Epoch 31/50
Epoch 32/50

Epoch 00032: ReduceLROnPlateau reducing learning rate to 2.000000165480742e-08.
Epoch 33/50
Epoch 34/50
Epoch 35/50

Epoch 00035: ReduceLROnPlateau reducing learning rate to 2.000000165480742e-09.
Epoch 36/50


Epoch 37/50
Epoch 38/50

Epoch 00038: ReduceLROnPlateau reducing learning rate to 2.000000165480742e-10.
Epoch 39/50
Epoch 40/50
Epoch 41/50

Epoch 00041: ReduceLROnPlateau reducing learning rate to 2.000000165480742e-11.
Epoch 42/50


Epoch 43/50
Epoch 44/50

Epoch 00044: ReduceLROnPlateau reducing learning rate to 2.000000165480742e-12.
Epoch 45/50
Epoch 46/50
Epoch 47/50

Epoch 00047: ReduceLROnPlateau reducing learning rate to 2.000000208848829e-13.
Epoch 48/50


Epoch 49/50
Epoch 50/50

Epoch 00050: ReduceLROnPlateau reducing learning rate to 2.0000002359538835e-14.


<keras.callbacks.History at 0x7f97e579a7f0>