In [1]:
import numpy as np
import tensorflow as tf

In [2]:
alpha = 0.5

def softmax(x):
    expx = np.exp(x-np.max(x, axis=-1))
    return expx/np.sum(expx, axis=-1, keepdims=True)    

def attention(key, data, sharpness=1.0):
    # returns [mb, capacity]
    return softmax(np.einsum("mcl,ml->mc", data, key))

def attention_jacobian(key, data, sharpness=1.0):
    # returns [m,c,l] = dA[m,c]/dK[m,l]
    a = attention(key, data, sharpness)   # [m,c]
    eye = np.eye(data.shape[1])[None,:,:]
    jac = a[:,None,:]*(eye-a[:,:,None])     # softmax jacobian J[m,i,j]= dS(m,i)/da(m,j)
    print("jac:", jac)
    return sharpness*np.einsum("mcx,mxl->mcl", jac, data)

da_dk_jac = attention_jacobian
    
def key_r_jacobian(key, data, sharpness=1.0):
    # dG[m,l]/dKr[m,l]   -> [m,l,l]
    jac = attention_jacobian(key, data, sharpness)
    return np.einsum("mba,mbj->maj", data, jac)

def dd_dkw_jacobian(key_w, data, p, sharpness=1.0):
    #
    # key[m,l]
    # data[m,c,l]
    # p[m,l]
    # returns: dW/dKw[m,c,l,l]
    a_jac = da_dk_jac(key_w, data, sharpness)
    return alpha * np.einsum("ma,mcb->mcab", p, a_jac)

def dd_dp_jacobian(key_w, data, sharpness=1.0):
    #
    # key_w[m,l]
    # data[m,c,l]
    # returns [m,c,l,l]
    a = attention(key_w, data, sharpness) # [m,c]
    eye = np.eye(data.shape[-1])
    return alpha * a[:,:,None,None] * eye[None,None,:,:]

def dg_dp_jacobian()
    
def write(key, data, alpha, p, sharpness=1.0):
    a = attention(key, b, sharpness)
    w = a[:,:,None] * p[:,None,:]
    return data + alpha*(w-data)

def read(key, data, sharpness=1.0):
    a = attention(key, data, sharpness)
    return np.sum(a[:,:,None]*data[:,None,:], axis=1)

def step(key_r, data, alpha, key_w, p, sharpness=1.0):
    data = write(key_w, data, alpha, p, sharpness)
    g = read(key_r, data, sharpness)
    return g, data




In [3]:
C = 3
L = 2
B = 1
data0 = np.random.random((B,C,L))
data = data0.copy()
key_r = np.random.random((B,L))
key_w = np.random.random((B,L))
p = np.random.random((B,L))
alpha = 0.5


In [4]:
def tf_attention(key, data, sharpness=1.0):
    return tf.nn.softmax(tf.reduce_sum(data*key[:,None,:]*sharpness, axis=-1), axis=1)

def tf_write(key, b, alpha, p, sharpness=1.0):
    a = tf_attention(key, b, sharpness)
    w = a[:,:,None] * p[:,None,:]
    data = b + alpha*(w-b)
    return data

def tf_read(key, data, sharpness=1.0):
    a = tf_attention(key, data, sharpness)
    g = tf.reduce_sum(a[:,:,None] * data, axis=1)
    return g

def tf_step(key_r, data, alpha, key_w, p, sharpness=1.0):
    data = tf_write(key_w, data, alpha, p, sharpness)
    g = tf_read(key_r, data, sharpness)
    return g, data



In [5]:
d_t = tf.convert_to_tensor(data)
kw_t = tf.convert_to_tensor(key_w)
kr_t = tf.convert_to_tensor(key_r)
p_t = tf.convert_to_tensor(p)


with tf.GradientTape(persistent=True) as tape:
    tape.watch(d_t)
    tape.watch(kr_t)
    tape.watch(kw_t)
    tape.watch(p_t)
    
    g, data1 = tf_step(kr_t, d_t, alpha, kw_t, p_t)
    
print("jac dg/dKr:", tape.jacobian(g, kr_t))
print("jac dg/dKw:", tape.jacobian(g, kw_t))
print("jac dg/dP:", tape.jacobian(g, p_t))


jac dg/dKr: tf.Tensor(
[[[[0.00230014 0.00234391]]

  [[0.00234391 0.012723  ]]]], shape=(1, 2, 1, 2), dtype=float64)
jac dg/dKw: tf.Tensor(
[[[[0.00073218 0.00261006]]

  [[0.00030701 0.00279467]]]], shape=(1, 2, 1, 2), dtype=float64)
jac dg/dP: tf.Tensor(
[[[[0.16863817 0.0004524 ]]

  [[0.00107352 0.16934636]]]], shape=(1, 2, 1, 2), dtype=float64)
