# RETRO Transformer

In [103]:
import numpy as np
from einops import rearrange

n = 12 #Sequence length
m = 4 #Chunk length
k = 2 #Amount of neighbours
r = 5 #Retrieval length
d = 2 #Embedding size
l = n // m #Number of chunks
t = 50 #Amount of tokens in db

observation = np.random.randint(20, size=n) #Input tokens
#print(hash(str(observation)))
print(observation)

[ 5 19  9 19  3  7 13  2 10 16  9 15]


## Simple Cross Attention

In [2]:
# Parameters
Q = np.random.rand(d, d)
K = np.random.rand(d, d)
V = np.random.rand(d, d)

def cross_attention(chunk, neighbour):
    m, d = chunk.shape
    r, d = neighbour.shape
    queries = chunk @ Q
    keys = neighbour @ K
    logits = queries @ keys.T
    values = neighbour @ V
    return logits, values

## Encoding Retieval Neighbours

In [161]:
chunks = rearrange(observation, '(l m) -> l m', l=l)
print(chunks.shape)

emb = np.random.rand(n, d) #Embedded input token for neighbour look up

print(emb.shape)

chunks = rearrange(emb, '(l m) d -> l m d', l=l)
print(chunks.shape)

for chunk in chunks:
    print(chunk)
    print(np.average(chunk, axis=-1))
    break

encoder = np.random.rand(r, d)
print(encoder.shape)

chunks = chunks @ encoder.T
print(chunks.shape)

print(np.random.rand(t, m, r).shape)

chunks = np.random.rand(l, r, d)

e_db = []
for emb in np.random.rand(t, r, d):
    e_db.append([emb.tolist()])
e_db = pd.DataFrame(e_db)

neighbours = []
for chunk in chunks:
    e_db['L2'] = e_db.apply(lambda x:np.linalg.norm(chunk - np.array(x[0])), axis=1)
    neighbours.append(np.array(e_db.nsmallest(k, ['L2'])[0].tolist()))
neighbours = np.array(neighbours)
#print(neighbours)
print(neighbours.shape)


neighbours = np.random.rand(l, k, r, d)
print(neighbours.shape)
#print(hash(str(neighbours)))
#print(neighbours)

(3, 4)
(12, 2)
(3, 4, 2)
[[0.00757536 0.14117046]
 [0.76278522 0.5859058 ]
 [0.46637889 0.27466621]
 [0.13247688 0.43279845]]
[0.07437291 0.67434551 0.37052255 0.28263767]
(5, 2)
(3, 4, 5)
(50, 4, 5)
(3, 2, 5, 2)
(3, 2, 5, 2)


## Chunked Cross Attention

In [131]:
emb = np.random.rand(n, d) #Embedded input tokens

attending_chunks = np.pad(emb[m-1:], ((0, m - 1), (0, 0)), mode='constant').reshape(l, m, d)

chunked_output = []
for u in range(l):
    chunk = attending_chunks[u]
    c_neighbours = neighbours[u]
    logits = []
    values = []
    for neighbour in c_neighbours:
        logit, value = cross_attention(chunk, neighbour)
        logits.append(logit)
        values.append(value)
    logits = np.array(logits)
    values = np.array(values)
    #logits += relative_positional_encodings(m, r)[None, :, :]
    logits = np.moveaxis(logits, 0, -1).reshape((m, r * k))
    values = np.moveaxis(values, 0, 1).reshape((r * k, d))
    chunked_output.append(logits @ values)
chunked_output = np.array(chunked_output)

output = np.pad(chunked_output.reshape(n, d),((m - 1, 0), (0, 0)), mode='constant')[:n]

#First chunk
output[:m-1] = emb[:m-1]

#Last row of final chunk
logits = []
values = []
for neighbour in neighbours[-1]:
    logit, value = cross_attention(np.array([emb[-1]]), neighbour)
    logits.append(logit)
    values.append(value)
logits = np.moveaxis(np.array(logits), 0, -1).reshape((1, r * k))
values = np.moveaxis(np.array(values), 0, 1).reshape((r * k, d))
output[-1] = logits @ values

#print(hash(str(output)))
print(output)

[[0.29622797 0.09046281]
 [0.45159562 0.22896399]
 [0.48413135 0.98606863]
 [2.43451523 2.03354707]
 [2.26553565 1.89220968]
 [1.7816741  1.499658  ]
 [1.54208971 1.30197663]
 [1.42007009 1.16246011]
 [2.59844625 2.21883448]
 [1.52687094 1.28589585]
 [2.01497182 1.65951107]
 [0.81804246 0.58353684]]
