# Retrieval Augmented Generation with SageMaker

Text to text Generative AI models have a well documented problem which is the issue of having only information up to the date for which they were trained. This notebook shows how to use retrieval augmented generation (RAG), otherwise known as data augmented generation, to help suppliment text generation models with up to date information via document search. We will use two different models to do this. First, we will use the HuggingFace FLAN T5 for document and question embedding. 

# The Hallucination Issue

Now that we have an endpoint up and running, the example below shows how the model confidently "hallucinates" that Switzerland won the 2022 world cup.  The actual fact is that Argentina won in 2022 and  France won in 2018. Here in lies the problem we need to fix with RAG.

In [69]:
import transformers
import torch
import pandas as pd
import numpy as np

import json
import boto3
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5EncoderModel

# Note:  we need to use at least t5-large, otherwise the embeddings do not match the right document later.
#        it might be good to actually show this, maybe?  and show the similarity search results later?

tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
encoder_model = T5EncoderModel.from_pretrained("google/flan-t5-large")

Some weights of the model checkpoint at google/flan-t5-large were not used when initializing T5EncoderModel: ['decoder.block.16.layer.2.DenseReluDense.wo.weight', 'decoder.block.9.layer.1.EncDecAttention.o.weight', 'decoder.block.21.layer.0.SelfAttention.v.weight', 'decoder.block.10.layer.1.EncDecAttention.o.weight', 'decoder.block.5.layer.0.SelfAttention.o.weight', 'decoder.block.7.layer.0.SelfAttention.o.weight', 'decoder.block.18.layer.1.EncDecAttention.k.weight', 'decoder.block.13.layer.0.SelfAttention.k.weight', 'decoder.block.12.layer.0.SelfAttention.o.weight', 'decoder.block.10.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.6.layer.0.SelfAttention.q.weight', 'decoder.block.13.layer.1.EncDecAttention.k.weight', 'decoder.block.2.layer.2.layer_norm.weight', 'decoder.block.18.layer.2.DenseReluDense.wo.weight', 'decoder.block.7.layer.0.SelfAttention.q.weight', 'decoder.block.7.layer.0.SelfAttention.v.weight', 'decoder.block.5.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.2

# IGNORE THE WARNING ^^ ABOVE ^^

In [70]:
prompt = f'''Answer the following question.
Question: Who won the 2022 world cup?
Answer:
'''
inputs = tokenizer(prompt, return_tensors='pt')
output = tokenizer.decode(
    model.generate(
        input_ids=inputs["input_ids"], 
        max_new_tokens=200,
    )[0], 
    skip_special_tokens=True
)


print(f'{prompt}{output}')

Answer the following question.
Question: Who won the 2022 world cup?
Answer:
argentina


In this cell promt engineering with adding the line `If you do not have the information to answer the question, say "I don't know".` to the prompt produces the answer of "I don't know" which is better than producing a wrong answer.

In [71]:
prompt = f'''Answer the following question. If you do not have the information to answer the question, say "I don't know".
Question: Who won the 2022 world cup?
Answer:
'''
inputs = tokenizer(prompt, return_tensors='pt')
output = tokenizer.decode(
    model.generate(
        input_ids=inputs["input_ids"], 
        max_new_tokens=200,
    )[0], 
    skip_special_tokens=True
)

print(f'{prompt}{output}')

Answer the following question. If you do not have the information to answer the question, say "I don't know".
Question: Who won the 2022 world cup?
Answer:
I don't know


# Get a HuggingFace Model for Embeddings

Load in the [FLAN T5 large model](https://huggingface.co/google/flan-t5-large) from HuggingFace. This will be the model we use to create our document search embeddings.

# Create Embedding Database

Now use the HuggingFace model to create embeddings for each of the three documents which have been provided. The documents used here are only illustrative. These documents could be extended to any collection of text to help supplement your use case.

In [72]:
import torch

def get_embedding(text, encoder_model, tokenizer):
    with torch.no_grad():
        input_ids = tokenizer(
            text, return_tensors="pt", max_length=512, truncation=True
        ).input_ids #.to(DEVICE)
        outputs = encoder_model(input_ids=input_ids)
        last_hidden_states = outputs.last_hidden_state
        e = last_hidden_states.mean(dim=1)
    return e

def create_doc_database(docs, encoder_model, tokenizer):
    database = []
    for i in range(docs.shape[0]):
        text = docs['title'].values[i] + ' - ' + docs['document'].values[i]
        e = get_embedding(text, encoder_model, tokenizer)
        database.append(e)
    database = torch.cat(database)
    return database

In [73]:
import pandas as pd

docs = pd.read_csv('document-corpus.txt', delimiter="::: ", engine='python')
docs

Unnamed: 0,title,document
0,world cup 2022,"""Lionel Messi can finally be called a world ch..."
1,champions league 2022,'Real Madrid were crowned European champions f...
2,ballon dor 2022,'Real Madrid and France\'s Karim Benzema has w...


In [74]:
database = create_doc_database(docs, encoder_model, tokenizer)

In [75]:
database.shape

torch.Size([3, 1024])

# Add document-search feature

Now that you have a database of embeddings, we can search the database against a text input `"Who won the 2022 world cup?"` to see which document is most relevant to the question by looking at the dot product of the embeddings.

In [76]:
def search_database(search_embedding, database):
    similarities = []
    for i in range(database.shape[0]):
        similarities.append(
            float(torch.dot(search_embedding[0], database[i]))
        )
    return np.argmax(similarities), similarities

In [77]:
search = 'Who won the 2022 world cup?'
search_embedding = get_embedding(search, encoder_model, tokenizer)
doc_index, similarities = search_database(search_embedding, database)
print(f"Input: {search}\nWas matched with document #{doc_index} which is titled \"{docs.loc[doc_index]['title']}\"")

Input: Who won the 2022 world cup?
Was matched with document #0 which is titled "world cup 2022"


# Dynamically Engineer the Prompt

Now that we have a user input matched with a relevant document, we can engineer a prompt which includes both the question and context from the document.

In [78]:
prompt_eng_base = '''Answer the following question with the following context. If you do not have the information to answer the question, say "I don't know".

Context: [PLACE DOC HERE]

Question: [PLACE QUESTION HERE]
Answer:
'''

In [79]:
def make_prompt(search, context, prompt_eng_base):
    prompt = prompt_eng_base.replace('[PLACE DOC HERE]', context)
    prompt = prompt.replace('[PLACE QUESTION HERE]', search)
    return prompt

In [80]:
augmented_prompt = make_prompt(search, docs.loc[doc_index]['document'], prompt_eng_base)
augmented_prompt

'Answer the following question with the following context. If you do not have the information to answer the question, say "I don\'t know".\n\nContext: "Lionel Messi can finally be called a world champion.Messi scored twice in one of the most epic soccer games anyone has ever watched as Argentina won the 2022 FIFA World Cup Final over France on penalties.The climactic match in Qatar finished 3-3 after extra time, with La Albiceleste claiming the shootout by a 4-2 margin.Argentina held a comfortable 2-0 lead until the 80th minute courtesy of a Messi penalty and a sublime team goal finished by Ángel Di María in the first half. However, Kylian Mbappé converted from the spot and finished a sumptuous volley in a span of two minutes to send the game to extra time.Eugene Omoruyi Scores 17 Points vs. Indiana PacersMessi was once again on hand to put Argentina in front in the 108th minute, but Mbappé kept his cool from the penalty spot once more to send the final to a shootout.Messi knocked home

# Wrap the RAG Flow into a Function

In [81]:
base_prompt = f'''Answer the following question. If you do not have the information to answer the question, say "I don't know".

Question: [SEARCH HERE]
Answer:
'''

def rag_demo(search, use_search=True):
    search_embedding = get_embedding(search, encoder_model, tokenizer)
    print('Returned embedding size {} for input {}'.format(search_embedding.shape, search))
    doc_index, similarities = search_database(search_embedding, database)
    if use_search:
        augmented_prompt = make_prompt(search, docs.loc[doc_index]['document'], prompt_eng_base)
    else:
        augmented_prompt = base_prompt.replace('[SEARCH HERE]', search)

    inputs = tokenizer(augmented_prompt, return_tensors='pt')
        
    output = tokenizer.decode(
        model.generate(
            input_ids=inputs["input_ids"], 
            max_new_tokens=200,
        )[0], 
        skip_special_tokens=True
    )
    return output

# Example Outputs

The outputs below show how you can now get relevant information to the model in order to give informed responses back to the user!

In [82]:
out = rag_demo(
    'Who won the 2022 world cup?', use_search=False
)
print(out)

Returned embedding size torch.Size([1, 1024]) for input Who won the 2022 world cup?
I don't know


In [83]:
out = rag_demo(
    'Who won the 2022 world cup?', use_search=True
)
print(out)

Token indices sequence length is longer than the specified maximum sequence length for this model (956 > 512). Running this sequence through the model will result in indexing errors


Returned embedding size torch.Size([1, 1024]) for input Who won the 2022 world cup?
Argentina


In [84]:
out = rag_demo(
    'Who were a 5 of the most important players in the 2022 world cup final?', use_search=False
)
print(out)

Returned embedding size torch.Size([1, 1024]) for input Who were a 5 of the most important players in the 2022 world cup final?
I don't know


In [85]:
out = rag_demo(
    'Who were a 5 of the most important players in the 2022 world cup final?', use_search=True
)
print(out)

Returned embedding size torch.Size([1, 1024]) for input Who were a 5 of the most important players in the 2022 world cup final?
Lionel Messi, ngel Di Mara, Kylian Mbappé, Kingsley Coman, and Aurélien Tchouaméni


# Suggested Next Steps

* Explore libraries which can help with this kind of workflow. See: [LangChain](https://github.com/hwchase17/langchain)
* Bring your own documents or information to this workflow to explore creating RAG based systems.
* Look into fine tuning your embedding model to produce better searching.
* Integrate this RAG flow with integrations to your own search capabilities.