# Maximum marginal relevance

In [1]:
import numpy as np

In [6]:

def calculate_similarity(query_vector, document_vector):
    """
    Calculate cosine similarity between the query vector and a document vector.

    Args:
    - query_vector (np.array): Vector representing the query.
    - document_vector (np.array): Vector representing a document.

    Returns:
    - similarity (float): Cosine similarity between the query vector and the document vector.
    """
    dot_product = np.dot(query_vector, document_vector)
    query_norm = np.linalg.norm(query_vector)
    doc_norm = np.linalg.norm(document_vector)
    similarity = dot_product / (query_norm * doc_norm)
    return similarity

def mmr_reranking(documents, query_vector, alpha, beta, initial_ranking):
    """
    Rerank a list of documents using Maximum Marginal Relevance (MMR).

    Args:
    - documents (list): List of document vectors.
    - query_vector (np.array): Vector representing the query.
    - alpha (float): Weight parameter for relevance.
    - beta (float): Weight parameter for diversity.
    - initial_ranking (list): Initial ranked list of document indices.

    Returns:
    - reranked_indices (list): Reranked list of document indices.
    """
    num_documents = len(documents)
    reranked_indices = []

    for index in initial_ranking:
        remaining_indices = [i for i in initial_ranking if i not in reranked_indices]
        remaining_documents = [documents[i] for i in remaining_indices]

        # Calculate relevance score
        relevance_score = calculate_similarity(query_vector, documents[index])

        # Calculate diversity score
        diversity_scores = [calculate_similarity(documents[index], doc) for doc in remaining_documents]
        max_diversity_score = max(diversity_scores)
        diversity_score = max_diversity_score if len(diversity_scores) > 1 else 0  # Set diversity_score to 0 if only 1 document left

        # Calculate MMR score
        mmr_score = alpha * relevance_score - beta * diversity_score

        # Select document with maximum MMR score
        selected_index = remaining_indices[np.argmax(mmr_score)]
        reranked_indices.append(selected_index)

    return reranked_indices

In [7]:
# Suppose we have 5 documents, each represented by a 3-dimensional vector
documents = [
    [0.1, 0.2, 0.3],
    [0.4, 0.5, 0.6],
    [0.7, 0.8, 0.9],
    [0.2, 0.3, 0.4],
    [0.5, 0.6, 0.7]
]
query_vector = np.array([0.1, 0.2, 0.3])
alpha = 0.7  # Weight parameter for relevance
beta = 0.3   # Weight parameter for diversity
initial_ranking = [2, 4, 1, 0, 3]  # Initial ranked list of document indices

reranked_indices = mmr_reranking(documents, query_vector, alpha, beta, initial_ranking)
reranked_documents = [documents[i] for i in reranked_indices]
print("Reranked Documents:", reranked_documents)


Reranked Documents: [[0.7, 0.8, 0.9], [0.5, 0.6, 0.7], [0.4, 0.5, 0.6], [0.1, 0.2, 0.3], [0.2, 0.3, 0.4]]


## MMR with LangChain

Source: https://github.com/generative-ai-on-aws/generative-ai-on-aws/blob/main/09_rag/01_langchain_llama2_sagemaker.ipynb

In [None]:
%pip install langchain==0.0.309 faiss-cpu==1.7.4 pypdf==3.15.1 -q --root-user-action=ignore

### Fetch sample data

In [None]:
!mkdir -p ./data

from urllib.request import urlretrieve
urls = [
    'https://s2.q4cdn.com/299287126/files/doc_financials/2023/ar/2022-Shareholder-Letter.pdf',
    'https://s2.q4cdn.com/299287126/files/doc_financials/2022/ar/2021-Shareholder-Letter.pdf',
    'https://s2.q4cdn.com/299287126/files/doc_financials/2021/ar/Amazon-2020-Shareholder-Letter-and-1997-Shareholder-Letter.pdf',
    'https://s2.q4cdn.com/299287126/files/doc_financials/2020/ar/2019-Shareholder-Letter.pdf'
]

filenames = [
    'AMZN-2022-Shareholder-Letter.pdf',
    'AMZN-2021-Shareholder-Letter.pdf',
    'AMZN-2020-Shareholder-Letter.pdf',
    'AMZN-2019-Shareholder-Letter.pdf'
]

metadata = [
    dict(year=2022, source=filenames[0]),
    dict(year=2021, source=filenames[1]),
    dict(year=2020, source=filenames[2]),
    dict(year=2019, source=filenames[3])]

data_root = "./data/"

for idx, url in enumerate(urls):
    file_path = data_root + filenames[idx]
    urlretrieve(url, file_path)

In [None]:
from pypdf import PdfReader, PdfWriter
import glob

