In [None]:
# build a sample vectorDB

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
import StrOutputParser
import ChatPromptTemplate
import os

os.environ["OPENAI_API_KEY"] = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"

In [None]:
# load blog post
loader = WebBaseLoader("https://www.investopedia.com/terms/c/capital-asset-pricing-model.asp")
data = loader.load()

llm = ChatOpenAI()

# split
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
splits = text_splitter.split_documents(data)

# vectorDB
embeddings = OpenAIEmbeddings()
vectordb = Chroma.from_documents(splits, embeddings)

In [None]:
# generate more question

def create_original_query(original_query):
    query = original_query["question"]
    qa_system_prompt = """
        You are a question generator. Given a question, generate 5 more questions.
    """
    qa_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", qa_system_prompt),
            ("human", "{question}"),
        ]
    )

    rag_chain = (
        qa_prompt
        | llm
        | StrOutputParser()
    )

    question_string = rag_chain.invoke({"question": query})

    lines_list = question_string.splitlines()
    queries = []
    queries = [query] + lines_list
    return queries

In [None]:
# retrieve document and cross encode
from sentence_transformers import CrossEncoder
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain.chains import RetrievalQA
from langchain_core.output_parsers import StrOutputParser
import numpy as np

cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")


In [None]:
# cross encoding happens in here

def create_documents(queries):
    retrieved_documents = []
    for i in queries:
        results = vectordb.as_retriever().get_relevant_documents(i)
        docString = [doc.page_content for doc in results]
        retrieved_documents.extend(docString)
    unique_a = []
    # if there is duplication documents for each query, make it unique
    for item in retrieved_documents:
        if item not in unique_a:
            unique_a.append(item)

    unique_documents = list(unique_a)

    pairs = []
    for doc in unique_documents:
        pairs.append([queries[0], doc])
    
    # cross encoder scoring
    scores = cross_encoder.predict(pairs)
    final_queries = []
    for x in range(len(scores)):
        final_queries.append({"score":scores[x], "document": unique_documents[x]})

    # rerank the documents, return top 5
    sorted_list = sorted(final_queries, key=lambda x: x["score"], reverse=True)
    first_five_docs = sorted_list[:5]
    return first_five_docs


In [None]:
# QnA document

qa_system_prompt = """
    Assistant is a large language model trained by OpenAI. \
    Use the following pieces of retrieved context to answer the question. \
    If you don't know the answer, just say that you don't know, don't try to make up an answer. \
    If the answer is not contained within the text below, say \"I don't know\". \
    {context}
"""
qa_prompt = ChatPromptTemplate.from_messages([
    ("system", qa_system_prompt),
    ("question", "{question}")
])

def format(docs):
    doc_strings = [doc["document"] for doc in docs]
    return "\n\n".join(doc_strings)

chain = (
    # prepare the context using below pipeline
    # generate queries -> corss encoding -> re-ranking -> return context
    {
        "context": RunnableLambda(create_original_query) | RunnableLambda(create_documents) | RunnableLambda(format), 
        "question": RunnablePassthrough()
    }
    | qa_prompt
    | llm
    | StrOutputParser()
)

result = chain.invoke({"question": "What is the capital of France?"})