#Implementing the DNC in PyTorch
See archive/dnc/mem_ops.py:

## TO-DO implement test code for files


In [None]:
import torch
import numpy as np

In [None]:
def init_memory(N, W, R):
    """
    returns the initial values of the memory matrix, usage vector,
    precedence vector, link matrix, read weightings, write weightings,
    and the read vectors
    """

    M0 = torch.fill([N, W], 1e-6)
    u0 = torch.zeros([N])
    p0 = torch.zeros([N])
    L0 = torch.zeros([N, N])
    wr0 = torch.fill([N, R], 1e-6)  # initial read weightings
    ww0 = torch.fill([N], 1e-6)  # initial write weightings
    r0 = torch.fill([W, R], 1e-6)  # initial read vector

    return M0, u0, p0, L0, wr0, ww0, r0

In [None]:
def parse_interface(zeta, N, W, R):
    """
    returns the individual components of the interface vector
    """
    cursor = 0  # keeps track of how far we parsed into zeta
    kr, cursor = torch.reshape(zeta[cursor:cursor + W*R], [W, R]), cursor + W*R
    br, cursor = zeta[cursor:cursor + R], cursor + R
    kw, cursor = torch.reshape(zeta[cursor: cursor + W], [W, 1]), cursor + W
    bw, cursor = zeta[cursor], cursor + 1
    e, cursor = zeta[cursor: cursor + W], cursor + W
    v, cursor = zeta[cursor: cursor + W], cursor + W
    f, cursor = zeta[cursor: cursor + R], cursor + R
    ga, cursor = zeta[cursor], cursor + 1
    gw, cursor = zeta[cursor], cursor + 1
    pi = torch.reshape(zeta[cursor:], [3, R])

    # transforming the parsed components into their correct values
    oneplus = lambda z: 1 + torch.nn.softplus(z)

    e = torch.nn.sigmoid(e)
    f = torch.nn.sigmoid(f)
    ga = torch.nn.sigmoid(ga)
    gw = torch.nn.sigmoid(gw)
    br = oneplus(br)
    bw = oneplus(bw)
    pi = torch.nn.softmax(pi, 0)

    return kr, br, kw, bw, e, v, f, ga, gw, pi

In [None]:
def C(M, k, b):
    """
    Content-based addressing weightings
    """
    M_normalized = torch.nn.l2_normalize(M, 1)
    k_normalized = torch.nn.l2_normalize(k, 0)
    similarity = torch.matmul(M_normalized, k_normalized)

    return torch.nn.softmax(similarity * b, 0)

In [None]:
def ut(u, f, wr, ww):
    """
    returns the updated usage vector given the previous one along with
    free gates and previous read and write weightings
    """
    psi_t = torch.reduce_prod(1 - f * wr, 1)
    return (u + ww - u * ww) * psi_t


In [None]:
def at(ut, N):
    """
    returns the allocation weighting given the updated usage vector
    """
    sorted_ut, free_list = torch.nn.top_k(-1 * ut, N)
    sorted_ut *= -1  # brings the usages to the original positive values

    # the exclusive argument makes the first element in the cumulative
    # product a 1 instead of the first element in the given tensor
    sorted_ut_cumprod = torch.cumprod(sorted_ut, exclusive=True)
    out_of_location_at = (1 - sorted_ut) * sorted_ut_cumprod

    empty_at_container = torch.TensorArray(torch.float32, N)
    full_at_container = empty_at_container.scatter(free_list, out_of_location_at)

    return full_at_container.pack()

In [None]:
def wwt(ct, at, gw, ga):
    """
    returns the upadted write weightings given allocation and content-based
    weightings along with the write and allocation gates
    """
    ct = torch.squeeze(ct)
    return gw * (ga * at + (1 - ga) * ct)

In [None]:
--------------------------------------------------------------------

In [4]:
def Lt(L, wwt, p, N):
    """
    returns the updated link matrix given the previous one along 
    with the updated write weightings and the previous precedence 
    vector
    """
    def pairwise_add(v):
        """
        returns the matrix of pairs - adding the elements of v to 
        themselves
        """
        n = v.get_shape().as_list()[0]
        # a NxN matrix of duplicates of u along the columns
        V = torch.concat(1, [v] * n)  
        return V + V

    # expand dimensions of wwt and p to make matmul behave as outer 
    # product
    wwt = torch.expand_dims(wwt, 1)
    p = torch.expand_dims(p, 0)

    I = torch.constant(np.identity(N, dtype=np.float32))
    return (((1 - pairwise_add(wwt)) * L + 
             torch.matmul(wwt, p)) * (1 - I))


In [None]:
def pt(wwt, p):
    """
    returns the updated precedence vector given the new write weightings and
    the previous precedence vector
    """
    return (1 - torch.reduce_sum(wwt)) * p + wwt

In [None]:
def Mt(M, wwt, e, v):
    """
    returns the updated memory matrix given the previous one, the new write
    weightings, and the erase and write vectors
    """
    # expand the dims of wwt, e, and v to make matmul
    # behave as outer product
    wwt = torch.expand_dims(wwt, 1)
    e = torch.expand_dims(e, 0)
    v = torch.expand_dims(v, 0)

    return M * (1 - torch.matmul(wwt, e)) + torch.matmul(wwt, v)

In [None]:
def wrt(wr, Lt, ct, pi):
    """
    returns the updated read weightings given the previous ones, the new link
    matrix, a content-based weighting, and the read modes
    """
    ft = torch.matmul(Lt, wr)
    bt = torch.matmul(Lt, wr, transpose_a=True)

    return pi[0] * bt + pi[1] * ct + pi[2] * ft

In [None]:
def rt(Mt, wrt):
    """
    returns the new read vectors given the new memory matrix and the new read
    weightings
    """
    return torch.matmul(Mt, wrt, transpose_a=True)