In [1]:
import numpy as np
from scipy.stats import ortho_group
from embedding.schema import Schema
from embedding.encoder import Encoder
from embedding.structure import Struct

In [170]:
schema = Schema(labels=['A1','A2', 'A3', 'next', 'a','b','c','d','e'], attributes=['A1', 'A2', 'A3', 'next'])

n_emb = 1024
encoder = Encoder(schema, dim=n_emb)
E = encoder.token_emb

In [171]:

def attention_impl(seq, q, k, v, target, temp=100):
    """Causal attention over a sequence of vectors."""

    for i, s in enumerate(seq):
        res = np.zeros((i,))
        for j, ss in enumerate(seq[:i]):
            res[j] = q[i].T @ k[j]
            
        res *= temp
        w = np.exp(res)/np.sum(np.exp(res))
        if i != 0:
            seq[i][target] = np.sum([w[j]*v[j] for j in range(i)], axis=0)


def relu(x):
    return np.maximum(x, 0)


def ff_decode(x, C=100):
    """Feed forward decoder layer"""
    I = np.ones(E.shape[0])
    cond = C*(E @ x - .5 * I)
    return E.T@(relu(cond + I) - relu(cond))


def ff_path_decode(x, p, C=100):
    """Feed forward path decoder layer"""
    n_attr = encoder.attr_emb.shape[0]
    I = np.ones(n_attr)
    E_attr = E[:n_attr]
    cond = C*(E_attr @ p - .5 * I).repeat(n_emb).reshape(n_attr, n_emb)
    trans = np.array([encoder.attr_emb[i].T@x for i in range(n_attr)])

    return x + np.sum(relu(cond + trans - x) - relu(cond), axis=0)




In [177]:
## Initialize seq with positional embedding
p_dim = 16
Z = ortho_group.rvs(p_dim)
p = np.random.normal(size=p_dim)
p /= np.linalg.norm(p)

n_seq = 5
seq = []
for i in range(n_seq):
    seq.append([p,  # position 
                np.zeros(n_emb), # vector
                np.zeros(n_emb), # transformed vector
                np.zeros(n_emb), # path
                np.zeros(n_emb)  # token
                ])
    p = Z @ p


## Place encoded vector in first position
x = Struct.create(schema, ('a', {'A1': 
                                     ('b', {'A2': 
                                                ('c', {'A1': 
                                                           ('d', {'A3':'e'}),
                                                       'A2':
                                                           ('a', {'A1':'b'})
                                                       }
                                                 )}
                                      )}
                           ))
v = encoder.encode(x)

A_next = encoder.attr_emb[schema.attr_to_ind['next']]
path = ['A1', 'A2', 'A1', 'A3']

# Initialize with encoded vector
seq[0][1] = v

# Initialize path
seq[0][3] = sum([np.linalg.matrix_power(A_next, i+1) @ E[schema.token_to_ind[p]] for i, p in enumerate(path)])


def attention1(seq):
    q = [Z.T@s[0] for s in seq]
    k = [s[0] for s in seq]
    v = [s[2] for s in seq]
    attention_impl(seq, q, k, v, 1)

def attention2(seq):
    q = [Z.T@s[0] for s in seq]
    k = [s[0] for s in seq]
    v = [A_next.T@s[3] for s in seq]
    attention_impl(seq, q, k, v, 3)

def ff(seq):
    """Apply feed forward decoder to a sequence of vectors."""
    for s in seq:
        s[2] = ff_path_decode(s[1], s[3])
        s[4] = ff_decode(s[2])
    
## Run transformer
for _ in range(n_seq):
    ff(seq)
    attention1(seq)
    attention2(seq)
    
out = [encoder.decode(s[4]) for s in seq]
for o in out:
    print(o.to_strings()) if o else print(None)

"a"
"b"
"c"
"d"
"e"
