# RETRO Transformer

In [1]:
import numpy as np
import pandas as pd
from einops import rearrange

n = 12 #Sequence length
m = 4 #Chunk length
k = 2 #Amount of neighbours
r = 5 #Retrieval length
d = 2 #Embedding size
l = n // m #Number of chunks
t = 50 #Amount of tokens in db

observation = np.random.randint(20, size=n) #Input tokens
print(observation)

[ 1 17 17 13 18  6 16 14 10 19  6  8]


## Simple Cross Attention

In [2]:
# Parameters
Q = np.random.rand(d, d)
K = np.random.rand(d, d)
V = np.random.rand(d, d)

def cross_attention(chunk, neighbour):
    m, d = chunk.shape
    r, d = neighbour.shape
    queries = chunk @ Q
    keys = neighbour @ K
    logits = queries @ keys.T
    values = neighbour @ V
    return logits, values

## Encoding Retieval Neighbours

In [3]:
chunks = rearrange(observation, '(l m) -> l m', l=l)
#print(chunks.shape)

encoder = np.random.rand(l, r, d) #Encoder for db
#print(encoder.shape)

chunks = np.einsum('l m, l r d -> l r d', chunks, encoder)
#print(chunks.shape)

e_db = []
for emb in np.random.rand(t, r, d):
    e_db.append([emb.tolist()])
e_db = pd.DataFrame(e_db)

neighbours = []
for chunk in chunks:
    e_db['L2'] = e_db.apply(lambda x:np.linalg.norm(chunk - np.array(x[0])), axis=1)
    neighbours.append(np.array(e_db.nsmallest(k, ['L2'])[0].tolist()))
neighbours = np.array(neighbours)
#print(neighbours.shape)

print(neighbours)
print(neighbours.shape)

[[[[0.72627626 0.69640847]
   [0.79771818 0.96784303]
   [0.41696126 0.23104722]
   [0.4148379  0.95539144]
   [0.68394057 0.88011938]]

  [[0.95287077 0.54908301]
   [0.68765086 0.69967441]
   [0.69018987 0.73420044]
   [0.89662106 0.60559892]
   [0.11098833 0.87013054]]]


 [[[0.29107286 0.9738282 ]
   [0.71244298 0.83815869]
   [0.91200024 0.61318898]
   [0.95785755 0.33630849]
   [0.47283201 0.47253173]]

  [[0.63384602 0.64709807]
   [0.44488637 0.5859764 ]
   [0.60126837 0.01433879]
   [0.97907421 0.67971735]
   [0.57075785 0.90273128]]]


 [[[0.72627626 0.69640847]
   [0.79771818 0.96784303]
   [0.41696126 0.23104722]
   [0.4148379  0.95539144]
   [0.68394057 0.88011938]]

  [[0.99519497 0.74562685]
   [0.50458975 0.41662138]
   [0.03121673 0.26845577]
   [0.93366058 0.28078022]
   [0.9971926  0.80670763]]]]
(3, 2, 5, 2)


## Chunked Cross Attention

In [4]:
emb = np.random.rand(n, d) #Embedded input tokens

attending_chunks = np.pad(emb[m-1:], ((0, m - 1), (0, 0)), mode='constant').reshape(l, m, d)

chunked_output = []
for u in range(l):
    chunk = attending_chunks[u]
    c_neighbours = neighbours[u]
    logits = []
    values = []
    for neighbour in c_neighbours:
        logit, value = cross_attention(chunk, neighbour)
        logits.append(logit)
        values.append(value)
    logits = np.array(logits)
    values = np.array(values)
    #logits += relative_positional_encodings(m, r)[None, :, :]
    logits = np.moveaxis(logits, 0, -1).reshape((m, r * k))
    values = np.moveaxis(values, 0, 1).reshape((r * k, d))
    print((logits @ values).shape)
    chunked_output.append(logits @ values)
chunked_output = np.array(chunked_output)

output = np.pad(chunked_output.reshape(n, d),((m - 1, 0), (0, 0)), mode='constant')[:n]
#First chunk
output[:m-1] = emb[:m-1]


#Last row of final chunk
logits = []
values = []
for neighbour in neighbours[l - 2]: #Use last neighbour
    logit, value = cross_attention(np.array([emb[-1]]), neighbour)
    logits.append(logit)
    values.append(value)
logits = np.moveaxis(np.array(logits), 0, -1).reshape((1, r * k))
values = np.moveaxis(np.array(values), 0, 1).reshape((r * k, d))
output[-1] = logits @ values
#print(output[-1])


#print(hash(str(output)))
print(output)

(4, 2)
(4, 2)
(4, 2)
[[0.30434409 0.99506272]
 [0.06797578 0.54649592]
 [0.18044656 0.05797994]
 [1.27454794 0.41996431]
 [1.23899826 0.40984806]
 [0.9377841  0.3097083 ]
 [1.68848659 0.55777549]
 [0.86294254 0.2952999 ]
 [1.37600583 0.4725025 ]
 [1.2810158  0.4418134 ]
 [1.12157343 0.38658226]
 [0.6683455  0.22990195]]
