In [None]:
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain_openai import AzureChatOpenAI
from dotenv import load_dotenv, find_dotenv
import PyPDF2
import os
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.document_loaders import UnstructuredFileLoader
from langchain.chains.summarize import load_summarize_chain
from langchain.chains.question_answering import load_qa_chain
from langserve import RemoteRunnable

_ = load_dotenv(find_dotenv())

chat_model = AzureChatOpenAI(
    openai_api_version="2023-09-01-preview",
    azure_endpoint=os.getenv('AZURE_API_ENDPOINT'),
    api_key=os.getenv('AZURE_OPENAI_API_KEY'),
    azure_deployment=os.getenv('OPENAI_DEPLOYMENT_NAME'),
    model_name=os.getenv('OPENAI_MODEL_NAME'),
    model_version=os.getenv('OPENAI_API_VERSION')
)

In [None]:
def getText(pdf_path):
        with open(pdf_path, "rb") as file:
            reader = PyPDF2.PdfReader(file)
            text = ""
            for page_num in range(len(reader.pages)):
                page = reader.pages[page_num]
                text += page.extract_text()
            return text

In [None]:
class CustomConversationChain(ConversationChain):
    def documentSummarizer(self, input):
        file_path = input
        article_content = getText(file_path)
        article_summary = self.summarize_content(article_content)      
        self.memory.save_context({"input": file_path}, {"content": article_content})       
        return article_summary

    def summarize_content(self, content):
        summary_prompt = f"Réalise un résumé concis du document : {content}"
        return self.llm.invoke(summary_prompt)

In [None]:
memory = ConversationBufferMemory()
chain = CustomConversationChain(
    llm=chat_model,
    memory=memory,
)

In [None]:
def summarizeDoc(pdf_path) :
    summary = chain.documentSummarizer(input=pdf_path)
    print(summary)
    
def askQuestion(question):
    response = chain.invoke(question)["response"]
    print(response)