In [14]:
import os
import re

import chromadb
import pymupdf
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.schema import Document
from langchain_chroma import Chroma
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_community.document_loaders import WikipediaLoader, YoutubeLoader
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_text_splitters import RecursiveCharacterTextSplitter

In [15]:
store = {}

contextualize_q_system_prompt = (
    "Given a chat history and the latest user question "
    "which might reference context in the chat history, "
    "formulate a standalone question which can be understood "
    "without the chat history. Do NOT answer the question, "
    "just reformulate it if needed and otherwise return it as is."
)

system_prompt = (
    "You are an assistant for question-answering tasks. "
    "Use the following pieces of retrieved context to answer "
    "the question. If you don't know the answer, say that you "
    "don't know. Use three sentences maximum and keep the "
    "answer concise."
    "\n\n"
    "{context}"
)

client = chromadb.Client()
llm = AzureChatOpenAI(
    azure_deployment=os.getenv("DEPLOYMENT_NAME_LLM"),
    openai_api_version="2023-06-01-preview",
    model_version="0301",
)
embedding = AzureOpenAIEmbeddings(
    azure_deployment=os.getenv("DEPLOYMENT_NAME_EMBEDDING"),
    openai_api_version="2023-05-15",
)
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000, chunk_overlap=200, add_start_index=True
)

In [16]:
def preprocess_wiki(page_name):
    """
    Récupère les données d'une page wikipedia et les retourne
    sous la forme d'une liste de plusieurs documents
    """
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000, chunk_overlap=200, add_start_index=True
    )
    docs = WikipediaLoader(query=page_name, load_max_docs=1, doc_content_chars_max=10000).load()
    doc_splits = text_splitter.split_documents(docs)
    return doc_splits

def create_vectorstore_and_retriever(docs, embedding, client, username):
    """
    Crée un vectorstore de documents qui est enregistré au sein
    d'une collection du client Chroma, et retourne un retriever
    de ce vectorstore.
    """
    vectorstore = Chroma.from_documents(
        documents=docs,
        embedding=embedding,
        client=client,
        collection_name=f"{username}_collection",
    )
    retriever = vectorstore.as_retriever()
    return retriever

In [17]:
def get_session_history(session_id: str) -> BaseChatMessageHistory:
    """
    Retourne l'historique des messages avec un Chat d'une session donnée.
    Si il n'existe pas d'historique pour cette session, en crée un.
    """
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]

def create_prompts_and_chains(llm, retriever):
    """
    Crée et retourne une rag chain où le retriever et le llm ont accès
    à l'historique des messages du Chat.
    """
    contextualize_q_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", contextualize_q_system_prompt),
            MessagesPlaceholder("chat_history"),
            ("human", "{input}"),
        ]
    )
    history_aware_retriever = create_history_aware_retriever(
        llm, retriever, contextualize_q_prompt
    )

    qa_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt),
            MessagesPlaceholder("chat_history"),
            ("human", "{input}"),
        ]
    )
    question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)

    rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
    conversational_rag_chain = RunnableWithMessageHistory(
        rag_chain,
        get_session_history,
        input_messages_key="input",
        history_messages_key="chat_history",
        output_messages_key="answer",
    )
    return conversational_rag_chain

def execute_rag_chain(conversational_rag_chain, username, query):
    """
    Transmets une query et une session utilisateur 
    à une rag chain et retourne la réponse.
    """
    response = conversational_rag_chain.invoke(
        {"input": query},
        config={"configurable": {"session_id": username}}
    )
    return response["answer"]

def get_answer_llm(username, retriever, query, llm):
    """Crée une rag chain, lui transmets une query et retourne la réponse."""
    conversational_rag_chain = create_prompts_and_chains(llm, retriever)
    answer = execute_rag_chain(conversational_rag_chain, username, query)
    return answer

In [18]:
def create_retriever(source, username):
    global retriever
    docs = preprocess_wiki(source)
    retriever = create_vectorstore_and_retriever(docs, embedding, client, username)

def query_llm(query, username):
    return get_answer_llm(username, retriever, query, llm)

### Tests

In [19]:
username = "Q"

create_retriever("Mistral AI", username)
collection = client.get_collection(f"{username}_collection")

In [22]:
create_retriever("Apple Inc.", username)

In [23]:
print(collection)
print(collection.count())

name='Q_collection' id=UUID('0b61fd91-3d43-413e-8a00-ae9ec4d83a88') metadata=None tenant='default_tenant' database='default_database'
27
