# Retrieval-Enhanced Transformer (ReTro)
Link to paper: https://arxiv.org/pdf/2112.04426.pdf<br/>
LabML AI implementation: https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/retro<br/>
LabML Annotated Implementation: https://nn.labml.ai/transformers/retro/index.html

**DeepMind Abstract**<br/>
    We enhance auto-regressive language models by conditioning on document chunks retrieved from a
    large corpus, based on local similarity with preceding tokens. With a 2 trillion token database, our
    Retrieval-Enhanced Transformer (Retro) obtains comparable performance to GPT-3 and Jurassic-1
    on the Pile, despite using 25× fewer parameters. After fine-tuning, Retro performance translates to
    downstream knowledge-intensive tasks such as question answering. Retro combines a frozen Bert
    retriever, a differentiable encoder and a chunked cross-attention mechanism to predict tokens based on
    an order of magnitude more data than what is typically consumed during training. We typically train
    Retro from scratch, yet can also rapidly Retrofit pre-trained transformers with retrieval and still
    achieve good performance. Our work opens up new avenues for improving language models through
    explicit memory at unprecedented scale.

![model.png](./images/retro/model.png)
**Source**: Deepmind paper, page 3

## Install Dependencies and Clone LabML's ReTro Git Repo

In [None]:
!pip install torch
!pip install torchvision
!pip install torchtext
!pip install labml_nn
!pip install labml
!pip install labml-helpers
!pip install numpy
!pip install matplotlib
!pip install einops
!pip install gym[atari]
!pip install opencv-python
!pip install Pillow
!pip install wget
!pip install transformers

In [None]:
import os.path
import sys

labml_repo_path = "retro/"

if not os.path.exists(labml_repo_path):
    !git clone https://github.com/labmlai/annotated_deep_learning_paper_implementations.git retro

### CUDA Setup

In [None]:
import torch

if torch.cuda.is_available():
    !pip install cupy-cuda111

In [None]:
import torch

if torch.cuda.is_available():
    !nvcc --version

In [None]:
if torch.cuda.is_available():
    !pip install faiss-gpu==1.7.0
else:
    !pip install faiss-cpu==1.7.0

## Testing the Default ReTro model

There is not official public implementation of the ReTro model.

This section runs the functions to test that the default ReTro functions are working correctly. It does the following:<br/>
* **database.build_database()**: get BERT embeddings for the training data and load them in an index
* **dataset.build_dataset()**: sample the index and load nearest neighbors into a dataset
* **train.train()**: train RETRO Model for 32 epochs to generate a text response to a prompt
* **model.test()**: test the model on sample data

In [None]:
from retro.labml_nn.transformers.retro import model, train, dataset, database

In [None]:
# creates lab.get_data_path()/retro.index
database.build_database()

# creates lab.get_data_path()/retro_train_dataset.json
dataset.build_dataset()

In [None]:
m = model._test()

In [None]:
m = train.train()

## Prepare ConceptNet Dataset for ReTro

The data used to train the official ReTro model is not publicly available. The default LabML implementation of ReTro trains using the Tiny Shakespeare dataset to generate a Shakespearean response to a given prompt.<br/>

Given that this implementation is interested in general question answering, a different dataset seemed more appropriate for the question answering task. The QA-GNN model was trained on the CommonSense QA dataset with concepts extracted from the ConceptNet dataset. To align the same set of trained concepts, the ConceptNet dataset is used in this function to create the key-value store of BERT embeddings of those concepts.<br/>

ConceptNet dataset downloaded and text extracted from https://s3.amazonaws.com/conceptnet/downloads/2018/omcs-sentences-free.txt

In [None]:
import wget
import os
import sys
import re
import numpy as np

def bar_progress(current, total, width=80):
  progress_message = "Downloading: %d%% [%d / %d] bytes" % (current / total * 100, current, total)
  sys.stdout.write("\r" + progress_message)
  sys.stdout.flush()

def prepare_conceptnet_text_file(conceptnet_file_loc, cn_textonly_file, download_path):
  conceptnet_text = []
  
  conceptnet_file = open(conceptnet_file_loc, 'r')
  cn_lines = conceptnet_file.readlines()
  half_lines = int(len(cn_lines) / 2)
  sample_lines = np.random.choice(cn_lines, size=half_lines, replace=False)
    
  for line in sample_lines:
    line_split = line.split("\t")
    if len(line_split) == 7:
      if line_split[4] == "en":
        current_string = re.sub(' +', ' ', line_split[1].strip())
        conceptnet_text.append(current_string + "\n")

  # writing to file
  cn_write_file = open(download_path + cn_textonly_file, 'w')
  cn_write_file.writelines(conceptnet_text)
  cn_write_file.close()


conceptnet_dataset_file = "omcs-sentences-free.txt"
conceptnet_dataset_url = "https://s3.amazonaws.com/conceptnet/downloads/2018/" + conceptnet_dataset_file
cn_textonly_file = "cn_textonly.txt"

if not os.path.exists(str(lab.get_data_path() / conceptnet_dataset_file)):
  wget.download(conceptnet_dataset_url, str(lab.get_data_path() / conceptnet_dataset_file), bar=bar_progress)

