In [1]:
import os, json, logging
import os.path as osp
import re

from typing import Any, Union, Tuple, Dict, Callable, List, Optional, Literal
from pprint import pprint
from datetime import datetime
from langchain.docstore.document import Document
from chromadb.config import Settings

from langchain import OpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import VectorStore, FAISS, Chroma, Pinecone
import pinecone
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
from langchain.output_parsers import PydanticOutputParser, OutputFixingParser, ListOutputParser
from langchain.chains.summarize import load_summarize_chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain, _split_list_of_docs, _collapse_docs
from langchain.chains.combine_documents.refine import RefineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.callbacks.manager import Callbacks
from custom_parsers import DrugOutput, DrugParser

import config
from config import MAIN_DIR, DATA_DIR, ARTIFACT_DIR, DOCUMENT_SOURCE

from shutil import rmtree
from utils import load_single_document, load_documents, convert_json_to_documents, convert_csv_to_documents
import yaml

from pydantic import root_validator

  from tqdm.autonotebook import tqdm


In [3]:
PROJECT = "uc"

with open(osp.join(MAIN_DIR, "auth", "api_keys.json"), "r") as f:
    keys = json.load(f)

OPENAI_KEY = keys["OPENAI_API_KEY"]
SOURCE_DATA = os.path.join(DOCUMENT_SOURCE, PROJECT)
EMBSTORE_DIR = os.path.join(config.EMBSTORE_DIR, PROJECT, "faiss", "text-embedding-ada-002")

EXCLUDE_DICT = {
    "agrawal.pdf": [13, 14, 15, 16, 17, 18],
    "PIIS1542356520300446.pdf": [12, 13, 14, 15, 16, 17, 18],
    "gutjnl-2021-326390R2 CLEAN.pdf": [0, 2, 31, 32, 33, 34, 35, 36,
                                       37, 38, 39, 40, 41, 42, 43, 44, 45]\
                                        + list(range(3, 31)),
    "otad009.pdf": [15, 16],
    "1-s2.0-S2468125321003770-main.pdf": [9],
    "juillerat 2022.pdf": [6, 7, 8],
}

In [4]:
LOGGER = logging.getLogger()

log_path = os.path.join(MAIN_DIR, "log", "logfile.txt")
file_handler = logging.FileHandler(
    filename=log_path)

formatter = logging.Formatter("%(asctime)s:%(levelname)s: %(message)s")
file_handler.setFormatter(formatter)

LOGGER.setLevel(logging.INFO)
LOGGER.addHandler(file_handler)

In [5]:
def convert_csv_to_documents(table_info: Dict, concatenate_rows: bool = True) -> List[Document]:
    assert table_info["mode"] == "table" and table_info["filename"].endswith(".csv")
    rows = load_single_document(os.path.join(MAIN_DIR, table_info["filename"]))
    documents = []
    table_content = table_info["description"] + "\n\n"
    for row in rows:
        if concatenate_rows:
            table_content += row.page_content + "\n\n"
            table_doc = Document(
                page_content=table_content,
                metadata=table_info["metadata"]
            )
        else:
            row_no = row.metadata["row"]
            metadata = {k: v for k, v in table_info["metadata"].items()}
            metadata["row"] = row_no
            metadata["modal"] = table_info["mode"]
            row.page_content = table_info["description"] + ":" + row.page_content
            row.metadata = metadata
            documents.append(row)
            
    if concatenate_rows:
        documents.append(table_doc)
    
    return documents
    
def check_documents_token(
    docs: List[Document],
    llm = ChatOpenAI(temperature=0,
                     model_name="gpt-3.5-turbo",
                     openai_api_key=OPENAI_KEY)
    ):
    if not isinstance(docs, List):
        docs = [docs]
    combine_document_chain = StuffDocumentsChain(
        llm_chain=LLMChain(
            llm=llm,
            prompt=PromptTemplate(template="{summaries}",
                                input_variables=["summaries"]),
            verbose=False,
        ),
        verbose=False
    )
    return combine_document_chain.prompt_length(docs)

