# A RETRO tutorial by Carson Lam 

this is my notebook to learn and teach the implementation of <a href="https://arxiv.org/abs/2112.04426">RETRO</a>, Deepmind's Retrieval based Attention net, in Pytorch, on a small but meaningful task. 

1. I am using python3.8, if you dont have python 3.8 there are many ways to install it, using pyenv or homebrew, I used homebrew  and followed the instructions at the end of the download. You have to restart your terminal for the changes to take effect.

2. create a virtual environment for this project and entered that environment

```
python3.8 -m venv env
source env/bin/activate
```

3. install this project's dependencies from requirements.txt

```
pip install --upgrade pip
pip install -r requirements.txt
```

4. save any additional dependencies you have pip installed inside your environment along with the specific version back into requirements.txt for later use

```
pip freeze > requirements.txt
```

5. open up jupyter and open this notebook

```
jupyter notebook
```

6. outside of the RETRO-pytorch folder we have a folder called data/ and inside data/ we have text_folder/ and processed_text/ folders

In [1]:
import torch
from retro_pytorch import RETRO, TrainingWrapper

%load_ext autoreload
%autoreload 2

print('torch.version', torch.__version__)
print('torch.cuda.is_available()', torch.cuda.is_available())
print('torch.cuda.device_count()', torch.cuda.device_count())

torch.version 1.11.0
torch.cuda.is_available() False
torch.cuda.device_count() 0


the chunk size that is indexed and retrieved is needed for proper relative positions as well as causal chunked cross attention

decoder cross attention layers is used with causal chunk cross attention
 
turn on `use_deepnet`  post-normalization with DeepNet residual scaling and initialization,  for scaling to 1000 layers

In [2]:
retro = RETRO(
    chunk_size = 64,                         # the chunk size that is indexed and retrieved  
    max_seq_len = 2048,                      # max sequence length
    enc_dim = 2, #896                        # encoder model dim
    enc_depth = 2,                           # encoder depth
    dec_dim = 2, #768,                       # decoder model dim
    dec_depth = 6, #12                       # decoder depth
    dec_cross_attn_layers = (3, 6), #(3, 6, 9, 12),   # decoder cross attention layers 
    heads = 1, #8                            # attention heads
    dim_head = 32, #64                       # dimension per head
    dec_attn_dropout = 0.25,                 # decoder attention dropout
    dec_ff_dropout = 0.25,                   # decoder feedforward dropout
    use_deepnet = True                       # turn on post-normalization with DeepNet residual scaling and initialization 
)

In [3]:
 # plus one since it is split into input and labels for training
seq = torch.randint(0, 20000, (2, 2048 + 1))   
print(seq)
print(seq.shape)

tensor([[ 5077, 13497, 19083,  ...,   602, 17115, 17246],
        [12548,  4856,  6550,  ..., 11688, 17976,  6109]])
torch.Size([2, 2049])


In [4]:
# retrieved tokens 
# - (batch, num chunks, num retrieved neighbors, retrieved chunk with continuation)
retrieved = torch.randint(0, 20000, (2, 32, 2, 128)) 
print(retrieved[:1,:3,:1,:32])
print(retrieved.shape)

tensor([[[[ 1601, 19473,  9206, 14366, 14221,  7337,  9191, 13800, 13641, 14119,
           10225,    57,  1722,  3420,  9363,  6131, 15305, 10109,  6006,  4499,
            3986,  4643,  8430,  7113,   159, 14452, 17571,  6632,  7755,  6512,
           18741,  8702]],

         [[17144, 19027, 14726, 14819,  9710,  1291,  3069, 12937, 18709,  8912,
            7244, 17348,  9276, 17261,  7435, 12393, 18171,  3626,  6282,  4880,
           17808, 18561,  9615, 13768,  2940,  5774, 16365, 17641,  8120, 14100,
           13870,  6459]],

         [[18666,   579,   561, 18489,  7426, 18334,  1080,  2945,  6500,   632,
            1838,  2085,  8767, 14453, 12808,  1028,  4739,  7120, 11321, 12055,
           14965,  7599, 12952, 17885, 19124,  6459,   612,  9428, 10531, 16981,
            2916,  3092]]]])
torch.Size([2, 32, 2, 128])


In [5]:
loss = retro(seq, retrieved, return_loss = True)
print(loss)

tensor(10.5088, grad_fn=<NllLoss2DBackward0>)


The aim of the TrainingWrapper is to process a folder of text documents into the necessary memmapped numpy arrays to begin training RETRO.

`bert_embed()` will automatically use cuda if available so best to match it with the retro that is inputted to wrapper



In [None]:
if torch.cuda.is_available():
    retro = retro.cuda()

wrapper = TrainingWrapper(
    retro = retro,                                 # path to retro instance
    knn = 2,                                       # knn (2 in paper was sufficient)
    chunk_size = 32,                               # chunk size (64 in paper)
    documents_path = '../data/text_folder',              # path to folder of text
    glob = '**/*.txt',                             # text glob
    chunks_memmap_path = '../data/processed_text/train.chunks.dat',     # path to chunks
    seqs_memmap_path = '../data/processed_text/train.seq.dat',          # path to sequence data
    doc_ids_memmap_path = '../data/processed_text/train.doc_ids.dat',   # path to document ids per chunk (used for filtering neighbors belonging to same document)
    max_chunks = 10,                         # maximum cap to chunks
    max_seqs = 10,                            # maximum seqs
    knn_extra_neighbors = 2,                     # num extra neighbors to fetch
    max_index_memory_usage = '10m',
    current_memory_available = '0.5G'
)

processing ../data/text_folder/doc1.txt


Using cache found in /Users/carson/.cache/torch/hub/huggingface_pytorch-transformers_main


processing ../data/text_folder/doc2.txt


Using cache found in /Users/carson/.cache/torch/hub/huggingface_pytorch-transformers_main
Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