if not os.path.exists(str(lab.get_data_path() / cn_textonly_file)):
  prepare_conceptnet_text_file(str(lab.get_data_path() / conceptnet_dataset_file), cn_textonly_file, str(lab.get_data_path()))

## Building ReTro Key-Value Embedding Store

The functions in this block are taken from LabML's implementation of ReTro. The only modifications were adapations to allow different datasets to be used (by adding the file_path and url parameters).<br/>
* **build_retro_database**: adapted from the build_database function in database.py
* **build_retro_dataset**: adapted from the build_dataset function in dataset.py

In [None]:
import faiss
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset as PyTorchDataset
import json
from pathlib import Path
from typing import List, Optional, Set
import math
import joblib

from labml import lab, monit
from labml_helpers.datasets.text import TextFileDataset, TextDataset
from labml_nn.transformers.retro.bert_embeddings import BERTChunkEmbeddings
from labml_nn.transformers.retro.database import RetroIndex
from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder
from labml.logger import inspect

In [None]:
def build_retro_database(file_path, url, chunk_len: int = 16, batch_size: int = 64, d_emb: int = 768, n_centeroids: int = 256, code_size: int = 64, n_probe: int = 8, n_train: int = 50_000):
  """
  ## Build the dataset

  * `chunk_len` is the chunk length
  * `chunks_per_sample` is the number of chunks per training sample
  * `skip_range` is the maximum number of characters to skip between two samples.
      We skip a few characters between samples to make sure the samples
      aren't aligned perfectly with the chunks in the [database](database.html)
  """
  # Load the dataset text file
  dataset = TextFileDataset(
      file_path,
      list,
      url=url)

  # Get training data (a string)
  text = dataset.train

  # Split the text into chunks of `chunk_length`
  chunks = [text[i:i + chunk_len] for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)]
  # Get the offsets of each of the chunks
  chunk_offsets = np.array([i for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)])
  # Number of chunks
  n_chunks = len(chunks)

  # Initialize BERT to get $\text{B\small{ERT}}(N)$
  bert = BERTChunkEmbeddings(torch.device('cuda:0'))

  # Get chunk embeddings by processing `batch_size` number of chunks on each iteration
  chunk_emb = []
  for i in monit.iterate('Get embeddings', range(0, n_chunks, batch_size)):
    chunk_emb.append(bert(chunks[i: i + batch_size]).cpu())
  # Merge them into a single tensor
  chunk_emb = torch.cat(chunk_emb, dim=0).numpy()

  # Create the [FAISS index](https://faiss.ai/cpp_api/struct/structfaiss_1_1IndexIVFPQ.html)
  quantizer = faiss.IndexFlatL2(d_emb)
  index = faiss.IndexIVFPQ(quantizer, d_emb, n_centeroids, code_size, 8)
  index.nprobe = n_probe

  # Get a random sample of the the chunk indexes
  random_sample = np.random.choice(np.arange(n_chunks), size=[min(n_train, n_chunks)], replace=False)

  # Train the index to store the keys
  with monit.section('Train index'):
      index.train(chunk_emb[random_sample])

  # Add the chunks to the index in batches of size `1024`
  for s in monit.iterate('Index', range(0, n_chunks, 1024)):
      e = min(s + 1024, n_chunks)
      # Add to index
      index.add_with_ids(chunk_emb[s:e], chunk_offsets[s: e])

  # Save the index
  with monit.section('Save'):
      faiss.write_index(index, str(lab.get_data_path() / 'retro.index'))

In [None]:
def build_retro_dataset(file_path, url, chunk_len: int = 16, chunks_per_sample: int = 32, skip_range: int = 8):
    """
    ## Build the dataset

    * `chunk_len` is the chunk length
    * `chunks_per_sample` is the number of chunks per training sample
    * `skip_range` is the maximum number of characters to skip between two samples.
        We skip a few characters between samples to make sure the samples
        aren't aligned perfectly with the chunks in the [database](database.html)
    """
    # Load the dataset text file
    dataset = TextFileDataset(
        file_path,
        list,
        url=url)

    # Training portion of it
    text = dataset.train

    # Load the index for retrieving neighbors
    index = RetroIndex()

    # The input sample offsets
    sample_offsets = []
    # Cursor for the text
    i = 0
    while i < len(text):
        # Skip a few characters to make sure it's not aligned with the neighbors
        skip = np.random.randint(skip_range)
        i += skip

        # Stop if we've reached the end of the text
        if i + chunks_per_sample * chunk_len > len(text):
            break

        # Collect the offset
        sample_offsets.append(i)

        # Increment the cursor
        i += chunks_per_sample * chunk_len

    # For samples
    samples = []
    # Iterate through sample offsets
    for i in monit.iterate('Gather Neighbors', sample_offsets):
        # Get the sample including an extra character (for prediction)
        sample = text[i: i + chunks_per_sample * chunk_len + 1]
        # The input
        src = sample[:-1]
        # Break it into chunks
        chunks = [src[j:j + chunk_len] for j in range(0, len(src), chunk_len)]
        # The chunk offsets
        chunk_offsets = [j + i for j in range(0, len(src), chunk_len)]

        # Retrieve nearest neighbors
        neighbor_offsets = index(chunks, chunk_offsets)

        # Get neighbor texts. The neighbor length is twice the `chunk_len`
        neighbors = [[text[j: j + chunk_len * 2] for j in n_off] for n_off in neighbor_offsets]

        # Add to list of samples
        samples.append((sample[:-1], sample[1:], neighbors))

    # Save the samples in JSON.
    # We don't need to use complex dataset storage mechanisms or pre-tokenize
    # since our dataset is small.
    with open(str(lab.get_data_path() / 'retro_train_dataset.json'), 'w') as f:
        f.write(json.dumps(samples))

