In [1]:
import numpy as np
from scipy.stats import ortho_group
from embedding.schema import Schema
from embedding.encoder import Encoder
from embedding.structure import Struct

In [2]:
schema = Schema(labels=['A1','A2', 'A3', 'next', 'a','b','c','d','e'], attributes=['A1', 'A2', 'A3', 'next'])

n_emb = 1024
encoder = Encoder(schema, dim=n_emb)
E = encoder.token_emb
p_dim = 16
Z = ortho_group.rvs(p_dim) # positional embedding operator
A_next = encoder.attr_emb[schema.attr_to_ind['next']]    


In [3]:

def attention_impl(seq, q, k, v, temp=100):
    """Causal attention over a sequence of vectors."""

    out = [np.zeros_like(v[0])]*len(seq)
    
    for i, s in enumerate(seq):
        res = np.zeros((i,))
        for j, ss in enumerate(seq[:i]):
            res[j] = q[i].T @ k[j]
            
        res *= temp
        w = np.exp(res)/np.sum(np.exp(res))
        if i != 0:
            out[i] = np.sum([w[j]*v[j] for j in range(i)], axis=0)

    return out


def relu(x):
    return np.maximum(x, 0)


def ff_decode(x, C=100):
    """Feed forward decoder layer"""
    I = np.ones(E.shape[0])
    cond = C*(E @ x - .5 * I)
    return E.T@(relu(cond + I) - relu(cond))


def ff_path_decode(x, p, C=100):
    """Feed forward path decoder layer"""
    n_attr = encoder.attr_emb.shape[0]
    I = np.ones(n_attr)
    E_attr = E[:n_attr]
    cond = C*(E_attr @ p - .5 * I).repeat(n_emb).reshape(n_attr, n_emb)
    trans = np.array([encoder.attr_emb[i].T@x for i in range(n_attr)])

    return x + np.sum(relu(cond + trans - x) - relu(cond), axis=0)



def attention(seq):
    q = [Z.T@s[0] for s in seq]
    k = [s[0] for s in seq]
    v1 = [s[2] for s in seq]
    v2 = [A_next.T@s[3] for s in seq]
    a1 = attention_impl(seq, q, k, v1)
    a2 = attention_impl(seq, q, k, v2)

    return [[0, aa1, 0, aa2, 0] for aa1, aa2 in zip(a1, a2)]

def ff1(seq):
    res = []
    for s in seq:
        res.append([0, 0, ff_path_decode(s[1], s[3]) - s[2], 0, 0])
    return res

def ff2(seq):
    res = []
    for s in seq:
        res.append([0, -s[1], 0, 0, ff_decode(s[2])])
    return res

def add(s1, s2):
    res = []
    for ss1, ss2 in zip(s1, s2):
        res.append([ss1[i] + ss2[i] for i in range(len(ss1))])
    return res


def transformer(seq):
    x = seq
    for _ in range(len(seq)):
        x = add(x, ff1(x))
        x = add(x, ff2(x))
        x = add(x, attention(x))
    return x

def encode_path(path):
    return sum([np.linalg.matrix_power(A_next, i+1) @ E[schema.token_to_ind[p]] for i, p in enumerate(path)])

def init_seq(path, v):
    ## Initialize seq with positional embedding
    p = np.random.normal(size=p_dim)
    p /= np.linalg.norm(p)
    n_seq = len(path) + 1
    seq = []
    for i in range(n_seq):
        seq.append([p,  # position 
                    np.zeros(n_emb), # vector
                    np.zeros(n_emb), # transformed vector
                    np.zeros(n_emb), # path
                    np.zeros(n_emb)  # token
                    ])
        p = Z @ p

    # Initialize with encoded vector
    seq[0][1] = v

    # Encode and initialize path
    seq[0][3] = encode_path(path)
    
    return seq

In [4]:
x = Struct.create(schema, ('a', {'A1': 
                                     ('b', {'A2': 
                                                ('c', {'A1': 
                                                           ('d', {'A3':'e'}),
                                                       'A2':
                                                           ('a', {'A1':'b'})
                                                       }
                                                 )}
                                      )}
                           ))
v = encoder.encode(x)

path1 = ['A1', 'A2', 'A2', 'A1']
path2 = ['A1', 'A2', 'A1', 'A3']

x = transformer(init_seq(path1, v))
out = [encoder.decode(s[4]) for s in x]
print([(o.to_strings()) if o else None for o in out])

x = transformer(init_seq(path2, v))
out = [encoder.decode(s[4]) for s in x]
print([(o.to_strings()) if o else None for o in out])

["a", "b", "c", "a", "b"]
["a", "b", "c", "d", "e"]
