In [None]:
import os
import numpy as np
from chromadb import HttpClient
from langchain.vectorstores import Chroma
from chromadb.utils import embedding_functions
from langchain_core.embeddings import Embeddings
from chromadb.api.types import EmbeddingFunction
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_mistralai.embeddings import MistralAIEmbeddings
from langchain.prompts import MessagesPlaceholder, ChatPromptTemplate
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.chat_message_histories import RedisChatMessageHistory
from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM
from langchain.chains import create_retrieval_chain, create_history_aware_retriever

In [None]:
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())

In [None]:
os.getenv('HF_TOKEN')

In [None]:
class ChromaEmbeddingsAdapter(Embeddings):
    def __init__(self, ef: EmbeddingFunction):
        self.ef = ef

    def embed_documents(self, texts):
        return self.ef(texts)

    def embed_query(self, query):
        return self.ef([query])[0]

In [None]:
embedding_fn = ChromaEmbeddingsAdapter(embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2"))

In [None]:
chroma_client = HttpClient(os.getenv('CHROMA_DB_URL'))
vector_store = Chroma(client=chroma_client, collection_name=os.getenv('DB_NAME'), 
                      embedding_function=embedding_fn)

In [None]:
model = ChatMistralAI(mistral_api_key=os.getenv('MISTRAL_API_KEY'))

In [None]:
chat_history_prompt = """
    Given a chat history and the latest user question which might reference context in the chat history, 
    formulate a standalone question which can be understood without the chat history. Do NOT answer 
    the question, just reformulate it if needed and otherwise return it as is.
"""
user_input_prompt = """
    ### [INST]
    Instruction: You are an expert political analyst with vast knowledge of the United States electoral process.
    You answer questions with certainty and you do not hallucinate. When unsure, you politely reply that you do 
    not have sufficient knowledge to answer the user question. You will generate new content by analysing the 
    context supplied with each user question. When your previous knowledge is capable of answering the questions, or when the 
    supplied context isn't enough to do so, you can default to previous knowledge. Using this instructions, answer the 
    following questions. Here is the supplied context:
    
    {context}
    
    [/INST]
"""

In [None]:
# get chromadb retriever from vector store
retriever = vector_store.as_retriever()

In [None]:
def get_message_history(session_id: str) -> RedisChatMessageHistory:
    return RedisChatMessageHistory(session_id, url=f"rediss://default:AdqMAAIncDE0MjU5OTI2NWM1MzM0MTMzODY2MTA0NzAwY2VjMDU3OHAxNTU5NDg@capital-garfish-55948.upstash.io:6379")

In [None]:
chat_history_context_prompt = ChatPromptTemplate.from_messages([
    ('system', chat_history_prompt),
    MessagesPlaceholder('chat_history'),
    ('human', '{input}')
])

# create history aware retriever using chromadb retriever
history_aware_retriever = create_history_aware_retriever(
    model,
    retriever,
    chat_history_context_prompt
)

# New question/answer prompt
chat_prompt = ChatPromptTemplate.from_messages(
    [('system', user_input_prompt),
      MessagesPlaceholder('chat_history'),
     ('human', '{input}')
     ]
)

# create document chain
question_answer_chain = create_stuff_documents_chain(model, chat_prompt)

# create rag_chain using system and user prompts
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)

# create runnable llm with message history
rag_chain_llm = RunnableWithMessageHistory(
    rag_chain,
    get_message_history,
    input_messages_key="input",
    history_messages_key="chat_history",
    output_messages_key="answer",
)

In [None]:
TEST_QUERY = """
    What's the latest in Texas?
"""
query_embeddings = embedding_fn.embed_query(TEST_QUERY)
np.array([query_embeddings]).shape

In [None]:
collection = chroma_client.get_collection(os.getenv('DB_NAME'))

In [None]:
print("There are", collection.count(), "items in the collection")

In [None]:
import uuid

In [None]:
session_id = str(uuid.uuid4())

In [None]:
message = "What's the latest with Donald Trump"

In [None]:
chat_response = rag_chain_llm.invoke(
    {"input": message},
    config={"configurable": {"session_id": session_id}}
)

In [None]:
' '.join(chat_response['answer'].split('\n\n'))

In [None]:
chat_response