<a href="https://colab.research.google.com/github/nineTailsKurama/mathtokenizer/blob/main/notebooks/RETRO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install retro-pytorch
!pip install huggingface_hub

Collecting retro-pytorch
  Downloading retro_pytorch-0.2.7-py3-none-any.whl (19 kB)
Collecting einops>=0.3
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Collecting autofaiss
  Downloading autofaiss-2.13.2-py3-none-any.whl (60 kB)
[K     |████████████████████████████████| 60 kB 4.4 MB/s 
Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 23.9 MB/s 
Collecting fsspec>=2022.1.0
  Downloading fsspec-2022.2.0-py3-none-any.whl (134 kB)
[K     |████████████████████████████████| 134 kB 63.9 MB/s 
[?25hCollecting embedding-reader<2,>=1.2.0
  Downloading embedding_reader-1.3.0-py3-none-any.whl (16 kB)
Collecting faiss-cpu<2,>=1.7.2
  Downloading faiss_cpu-1.7.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.6 MB)
[K     |████████████████████████████████| 8.6 MB 20.9 MB/s 
[?25hCollecting dataclasses<1.0.0,>=0.6
  Downloading dataclasses-0.6-py3-no

In [None]:
import torch
from retro_pytorch import RETRO

In [None]:
retro = RETRO(
    chunk_size = 64,                         # the chunk size that is indexed and retrieved (needed for proper relative positions as well as causal chunked cross attention)
    max_seq_len = 2048,                      # max sequence length
    enc_dim = 896,                           # encoder model dim
    enc_depth = 2,                           # encoder depth
    dec_dim = 796,                           # decoder model dim
    dec_depth = 12,                          # decoder depth
    dec_cross_attn_layers = (3, 6, 9, 12),   # decoder cross attention layers (with causal chunk cross attention)
    heads = 8,                               # attention heads
    dim_head = 64,                           # dimension per head
    dec_attn_dropout = 0.25,                 # decoder attention dropout
    dec_ff_dropout = 0.25,                   # decoder feedforward dropout
    use_deepnet = True                       # turn on post-normalization with DeepNet residual scaling and initialization, for scaling to 1000 layers
)




In [None]:
import torch

In [None]:
seq = torch.randint(0, 20000, (2, 2048 + 1))      # plus one since it is split into input and labels for training
retrieved = torch.randint(0, 20000, (2, 32, 2, 128)) # retrieved tokens - (batch, num chunks, num retrieved neighbors, retrieved chunk with continuation)

# loss = retro(seq, retrieved, return_loss = True)
# loss.backward()

In [None]:
seq = seq.cuda()
retrieved = retrieved.cuda()

In [None]:
retro.cuda()

RETRO(
  (token_emb): Embedding(28996, 896)
  (pos_emb): Embedding(2048, 896)
  (to_decoder_model_dim): Linear(in_features=896, out_features=796, bias=True)
  (encoder_output_to_decoder_dim): Linear(in_features=896, out_features=796, bias=True)
  (encoder): Encoder(
    (layers): ModuleList(
      (0): ModuleList(
        (0): PostNorm(
          (fn): Attention(
            (dropout): Dropout(p=0.0, inplace=False)
            (to_q): Linear(in_features=896, out_features=512, bias=False)
            (to_k): Linear(in_features=896, out_features=512, bias=False)
            (to_v): Linear(in_features=896, out_features=512, bias=False)
            (to_out): Linear(in_features=512, out_features=896, bias=True)
          )
          (norm): LayerNorm((896,), eps=1e-05, elementwise_affine=True)
        )
        (1): PostNorm(
          (fn): Attention(
            (dropout): Dropout(p=0.0, inplace=False)
            (to_q): Linear(in_features=896, out_features=512, bias=False)
            (

In [None]:
loss = retro(seq, retrieved, return_loss = True)
loss.backward()

RuntimeError: ignored

In [None]:
# !nvidia-smi

## RETRO training wrapper

In [None]:
import torch
from retro_pytorch import RETRO, TrainingWrapper

# instantiate RETRO, fit it into the TrainingWrapper with correct settings

retro = RETRO(
    max_seq_len = 2048,                      # max sequence length
    enc_dim = 896,                           # encoder model dimension
    enc_depth = 3,                           # encoder depth
    dec_dim = 768,                           # decoder model dimensions
    dec_depth = 12,                          # decoder depth
    dec_cross_attn_layers = (1, 3, 6, 9),    # decoder cross attention layers (with causal chunk cross attention)
    heads = 8,                               # attention heads
    dim_head = 64,                           # dimension per head
    dec_attn_dropout = 0.25,                 # decoder attention dropout
    dec_ff_dropout = 0.25                    # decoder feedforward dropout
).cuda()

# wrapper = TrainingWrapper(
#     retro = retro,                                 # path to retro instance
#     knn = 2,                                       # knn (2 in paper was sufficient)
#     chunk_size = 64,                               # chunk size (64 in paper)
#     documents_path = './text_folder',              # path to folder of text
#     glob = '**/*.txt',                             # text glob
#     chunks_memmap_path = './train.chunks.dat',     # path to chunks
#     seqs_memmap_path = './train.seq.dat',          # path to sequence data
#     doc_ids_memmap_path = './train.doc_ids.dat',   # path to document ids per chunk (used for filtering neighbors belonging to same document)
#     max_chunks = 1_000_000,                        # maximum cap to chunks
#     max_seqs = 100_000,                            # maximum seqs
#     knn_extra_neighbors = 100,                     # num extra neighbors to fetch
#     max_index_memory_usage = '100m',
#     current_memory_available = '1G'
# )

# get the dataloader and optimizer (AdamW with all the correct settings)

train_dl = iter(wrapper.get_dataloader(batch_size = 2, shuffle = True))
optim = wrapper.get_optimizer(lr = 3e-4, wd = 0.01)

# now do your training
# ex. one gradient step

seq, retrieved = map(lambda t: t.cuda(), next(train_dl))

# seq       - (2, 2049)         - 1 extra token since split by seq[:, :-1], seq[:, 1:]
# retrieved - (2, 32, 2, 128)   - 128 since chunk + continuation, each 64 tokens

loss = retro(
    seq,
    retrieved,
    return_loss = True
)

# one gradient step

loss.backward()
optim.step()
optim.zero_grad()

# do above for many steps, then ...

# topk sampling with retrieval at chunk boundaries

sampled = wrapper.generate(filter_thres = 0.9, temperature = 1.0) # (1, <2049) terminates early if all <eos>

# or you can generate with a prompt, knn retrieval for initial chunks all taken care of

prompt = torch.randint(0, 1000, (1, 128))  # start with two chunks worth of sequence
sampled = wrapper.generate(prompt, filter_thres = 0.9, temperature = 1.0) # (1, <2049) terminates early if all <eos>


## RETRO Datasets

In [None]:
import torch
from torch.utils.data import DataLoader
from retro_pytorch import RETRO, RETRODataset

# mock data constants

import numpy as np

NUM_CHUNKS = 1000
CHUNK_SIZE = 64
NUM_SEQS = 100
NUM_NEIGHBORS = 2

def save_memmap(path, tensor):
    f = np.memmap(path, dtype = tensor.dtype, mode = 'w+', shape = tensor.shape)
    f[:] = tensor
    del f

# generate mock chunk data

save_memmap(
    './train.chunks.dat',
    np.int32(np.random.randint(0, 8192, size = (NUM_CHUNKS, CHUNK_SIZE + 1)))
)

# generate nearest neighbors for each chunk

save_memmap(
    './train.chunks.knn.dat',
    np.int32(np.random.randint(0, 1000, size = (NUM_CHUNKS, NUM_NEIGHBORS)))
)

# generate seq data

save_memmap(
    './train.seq.dat',
    np.int32(np.random.randint(0, 128, size = (NUM_SEQS,)))
)

# instantiate dataset class
# which constructs the sequence and neighbors from memmapped chunk and neighbor information

# train_ds = RETRODataset(
#     num_sequences = NUM_SEQS,
#     num_chunks = NUM_CHUNKS,
#     num_neighbors = NUM_NEIGHBORS,
#     chunk_size = CHUNK_SIZE,
#     seq_len = 2048,
#     chunk_memmap_path = './train.chunks.dat',
#     chunk_nn_memmap_path = './train.chunks.knn.dat',
#     seq_memmap_path = './train.seq.dat'
# )

# train_dl = iter(DataLoader(train_ds, batch_size = 2))

# one forwards and backwards

retro = RETRO(
    max_seq_len = 2048,                      # max sequence length
    enc_dim = 896,                           # encoder model dimension
    enc_depth = 3,                           # encoder depth
    dec_dim = 768,                           # decoder model dimensions
    dec_depth = 12,                          # decoder depth
    dec_cross_attn_layers = (1, 3, 6, 9),    # decoder cross attention layers (with causal chunk cross attention)
    heads = 8,                               # attention heads
    dim_head = 64,                           # dimension per head
    dec_attn_dropout = 0.25,                 # decoder attention dropout
    dec_ff_dropout = 0.25                    # decoder feedforward dropout
).cuda()

seq, retrieved = map(lambda t: t.cuda(), next(train_dl))

# seq       - (2, 2049)         - 1 extra token since split by seq[:, :-1], seq[:, 1:]
# retrieved - (2, 32, 2, 128)   - 128 since chunk + continuation, each 64 tokens

loss = retro(
    seq,
    retrieved,
    return_loss = True
)

loss.backward()

