In [None]:
! pip install langchain_community langchain_openai faiss_cpu

In [None]:
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_text_splitters import CharacterTextSplitter
from google.colab import userdata
import os
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor, DocumentCompressorPipeline
from langchain_openai import OpenAI, ChatOpenAI, OpenAIEmbeddings
from langchain.chains import RetrievalQA
from langchain.retrievers.document_compressors import LLMChainFilter
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain_text_splitters import CharacterTextSplitter


In [None]:
documents = TextLoader("/content/state_of_the_union.txt").load()

In [None]:
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=100)

In [None]:
texts = text_splitter.split_documents(documents)

In [None]:
OPENAI_API_KEY = userdata.get("OPENAI_API_KEY")

In [None]:
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY

In [None]:
retriever = FAISS.from_documents(texts, OpenAIEmbeddings()).as_retriever()

In [None]:
docs = retriever.invoke("What did the president say about Ketanji Brown Jackson?")

In [None]:
# Helper function for printing docs
 
def pretty_print_docs(docs):
    print( 
        f"\n{'-' * 100}\n".join(
            [f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)]
        )
    )

In [None]:
pretty_print_docs(docs)

In [None]:
query = "What were the top tree priorities in the most recent State of the Union address?"

In [None]:
llm = OpenAI(temperature=0)

In [None]:
chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)

In [None]:
print(chain.invoke(query)["text"])

In [None]:
compressor = LLMChainExtractor.from_llm(llm)

In [None]:
compressor_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)

In [None]:
compressor_docs = compressor_retriever.invoke(query)

In [None]:
pretty_print_docs(compressor_docs)

In [None]:
_filter = LLMChainFilter.from_llm(llm)

In [None]:
compressor_retriever_2 = ContextualCompressionRetriever(base_compressor=_filter, base_retriever=retriever)

In [None]:
compressor_docs_2 = compressor_retriever_2.invoke(query)

In [None]:
pretty_print_docs(compressor_docs_2)

In [None]:
original_contexts_len = len("\n\n".join([d.page_content for i, d in enumerate(docs)]))

In [None]:
original_contexts_len

In [None]:
compressor_contexts_len = len("\n\n".join([d.page_content for i, d in enumerate(compressor_docs)]))

In [None]:
print("Compressed Ratio: ", f"{original_contexts_len/ (compressor_contexts_len + 1e-5):.2f}x")

In [None]:
embeddings = OpenAIEmbeddings()

In [None]:
embeddings_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76)

In [None]:
compressor_retriever_3 = ContextualCompressionRetriever(base_compressor=embeddings_filter, base_retriever=retriever)

In [None]:
pretty_print_docs(compressor_retriever_3)

In [None]:
compressor_docs_3 = compressor_retriever_3.invoke(query)

In [None]:
pretty_print_docs(compressor_docs_3)

In [None]:
compressor_contexts_len_2 = len("\n\n".join([d.page_content for i, d in enumerate(compressor_docs_3)]))

In [None]:
print("Compressed context length: ", compressor_contexts_len_2)

In [None]:
print("Compressed Ratio: ", f"{original_contexts_len/ (compressor_contexts_len_2 + 1e-5):.2f}x")

In [None]:
splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0, separator=". ")

In [None]:
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)

In [None]:
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76)

In [None]:
pipeline_compressor = DocumentCompressorPipeline(transformers=[splitter, redundant_filter, redundant_filter])

In [None]:
compressor_retriever_4 = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=retriever)

In [None]:
compressor_docs_4 = compressor_retriever_4.invoke(query)

In [None]:
pretty_print_docs(compressor_docs_4)

In [None]:
chain_1 = RetrievalQA.from_chain_type(llm=llm, retriever=compressor_docs_4)

In [None]:
chain_1.invoke(query)

In [None]:
print(chain_1.invoke(query)["result"])