In [6]:
class MapReduceDocumentsChainV2(MapReduceDocumentsChain):
    combine_max_tokens: int = 30000
    collapse_max_tokens: int = 5000

    @root_validator()
    def check_maximum_context_length(cls, values: Dict) -> Dict:
        max_token_dict = {
            "gpt-3.5-turbo": 3000,
            "gpt-3.5-turbo-16k": 14000,
            "gpt-4": 7000,
            "gpt-4-32k": 30000
        }
        
        combine_doc_llm_model = values["combine_document_chain"].llm_chain.llm.model_name
        if combine_doc_llm_model in max_token_dict:
            if max_token_dict[combine_doc_llm_model] < values["combine_max_tokens"]:
                values["combine_max_tokens"] = max_token_dict[combine_doc_llm_model]
        
        if values["collapse_document_chain"]:
            collapse_doc_llm_model = values["collapse_document_chain"].llm_chain.llm.model_name
        else:
            collapse_doc_llm_model = values["combine_document_chain"].llm_chain.llm.model_name
        
        if collapse_doc_llm_model in max_token_dict:
            if max_token_dict[collapse_doc_llm_model] < values["collapse_max_tokens"]:
                values["collapse_max_tokens"] = max_token_dict[collapse_doc_llm_model]

        return values

    def combine_docs(
        self,
        docs: List[Document],
        callbacks: Callbacks = None,
        **kwargs: Any,
    ) -> Tuple[str, dict]:
        """Combine documents in a map reduce manner.

        Combine by mapping first chain over all documents, then reducing the results.
        This reducing can be done recursively if needed (if there are many documents).
        """
        results = self.llm_chain.apply(
            # FYI - this is parallelized and so it is fast.
            [{self.document_variable_name: d.page_content, **kwargs} for d in docs],
            callbacks=callbacks,
        )
        return self._process_results(
            results, docs, callbacks=callbacks, **kwargs
        )

    def _process_results(
        self,
        results: List[Dict],
        docs: List[Document],
        callbacks: Callbacks = None,
        **kwargs: Any,
    ) -> Tuple[str, dict]:
        question_result_key = self.llm_chain.output_key
        result_docs = [
            Document(page_content=r[question_result_key], metadata=docs[i].metadata)
            # This uses metadata from the docs, and the textual results from `results`
            for i, r in enumerate(results)
        ]
        length_func = self.combine_document_chain.prompt_length
        num_tokens = length_func(result_docs, **kwargs)

        def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
            return self._collapse_chain.run(
                input_documents=docs, callbacks=callbacks, **kwargs
            )

        collapse_counter = 0
        while num_tokens is not None and num_tokens > self.combine_max_tokens:
            
            # 
            collapse_counter += 1
            if collapse_counter == 2:
                raise Exception("Double Collapse steps. Stop")            
            
            new_result_doc_list = _split_list_of_docs(
                result_docs, length_func, self.collapse_max_tokens, **kwargs
            )
            result_docs = []
            for docs in new_result_doc_list:
                new_doc = _collapse_docs(docs, _collapse_docs_func, **kwargs)
                result_docs.append(new_doc)
            num_tokens = self.combine_document_chain.prompt_length(
                result_docs, **kwargs
            )
        if self.return_intermediate_steps:
            _results = [r[self.llm_chain.output_key] for r in results]
            extra_return_dict = {"intermediate_steps": _results}
        else:
            extra_return_dict = {}
        output = self.combine_document_chain.run(
            input_documents=result_docs, callbacks=callbacks, **kwargs
        )
        return output, extra_return_dict

In [7]:
with open(osp.join(DATA_DIR, "queries", "uc.txt"), "r", encoding = "utf-8-sig") as f:
    test_cases = f.readlines()

test_cases = [test_case.rstrip() for test_case in test_cases]
test_cases

['40 year old male with newly diagnosed moderate UC and articular extraintestinal manifestations',
 '70 year old female with newly diagnosed severe UC',
 '35 year old male with known moderate UC with prior exposure to infliximab but has worsening colitis on endoscopy despite compliance',
 '60 year old female with newly diagnosed moderate UC with a background of congestive cardiac failure',
 '38 year old female with newly diagnosed moderate UC and psoriasis',
 '25 year old pregnant woman with severe distal ulcerative colitis',
 '56 year old man with moderate to severe ulcerative colitis and ankylosing spondylitis',
 '38 year old man with severe ulcerative colitis and has lost response to vedolizumab',
 '28 year old woman who has severe extensive ulcerative colitis and has a history of lymphoma which was treated 4 years ago',
 '36 year old woman with moderate ulcerative colitis and multiple sclerosis']

In [11]:
from langchain.agents import initialize_agent
from langchain.chat_models import ChatOpenAI
from langchain.agents import load_tools
from langchain.tools.base import BaseTool

from custom_chain import MapReduceDocumentsChainV2

In [10]:
llm = ChatOpenAI(
    model_name = "gpt-3.5-turbo",
    temperature = 0,
    openai_api_key = OPENAI_KEY,
    max_tokens = 512
)

In [None]:
class RetrieverTool(BaseTool):
    description: "Use this tool to search for information regarding treatment for patients with moderate to severe ulcerative colititis (UC)"