local_pdfs = glob.glob(data_root + '*.pdf')

for local_pdf in local_pdfs:
    pdf_reader = PdfReader(local_pdf)
    pdf_writer = PdfWriter()
    for pagenum in range(len(pdf_reader.pages)-3):
        page = pdf_reader.pages[pagenum]
        pdf_writer.add_page(page)

    with open(local_pdf, 'wb') as new_file:
        new_file.seek(0)
        pdf_writer.write(new_file)
        new_file.truncate()

In [None]:
import numpy as np
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader, PyPDFDirectoryLoader

documents = []

for idx, file in enumerate(filenames):
    loader = PyPDFLoader(data_root + file)
    document = loader.load()
    for document_fragment in document:
        document_fragment.metadata = metadata[idx]
        
    documents += document

# - in our testing Character split works better with this PDF data set
text_splitter = RecursiveCharacterTextSplitter(
    # Set a really small chunk size, just to show.
    chunk_size = 512,
    chunk_overlap  = 100,
)

docs = text_splitter.split_documents(documents)

print(f'# of Document Pages {len(documents)}')
print(f'# of Document Chunks: {len(docs)}')

### Embedding Model

In [None]:
from sagemaker.jumpstart.model import JumpStartModel

embedding_model_id, embedding_model_version = "huggingface-textembedding-all-MiniLM-L6-v2", "*"
model = JumpStartModel(model_id=embedding_model_id, model_version=embedding_model_version)
embedding_predictor = model.deploy()

In [None]:
embedding_model_endpoint_name = embedding_predictor.endpoint_name
embedding_model_endpoint_name

In [None]:
import boto3
aws_region = boto3.Session().region_name

### Vector database

In [None]:
from typing import Dict, List
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
import json


class CustomEmbeddingsContentHandler(EmbeddingsContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"text_inputs": inputs, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> List[List[float]]:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["embedding"]


embeddings_content_handler = CustomEmbeddingsContentHandler()


embeddings = SagemakerEndpointEmbeddings(
    endpoint_name=embedding_model_endpoint_name,
    region_name=aws_region,
    content_handler=embeddings_content_handler,
)

In [None]:
from langchain.schema import Document
from langchain.vectorstores import FAISS

In [None]:
db = FAISS.from_documents(docs, embeddings)

### Creating Prompt

In [None]:
from langchain.prompts import PromptTemplate

prompt_template = """
<s>[INST] <<SYS>>
Use the context provided to answer the question at the end. If you dont know the answer just say that you don't know, don't try to make up an answer.
<</SYS>>

Context:
----------------
{context}
----------------

Question: {question} [/INST]
"""

PROMPT = PromptTemplate(
    template=prompt_template,
    input_variables=["context", "question"]
)

### Preparing LLM

In [None]:
from typing import Dict

from langchain import PromptTemplate, SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.chains.question_answering import load_qa_chain
from langchain.chains import RetrievalQA
import json


class QAContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
        input_str = json.dumps(
            {"inputs" : [
                [
                    {
                        "role" : "system",
                        "content" : ""
                    },
                    {
                        "role" : "user",
                        "content" : prompt
                    }
                ]],
                "parameters" : {**model_kwargs}
            })
        return input_str.encode('utf-8')
    
    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generation"]["content"]

qa_content_handler = QAContentHandler()

In [None]:
llm_model_id, llm_model_version = "meta-textgeneration-llama-2-7b-f", "2.*"
llm_model = JumpStartModel(model_id=llm_model_id, model_version=llm_model_version)
llm_predictor = llm_model.deploy()

In [None]:
llm_model_endpoint_name = llm_predictor.endpoint_name
llm_model_endpoint_name

In [None]:
llm = SagemakerEndpoint(
        endpoint_name=llm_model_endpoint_name,
        region_name=aws_region,
        model_kwargs={"max_new_tokens": 1000, "top_p": 0.9, "temperature": 1e-11},
        endpoint_kwargs={"CustomAttributes": 'accept_eula=true'},
        content_handler=qa_content_handler
    )

### Retrieval with MMR

In [None]:
qa_chain = RetrievalQA.from_chain_type(
    llm,
    chain_type='stuff',
    retriever=db.as_retriever(
        search_type="mmr", # Maximum Marginal Relevance (MMR)
        search_kwargs={"k": 3, "lambda_mult": 0.1}
    ),
    return_source_documents=True,
    chain_type_kwargs={"prompt": PROMPT}
)

In [None]:
query = "How has AWS evolved?"
result = qa_chain({"query": query})
print(f'Query: {result["query"]}\n')
print(f'Result: {result["result"]}\n')
print(f'Context Documents: ')
for srcdoc in result["source_documents"]:
      print(f'{srcdoc}\n')