# RETRO Transformer

In [1]:
import numpy as np
import pandas as pd
from einops import rearrange

n = 12 #Sequence length
m = 4 #Chunk length
k = 2 #Amount of neighbours
r = 5 #Retrieval length
d = 2 #Embedding size
l = n // m #Number of chunks
t = 50 #Amount of tokens in db

observation = np.random.randint(20, size=n) #Input tokens
print(observation)

[13 11 14 19  2 10 14 17 11  7 16 16]


## Simple Cross Attention

In [2]:
# Parameters
Q = np.random.rand(d, d)
K = np.random.rand(d, d)
V = np.random.rand(d, d)

def cross_attention(chunk, neighbour):
    m, d = chunk.shape
    r, d = neighbour.shape
    queries = chunk @ Q
    keys = neighbour @ K
    logits = queries @ keys.T
    values = neighbour @ V
    return logits, values

## Encoding Retieval Neighbours

In [27]:
chunks = rearrange(observation, '(l m) -> l m', l=l)
#print(chunks.shape)

encoder = np.random.rand(l, r, d) #Encoder for db
#print(encoder.shape)

chunks = np.einsum('l m, l r d -> l r d', chunks, encoder)
#print(chunks.shape)

e_db = []
for emb in np.random.rand(t, r, d):
    e_db.append([emb.tolist()])
e_db = pd.DataFrame(e_db)

neighbours = []
for chunk in chunks:
    e_db['L2'] = e_db.apply(lambda x:np.linalg.norm(chunk - np.array(x[0])), axis=1)
    neighbours.append(np.array(e_db.nsmallest(k, ['L2'])[0].tolist()))
neighbours = np.array(neighbours)
#print(neighbours.shape)

print(neighbours)

[[[[0.79976294 0.92299711]
   [0.89373073 0.9868527 ]
   [0.75857937 0.42092545]
   [0.59648677 0.42693725]
   [0.86557815 0.86846375]]

  [[0.94922426 0.52133696]
   [0.87750427 0.45189972]
   [0.67481455 0.36332083]
   [0.83368245 0.28567799]
   [0.49969988 0.82094618]]]


 [[[0.79976294 0.92299711]
   [0.89373073 0.9868527 ]
   [0.75857937 0.42092545]
   [0.59648677 0.42693725]
   [0.86557815 0.86846375]]

  [[0.77412505 0.21061206]
   [0.99260561 0.79865708]
   [0.39562119 0.82705474]
   [0.86144231 0.55268223]
   [0.85843361 0.29968387]]]


 [[[0.79976294 0.92299711]
   [0.89373073 0.9868527 ]
   [0.75857937 0.42092545]
   [0.59648677 0.42693725]
   [0.86557815 0.86846375]]

  [[0.77412505 0.21061206]
   [0.99260561 0.79865708]
   [0.39562119 0.82705474]
   [0.86144231 0.55268223]
   [0.85843361 0.29968387]]]]


## Chunked Cross Attention

In [29]:
emb = np.random.rand(n, d) #Embedded input tokens

attending_chunks = np.pad(emb[m-1:], ((0, m - 1), (0, 0)), mode='constant').reshape(l, m, d)

chunked_output = []
for u in range(l):
    chunk = attending_chunks[u]
    c_neighbours = neighbours[u]
    logits = []
    values = []
    for neighbour in c_neighbours:
        logit, value = cross_attention(chunk, neighbour)
        logits.append(logit)
        values.append(value)
    logits = np.array(logits)
    values = np.array(values)
    #logits += relative_positional_encodings(m, r)[None, :, :]
    logits = np.moveaxis(logits, 0, -1).reshape((m, r * k))
    values = np.moveaxis(values, 0, 1).reshape((r * k, d))
    chunked_output.append(logits @ values)
chunked_output = np.array(chunked_output)

output = np.pad(chunked_output.reshape(n, d),((m - 1, 0), (0, 0)), mode='constant')[:n]
#First chunk
output[:m-1] = emb[:m-1]


#Last row of final chunk
logits = []
values = []
for neighbour in neighbours[l - 2]: #Use last neighbour
    logit, value = cross_attention(np.array([emb[-1]]), neighbour)
    logits.append(logit)
    values.append(value)
logits = np.moveaxis(np.array(logits), 0, -1).reshape((1, r * k))
values = np.moveaxis(np.array(values), 0, 1).reshape((r * k, d))
output[-1] = logits @ values
#print(output[-1])


#print(hash(str(output)))
print(output)

[[0.37655074 0.69084134]
 [0.77894148 0.19423186]
 [0.12194683 0.49915636]
 [1.52714816 5.52665238]
 [1.52287772 5.50376214]
 [0.72356996 2.61723092]
 [1.52918793 5.53094029]
 [1.43724882 5.11947121]
 [2.13782346 7.6177986 ]
 [0.94567979 3.36352919]
 [0.91118843 3.2356434 ]
 [1.31682943 4.68828308]]
