In [181]:
from typing import Optional, Sequence, Any
from langchain.chains import create_structured_output_runnable
from langchain_core.pydantic_v1 import BaseModel
from langchain.chat_models.gigachat import GigaChat
from langchain.chains.question_answering import load_qa_chain
from PyPDF2 import PdfReader

from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import GigaChatEmbeddings
from langchain.text_splitter import (
    RecursiveCharacterTextSplitter
)
import json

In [182]:
CREDENTIALS = 'Yjg4MTQzMmUtNDAwMS00NDk0LThjOGUtNmU5ZWQ2YzQ4NDQ2OmQ4MWMxZGZiLTFmNGYtNDk5NS05OGQzLTBiMzYyYWJmNjk3OA=='
TESTPDF = "../data/papers/10.1002@solr.201900061.pdf"

In [183]:
with open("prompts.json", encoding="utf8") as f:
    prompts = json.load(f)

LIST_OBJECTS_PROMPT, SYSTEM_PROMPT, MAKE_MD_TABLE_PROMPT = prompts['SYSTEM_PROMPT'], prompts['LIST_OBJECTS_PROMPT'], prompts['MAKE_MD_TABLE_PROMPT']

In [184]:
class Cell(BaseModel):
    perovskite: str
    HTM_type: str
    PCE: float
    VOC: float
    JSC: float
    FF: float

In [185]:
giga = GigaChat(credentials=CREDENTIALS, 
                verify_ssl_certs=False,
                scope='GIGACHAT_API_CORP',
                model="GigaChat-Pro",
                )

## PDF tools

In [186]:
def extract_raw_text_from_pdf(path) -> str:
    reader = PdfReader(stream=path)

    raw_text = ''
    for _, page in enumerate(reader.pages):
        text = page.extract_text()
        if text:
            raw_text += text
    return raw_text

def get_index_from_pdf(pdf_path) -> Any:
    text_splitter = RecursiveCharacterTextSplitter(
        # separator="\n",
        chunk_size=1000,
        chunk_overlap=200,
        length_function=len,
    )

    raw_text = extract_raw_text_from_pdf(path=pdf_path)
    texts = text_splitter.split_text(text=raw_text)
    embeddings = GigaChatEmbeddings(credentials=CREDENTIALS, verify_ssl_certs=False, scope='GIGACHAT_API_CORP')
    index = FAISS.from_texts(texts=texts, embedding=embeddings)
    return index

def invoke_chain_with_index(chain, index, query) -> dict:
    query = "For each of these {Spiro HTM, Spiro-CB, Spiro-THF} report efficiency (PCE or optimized efficiency or η). Put that data in the markdown table with columns 'HTM' - 'PCE'"
    docs = index.similarity_search(query)
    return chain.invoke({"input_documents": docs, "question": query})

In [187]:
class PDFExtractor:
    def __init__(self) -> None:
        pass
        # self.path = path
        # self.query = input
        
    def extract_raw_text_from_pdf(self, path: str, query: str) -> str:
        self.path = path
        self.query = query
        reader = PdfReader(stream=self.path)

        raw_text = ''
        for _, page in enumerate(reader.pages):
            text = page.extract_text()
            if text:
                raw_text += text
        return raw_text

    def get_index_from_pdf(self) -> Any:
        text_splitter = RecursiveCharacterTextSplitter(
            # separator="\n",
            chunk_size=1000,
            chunk_overlap=200,
            length_function=len,
        )

        raw_text = extract_raw_text_from_pdf(self)
        texts = text_splitter.split_text(text=raw_text)
        embeddings = GigaChatEmbeddings(credentials=CREDENTIALS, verify_ssl_certs=False, scope='GIGACHAT_API_CORP')
        index = FAISS.from_texts(texts=texts, embedding=embeddings)
        return index

    def invoke_chain_with_index(self, chain, index) -> dict:
        # query = "For each of these {Spiro HTM, Spiro-CB, Spiro-THF} report efficiency (PCE or optimized efficiency or η). Put that data in the markdown table with columns 'HTM' - 'PCE'"
        docs = index.similarity_search(self.query)
        return chain.invoke({"input_documents": docs, "question": self.query})

In [188]:
def extract_raw_text_from_pdf(path: str) -> str:
    reader = PdfReader(stream=path)

    raw_text = ''
    for _, page in enumerate(reader.pages):
        text = page.extract_text()
        if text:
            raw_text += text
    return raw_text

In [189]:
# runnable = create_structured_output_runnable(Cell, giga)

## Into action

