In [None]:
!mc cp s3/$VAULT_TOP_DIR/Accords/Construction_dataset_public/Dataset_public_accords_teletravail_Dares.parquet .

In [None]:
import pandas as pd
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain_chroma import Chroma
from langchain.embeddings import OpenAIEmbeddings, OllamaEmbeddings
from langchain.document_loaders import TextLoader
from langchain.schema import Document, Generation, LLMResult
from langchain.llms import Ollama, BaseLLM
from langchain.chains import StuffDocumentsChain, RetrievalQA, LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_community.llms import OpenAI
from langchain_text_splitters import CharacterTextSplitter
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from pathlib import Path
import json
import requests

class LocalOllamaLLM(BaseLLM):
    api_url : str
    def _generate(self, prompt, stop):
        response = requests.post(f"{self.api_url}/api/generate", json={"model": "mistral-large", "prompt": str(prompt) })
        response.raise_for_status()
        response_text=''.join([json.loads(line)['response'] for line in response.text.splitlines()])
        generations=[]
        generations.append([Generation(text=response_text)])
        return LLMResult(generations=generations)


    def _llm_type(self):
        return "local"  # Or whatever type is appropriate for your local setup

llm = LocalOllamaLLM(api_url="http://127.0.0.1:11434")

text_splitter = CharacterTextSplitter(
    separator="\n\n",
    chunk_size=1000,
    chunk_overlap=200,
    length_function=len,
    is_separator_regex=False,
)

embedder = HuggingFaceEmbeddings(model_name="BAAI/bge-m3")

system_prompt = (
    " Utilisez le contexte donné pour répondre à la question.  "
    " Si vous ne connaissez pas la réponse, dites que vous ne savez pas.  "
    " Utilisez trois phrases au maximum et soyez concis dans votre réponse. "
    " En premier lieu, répondre en donnant une variable : variable=(valeur ou None)  . "
    " S'il y a plusieurs valeurs possibles, prendre le max : variable=max(valeurs ou None)  . "
    " Contexte : {context}  "
)
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        ("human", "{input}"),
    ]
)
question_answer_chain = create_stuff_documents_chain(llm, prompt)

In [None]:
file="Dataset_public_accords_teletravail_Dares.parquet"
df=pd.read_parquet(file)

In [None]:
new_dir = Path('results').mkdir(exist_ok=True)
for index, row in df.iterrows():
    text = df.texte_complet_accord[index]
    texts = text_splitter.create_documents([text])
    vector_store = Chroma(embedding_function=embedder)
    vector_store.add_documents(texts)
    retriever= vector_store.as_retriever()
    chain = create_retrieval_chain(retriever, question_answer_chain)
    query= "Combien de jour de télétravail par semaine est autorisé au maximum ?"
    result=chain.invoke({"input": query})
    with open(f"results/{index}.answer","w") as f:
        f.write(result['answer'])