# RETRO Transformer

In [5]:
import numpy as np
from einops import rearrange

n = 12 #Sequence length
m = 4 #Chunk length
k = 2 #Amount of neighbours
r = 5 #Retrieval length
d = 2 #Embedding size
l = n // m #Number of chunks

observation = np.random.rand(n, d) #Input data
#print(hash(str(observation)))
print(observation)

[[0.34436241 0.37406753]
 [0.15645653 0.6920059 ]
 [0.41509948 0.56976514]
 [0.73416264 0.80577888]
 [0.03109859 0.8786712 ]
 [0.4192683  0.5561915 ]
 [0.2903888  0.23452864]
 [0.42214147 0.41872855]
 [0.17304639 0.12566048]
 [0.18654496 0.99388621]
 [0.53282135 0.83040245]
 [0.00402069 0.48227176]]


In [2]:
# Parameters
Q = np.random.rand(d, d)
K = np.random.rand(d, d)
V = np.random.rand(d, d)

def cross_attention(chunk, neighbour):
    m, d = chunk.shape
    r, d = neighbour.shape
    queries = chunk @ Q
    keys = neighbour @ K
    logits = queries @ keys.T
    values = neighbour @ V
    return logits, values

## Encoding Retieval Neighbours

In [6]:
#Get k-nearest neighbours
chunks = rearrange(observation, '(l m) d -> l m d', l=l)
#print(chunks.shape)

n_chunk = chunks[:-1]
#print(n_chunk.shape)

neighbours = np.random.rand(l, k, r, d)
#print(hash(str(neighbours)))
print(neighbours)

[[[[0.32461632 0.54865897]
   [0.92194059 0.65330148]
   [0.15434714 0.86144158]
   [0.34961148 0.40062833]
   [0.31893657 0.41166529]]

  [[0.20017308 0.64332549]
   [0.5328855  0.26695701]
   [0.09165474 0.91876226]
   [0.46785621 0.56646182]
   [0.31808915 0.45808314]]]


 [[[0.53630748 0.29439516]
   [0.85420452 0.02874575]
   [0.54641079 0.25959703]
   [0.7718861  0.25873036]
   [0.73233494 0.43678262]]

  [[0.19063246 0.61184029]
   [0.91917042 0.99544409]
   [0.27038337 0.09932637]
   [0.25144593 0.89861245]
   [0.04431888 0.78519688]]]


 [[[0.72828598 0.72217343]
   [0.51168345 0.09565324]
   [0.49437625 0.18698449]
   [0.01310741 0.37222562]
   [0.55598886 0.84819936]]

  [[0.93725932 0.28173065]
   [0.89636889 0.40609561]
   [0.09683494 0.05406874]
   [0.92592465 0.65330892]
   [0.31067811 0.76946625]]]]


## Chunked Cross Attention

In [4]:
attending_chunks = np.pad(observation[m-1:], ((0, m - 1), (0, 0)), mode='constant').reshape(l, m, d)

chunked_output = []
for u in range(l):
    chunk = attending_chunks[u]
    c_neighbours = neighbours[u]
    logits = []
    values = []
    for neighbour in c_neighbours:
        logit, value = cross_attention(chunk, neighbour)
        logits.append(logit)
        values.append(value)
    logits = np.array(logits)
    values = np.array(values)
    #logits += relative_positional_encodings(m, r)[None, :, :]
    logits = np.moveaxis(logits, 0, -1).reshape((m, r * k))
    values = np.moveaxis(values, 0, 1).reshape((r * k, d))
    chunked_output.append(logits @ values)
chunked_output = np.array(chunked_output)

output = np.pad(chunked_output.reshape(n, d),((m - 1, 0), (0, 0)), mode='constant')[:n]

#First chunk
output[:m-1] = observation[:m-1]

#Last row of final chunk
logits = []
values = []
for neighbour in neighbours[-1]:
    logit, value = cross_attention(np.array([observation[-1]]), neighbour)
    logits.append(logit)
    values.append(value)
logits = np.moveaxis(np.array(logits), 0, -1).reshape((1, r * k))
values = np.moveaxis(np.array(values), 0, 1).reshape((r * k, d))
output[-1] = logits @ values

#print(hash(str(output)))
print(output)

[[0.07996373 0.06450442]
 [0.82423691 0.3825449 ]
 [0.64634458 0.26694494]
 [1.79499304 2.76881043]
 [1.2515339  1.93063008]
 [0.74511727 1.14694873]
 [1.9454078  3.00187902]
 [2.0371975  2.98876332]
 [1.07197414 1.57225073]
 [1.33908344 1.96359641]
 [3.13173358 4.59609942]
 [3.79941571 5.81025084]]