### Train ReTro Model

The functions in this block are taken from LabML's implementation of ReTro. The only modifications were adapations to allow different datasets to be used (by adding the tds_file_name and url parameters).<br/>
* **train_retro_model**: adapted from the train function in train.py

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, RandomSampler

from labml import monit, lab, tracker, experiment, logger
from labml.logger import Text
from labml_helpers.datasets.text import TextFileDataset
from labml_nn.optimizers.noam import Noam
from labml_nn.transformers.retro import model as retro
from labml_nn.transformers.retro.dataset import Dataset, RetroIndex
from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder
from labml_nn.transformers.retro.train import Trainer, Sampler

"""
## Create and train a small model
"""
def train_retro_model(tds_file_name, url):
    # GPU device
    device = torch.device('cuda:0')

    # Load dataset
    tds = TextFileDataset(
        tds_file_name,
        list,
        url=url)

    # Load [Retro dataset](dataset.html)
    train_dataset = Dataset(lab.get_data_path() / 'retro_train_dataset.json', tds)

    # Create dataloader
    train_dl = DataLoader(train_dataset,
                          batch_size=4,
                          sampler=RandomSampler(train_dataset, replacement=True))

    # Hyper-parameters
    chunk_len = 16
    d_model = 128
    d_ff = 512
    n_heads = 16
    d_k = 16
    
    # Create the nearest neighbor encoder
    nearest_neighbor_encoder = NearestNeighborEncoder(chunk_len, 6, {3}, d_model, n_heads, d_k, d_ff)
    # Create the model
    model = RetroModel(tds.n_tokens, d_model, 6,
                       {3, 5},
                       chunk_len, n_heads, d_k, d_ff,
                       encoder=nearest_neighbor_encoder)
    # Move the model to the device
    model = model.to(device)
    # Create the optimizer
    optimizer = Noam(model.parameters(), lr=1., d_model=d_model, warmup=2_000)
    # Create the `Trainer`
    trainer = Trainer(device, model, train_dl, optimizer)
    # Create the `Sampler`
    sampler = Sampler(device, model, tds, chunk_len)
    #
    #prompt = '''Second Citizen:\nOne word, good citizens.\n\nFirst Citizen:'''
    prompt = "Dear model, Please answer this question!\n\nQ: What is the course to take?\n\nA:"

    # Train for `32` epochs
    for epoch in monit.loop(32):
        # Train
        trainer()
        # Print a new line
        tracker.new_line()
        # Sample from the `prompt`
        logger.log([(prompt.replace('\n', '\\n\n'), Text.subtle),
                    (sampler.sample(prompt, 128).replace('\n', '\\n\n'), Text.none)])
    
    return model

### Wrapper Function to Build Data Stores and Train the Model

In [None]:
def build_datasets_and_model(tgs_txt, tgs_url):
    if not os.path.exists(str(lab.get_data_path() / 'retro.index')):
        build_retro_database(tgs_txt, tgs_url)
    
    if not os.path.exists(str(lab.get_data_path() / 'retro_train_dataset.json')):
        build_retro_dataset(tgs_txt, tgs_url)
    
    if not os.path.exists(str(lab.get_data_path() / 'retro_model.pkl')):
        m = train_retro_model(tgs_txt, tgs_url)
        # Save the model as a pickle in a file
        joblib.dump(m, 'retro_model.pkl')
    else:
        m = joblib.load(str(lab.get_data_path() / 'retro_model.pkl'))

    return m

### Specify the Source Dataset, Build Data Stores, and Train the Model

In [None]:
tiny_shakespeare_txt = str(lab.get_data_path() / 'tiny_shakespeare.txt')
tiny_shakespeare_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"

m = build_datasets_and_model(tiny_shakespeare_txt, tiny_shakespeare_url)
# m = build_datasets_and_model(cn_textonly_file, "")

## Remaining Work

I was able to run the following functions successfully with ConceptNet data:<br/>
* prepare_conceptnet_text_file
* build_retro_database
* build_retro_dataset

While running the train_retro_model with the ConceptNet data, after successfully training for a portion of the first epoch, I received an error about different sized tensors. I suspect that this is due to hyperparameter choice in the training functions but have not been able to debug.<br/>

An evaluation function is also needed for this work. My intention was to evaluate the ReTro model using SQuAD data to create a similar evaluation between all three models: GPT-3, QA-GNN, and ReTro.