In [190]:
# LIST_OBJECTS_PROMPT = 'Please list all the {object_type}s for which device efficiency results are given in this paper. Use python list format and names of HTM that were originally given by authors. For example, ["HTM1", "HTM2", "HTM3"].'
# for i,path in enumerate(sorted(glob.glob("../data/papers/2014/*.pdf"))):
#     # if i >= 5:
#     #     break
#     print(path)
#     index = get_index_from_pdf(path)
#     doc_chain = get_doc_chain(llm=giga, template=CONTEXT_BASED_Q_TEMPLATE, prompt=list_objects_prompt)
#     retrieve_chain = get_retrieval_chain(index=index, doc_chain=doc_chain)
#     response = retrieve_chain.invoke({"chat_history": chat_history, "input": LIST_OBJECTS_PROMPT.format(**dic)})
#     print(response["answer"])

In [191]:
from langchain.agents import AgentExecutor, Tool, create_gigachat_functions_agent, ZeroShotAgent
from langchain_core.prompts import ChatPromptTemplate
from langchain.memory import ConversationBufferMemory
from langchain.chains import LLMChain

In [192]:
tools = [
    Tool(
        name="Extractor_from_PDF",
        func=extract_raw_text_from_pdf,
        # func=PDFExtractor.extract_raw_text_from_pdf,
        description="extracts information from the PDF provided",
    ),
    Tool(
        name="Indexer_of_the_PDF",
        func=get_index_from_pdf,
        description="PDF indexation",
    ),
    Tool(
        name="Text_similarity_search",
        func=invoke_chain_with_index,
        description="performs a text similarity search in the given raw text",
    )
]

In [193]:
suffix = """Begin!"

{chat_history}
Question: {input}
{agent_scratchpad}"""

In [194]:
prompts['LIST_OBJECTS_PROMPT']

'Please list all the {object_type}s for which device efficiency results are given in this paper. Use python list format and names of HTM that were originally given by authors. For example, ["HTM1", "HTM2", "HTM3"]'

In [195]:
list_objects_prompt = ZeroShotAgent.create_prompt(
    tools=tools,
    prefix=prompts['LIST_OBJECTS_PROMPT'].format(object_type="HTM"),
    suffix=suffix,
    input_variables=["input", "chat_history", "agent_scratchpad", "object type"]    
)

In [196]:
memory = ConversationBufferMemory(memory_key="chat_history")
llm_chain = LLMChain(llm=giga, prompt=list_objects_prompt, memory=memory)
agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True \
                      )


In [197]:
agent_chain = AgentExecutor.from_agent_and_tools(
    agent=agent, tools=tools, verbose=True, memory=memory
)

In [198]:
input = {
    # "path" : "../data/papers/2014/2014_1.pdf",
    # "query" : "Does the device efficiency in the paper exceed 35%?",
    # "object_type": "HTM",
    # "input": "HZ"
    "input": "What is the PCE of a solar cell device in the paper?",
    # "agent_scratchpad": "fddfdfdfdfdffd",
}

In [199]:
response = agent_chain(input, return_only_outputs=True)
answer = response['answer']



[1m> Entering new AgentExecutor chain...[0m


ValueError: One input key expected got ['agent_scratchpad', 'input']

In [50]:
# agent = create_gigachat_functions_agent(giga, tools, prompt=list_objects_prompt)
# agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

In [8]:
import glob

In [10]:
runnable.invoke(glob.glob("../data/papers/2014/*.pdf")[0])

ResponseError: (URL('https://gigachat.devices.sberbank.ru/api/v1/chat/completions'), 422, b'{"status":422,"message":"Requested model does not support functions"}\n', Headers([('server', 'nginx'), ('date', 'Thu, 25 Apr 2024 11:37:57 GMT'), ('content-type', 'application/json; charset=utf-8'), ('content-length', '70'), ('connection', 'keep-alive'), ('access-control-allow-credentials', 'true'), ('access-control-allow-headers', 'Origin, X-Requested-With, Content-Type, Accept, Authorization'), ('access-control-allow-methods', 'GET, POST, DELETE, OPTIONS'), ('access-control-allow-origin', 'https://beta.saluteai.sberdevices.ru'), ('x-request-id', '809fce96-9baa-47e2-811a-0f8f45997c86'), ('x-session-id', 'db48cb12-48dd-41cd-bdac-12bc246fde6e'), ('allow', 'GET, POST'), ('strict-transport-security', 'max-age=31536000; includeSubDomains'), ('allow', 'GET, POST'), ('strict-transport-security', 'max-age=31536000; includeSubDomains')]))