In [20]:
import os
import configparser
from langchain_community.document_loaders import TextLoader, DirectoryLoader
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_mistralai.embeddings import MistralAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
from IPython.display import display, Latex
from warnings import filterwarnings
filterwarnings('ignore')

In [21]:
config = configparser.ConfigParser()
config.read('config')   
api_key = config.get('Mistral', 'api_key')

In [22]:
# Define the embedding model
embeddings = MistralAIEmbeddings(model="mistral-embed", mistral_api_key=api_key)

## Пытаемся получить и сохранить vector store
MistralAI дает too many requests ошибку, пытаемся ее обойти

In [23]:
# for i in range(1, 39):
#     # Load data
#     loader = TextLoader(f'data/chapter{i}.tex')
#     #loader = DirectoryLoader('data', glob="**/*r1.tex", loader_cls=TextLoader)
#     docs = loader.load()
#     # Split text into chunks 
#     text_splitter = RecursiveCharacterTextSplitter()
#     documents = text_splitter.split_documents(docs)
#     # Create the vector store 
#     try:
#         if not os.path.exists(f'vector_store/faiss_index_{i}'):
#             vector = FAISS.from_documents(documents, embeddings)
#             vector.save_local(f"vector_store/faiss_index_{i}")
#     except:
#         print(i)

# # Load the vector store
# vector = FAISS.load_local("vector_store/faiss_index_1", embeddings,  allow_dangerous_deserialization=True)

# for i in range(2, 39):
#     tmp = FAISS.load_local(f"vector_store/faiss_index_{i}", embeddings,  allow_dangerous_deserialization=True)
#     vector.merge_from(tmp)
# vector.save_local(f"vector_store/all_indexes")

In [24]:
#Load the vector store
vector = FAISS.load_local("vector_store/all_indexes", embeddings,  allow_dangerous_deserialization=True)

In [25]:
# Define a retriever interface
retriever = vector.as_retriever()

In [26]:
# Define LLM
model = ChatMistralAI(mistral_api_key=api_key)

In [27]:
# Define prompt template
prompt = ChatPromptTemplate.from_template("""Ответь на следующий вопрос основываясь только на предоставленном контексте:


<контекст>
{context}
</контекст>

Вопрос: {input}""")

# Смотрим, что выдает RAG на основе госбука

In [44]:
# Create a retrieval chain to answer questions
document_chain = create_stuff_documents_chain(model, prompt)
retrieval_chain = create_retrieval_chain(retriever, document_chain)
response = retrieval_chain.invoke({"input": "Критерий Коши для последовательности"})
display(Latex(response['answer']))

<IPython.core.display.Latex object>

# Смотрим, что выдает обычная модель

In [46]:
from langchain_core.messages import HumanMessage
messages = [HumanMessage(content="Критерий Коши для последовательности")]
display(Latex(model.invoke(messages).content))

<IPython.core.display.Latex object>