In [23]:
import os
import json
import sys
import logging
import re
import openai
import tiktoken
import camelot

from typing import Union, Dict, List, Callable
from llama_index import SimpleDirectoryReader
from llama_index.schema import Document
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from llama_index.embeddings import OpenAIEmbedding
from llama_index.schema import Document, MetadataMode, NodeWithScore, TextNode
from llama_index.callbacks import CallbackManager, TokenCountingHandler

from config import MAIN_DIR, GUIDELINES
from utils import generate_vectorindex
from utils import load_vectorindex

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
openai.log = "info"

In [13]:
DATA_DIR = os.path.join(MAIN_DIR, "data")
DOCUMENT_DIR = os.path.join(MAIN_DIR, "data", "document_sources")
EXCLUDE_DICT = os.path.join(DATA_DIR, "exclude_pages.json")

with open(os.path.join(MAIN_DIR, "auth", "api_keys.json"), "r") as f:
    api_keys = json.load(f)

os.environ["OPENAI_API_KEY"] = api_keys["OPENAI_API_KEY"]
openai.api_key = api_keys["OPENAI_API_KEY"]

In [24]:
def convert_prompt_to_string(prompt) -> str:
    return prompt.format(**{v: v for v in prompt.template_vars})

def generate_query(profile: str, scan: str):
    return "Patient Profile: {}\nScan ordered: {}".format(profile, scan)

def convert_doc_to_dict(doc: Union[Document, NodeWithScore, Dict]) -> Dict:
    if isinstance(doc, NodeWithScore):
        json_doc = {
            "page_content": doc.text,
            "metadata": doc.metadata,
            "score": doc.score
            } 
    elif isinstance(doc, Document):
        json_doc = {
            "page_content": doc.text,
            "metadata": doc.metadata,
            "score": ""
            }
    elif isinstance(doc, Dict):
        json_doc = {
            "page_content": doc["text"],
            "metadata": doc["metadata"],
            "score": "None"
        }
    return json_doc

def get_experiment_logs(description: str, log_folder: str):
    logger = logging.getLogger(description)

    stream_handler = logging.StreamHandler(sys.stdout)

    if not os.path.exists(log_folder):
        os.makedirs(log_folder, exist_ok=True)

    file_handler = logging.FileHandler(filename=os.path.join(log_folder, "logfile.log"))

    formatter = logging.Formatter("%(asctime)s:%(levelname)s: %(message)s")
    file_handler.setFormatter(formatter)
    stream_handler.setFormatter(formatter)

    logger.setLevel(logging.INFO)
    logger.addHandler(stream_handler)
    logger.addHandler(file_handler)
    
    return logger

def filter_by_pages(
    doc_list: List[Document],
    exclude_info: Dict[str, List]
) -> List[Document]:
    filtered_list = []
    for doc in doc_list:
        file_name = doc.metadata["file_name"]
        page = doc.metadata["page_label"]
        if file_name not in exclude_info.keys():
            filtered_list.append(doc)
            continue
        if int(page) not in exclude_info[file_name]:
            filtered_list.append(doc)

    return filtered_list

def count_tokens(
    texts: Union[str, TextNode, NodeWithScore, List],
    tokenizer: Callable = tiktoken.encoding_for_model("gpt-3.5-turbo")
):
    token_counter = 0
    if not isinstance(texts, List):
        texts = [texts]
    for text in texts:
        if isinstance(text, NodeWithScore):
            text_str = text.node.text
        elif isinstance(text, TextNode):
            text_str = text.text
        elif isinstance(text, str):
            text_str = text
        token_counter += len(tokenizer.encode(text_str))
    return token_counter

def organize_by_files(
    doc_list: List[Document]
):
    doc_dict = {}
    for doc in doc_list:
        filename = doc.metadata["file_name"]
        if filename not in doc_dict:
            doc_dict[filename] = [doc]
        else:
            doc_dict[filename].append(doc)
            
    return doc_dict
    

In [16]:
documents = SimpleDirectoryReader(DOCUMENT_DIR).load_data()
print("Total no of docs before filtering:", len(documents))
with open(EXCLUDE_DICT, "r") as f:
    exclude_pages = json.load(f)
documents = filter_by_pages(doc_list=documents, exclude_info=exclude_pages)
print("Total number of docs after filtering", len(documents))

Total no of docs before filtering: 546
Total number of docs after filtering 395


In [39]:
extract_prompt = """You are given a text which contains descriptions of patient variants and the corresponding table with the appropriateness of different image procedure and radiation level.
Extract only the description of the variant. Output your answer as follow:
Variant 1:
Variant 2: 
========
EXAMPLE:
TEXT: Variant 1:  Chronic ankle pain . Initial imaging . 
Procedure  Appropriateness Category  Relative Radiation Level  
Radiography ankle  Usually Appropriate  ☢ 
Bone scan ankle Usually Not Appropriate  ☢☢☢ 
US ankle  Usually Not Appropriate  O 
CT ankle without IV contrast  Usually Not Appropriate  ☢ 
CT ankle with IV contrast  Usually Not Appropriate  ☢ 
Variant 2:  Chronic ankle pain. M ultiple sites of degenerative joint disease in the hind foot detected by 
ankle radiographs . Next study.  
Procedure  Appropriateness Catego ry Relative Radiation Level  
Image -guided anesthetic injection ankle and 
hindfoot  May Be Appropriate  Varies
ANSWER:
Variant 1: Chronic ankle pain. Initial imaging.
Variant 2: Chronic ankle pain. Multiple sites of degenerative joint disease in the hind foot detected by 
ankle radiographs. Next study.
========
TEXT: {input_query}
"""

EXTRACT_PROMPT_TEMPLATE = PromptTemplate.from_template(extract_prompt)

extract_chain = LLMChain(llm=ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, max_tokens=256), prompt=EXTRACT_PROMPT_TEMPLATE)

In [None]:
table_list = []
text_list = []

for idx, doc in enumerate(documents):
    page = doc.metadata["page_label"]
    filename = doc.metadata["file_name"]
    condition = os.path.splitext(filename)[0]
    tables = camelot.read_pdf(
        os.path.join(DOCUMENT_DIR, filename), pages=page, suppress_stdout=True
        )
    if not tables:
        doc.metadata["mode"] = "text"
        doc.metadata["condition"] = condition.lower()
        doc.excluded_embed_metadata_keys = ['file_name', 'page_label']
        doc.excluded_llm_metadata_keys = ['file_name', 'page_label']
        text_list.append(doc)
    else:
        table_texts = []
        text_flag = False
        for table in tables:
            table_df = table.df
            table_df = table_df.rename(columns=table_df.iloc[0])\
                    .rename(columns={"Radiologic Procedure": "Procedure", "Rating": "Appropriateness Category"})\
                    .drop(table_df.index[0])\
                    .reset_index(drop=True)
            if "Procedure" in table_df.columns:
                table_df = table_df[["Procedure", "Appropriateness Category"]]
                table_df["Procedure"] = table_df["Procedure"].str.replace("\n", " ")
                table_str = table_df.to_markdown(index=False)
                table_str = re.sub(r" +", " ", table_str)
                table_str = re.sub(r":-+|-+:", "---", table_str)
                table_texts.append(table_str)
            else:
                print("File Name: {}, Page: {}, Columns: {}"\
                    .format(filename, page, table_df.columns))
                text_flag = True
        if text_flag:
            doc.metadata["mode"] = "text"
            doc.metadata["condition"] = condition.lower()
            doc.excluded_embed_metadata_keys = ['file_name', 'page_label']
            doc.excluded_llm_metadata_keys = ['file_name', 'page_label']
            text_list.append(doc)
        else:
            # pattern = r"Variant ([0-9])+ *:([\s\S]+?)Procedure Appropriateness Category Relative Radiation Level"
            text = extract_chain(doc.text)["text"]
            pattern = r"Variant ([0-9])+ *:([^\n]+)"
            table_infos = re.findall(pattern, text)
            # table_descriptions = [description.strip() for description in table_descriptions]
            assert len(table_texts) == len(table_infos), f"{table_texts}\n{table_infos}"
            for table_text, table_info in zip(table_texts, table_infos):
                variant_no, table_description = table_info
                table_list.append(
                    {
                    "text": table_text,
                    "metadata": {
                        "mode": "tabular",
                        "condition": condition.lower(),
                        "description": table_description,
                        "variant": variant_no,
                        "file_name": filename,
                        "page_label": page
                    }
                    }
                )

multimodal_vector_path = os.path.join(DATA_DIR, "multimodal")
with open(os.path.join(multimodal_vector_path, "tables.json"), "w") as f:
    json.dump(table_list, f)

with open(os.path.join(multimodal_vector_path, "texts.json"), "w") as f:
    json.dump(text_list, f)

# Chroma: Only Table Descriptions & Metadata Filtering

In [None]:
multimodal_vector_path = os.path.join(DATA_DIR, "multimodal")

with open(os.path.join(multimodal_vector_path, "tables.json"), "r") as f:
    table_list = json.load(f)

with open(os.path.join(multimodal_vector_path, "texts.json"), "r") as f:
    text_list = json.load(f)

In [None]:
table_docs = []
for table in table_list:
    table["metadata"]["mode"] = "tabular"
    doc = Document(
        text=table["text"],
        metadata=table["metadata"],
        excluded_embed_metadata_keys = ['file_name', 'page_label', 'variant', 'mode'],
        excluded_llm_metadata_keys = ['file_name', 'page_label', 'variant']
        )
    table_docs.append(doc)
    
text_docs = []
for text in text_list:
    text["metadata"]["mode"] = "tabular"
    doc = Document(
        text=text["text"],
        metadata=text["metadata"],
        excluded_embed_metadata_keys = ['file_name', 'page_label', 'variant', 'mode'],
        excluded_llm_metadata_keys = ['file_name', 'page_label', 'variant']
        )
    text_docs.append(doc)

In [18]:
token_counter = TokenCountingHandler(
    tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode
)

callback_manager = CallbackManager([token_counter])

embed_model = OpenAIEmbedding()

table_docs = []
for table in table_list:
    table["metadata"]["mode"] = "tabular"
    doc = Document(
        text=table["text"],
        metadata=table["metadata"],
        excluded_embed_metadata_keys = ['file_name', 'page_label', 'variant', 'mode'],
        excluded_llm_metadata_keys = ['file_name', 'page_label', 'variant']
        )
    table_docs.append(doc)

description_texts = [doc.get_metadata_str(mode=MetadataMode.EMBED) for doc in table_docs]
description_embs = embed_model.get_text_embedding_batch(description_texts)

for doc, emb in zip(table_docs, description_embs):
    doc.embedding = emb

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embedding

In [32]:
# Count Tokens
total_tokens = 0

for doc in table_docs:
    total_tokens += count_tokens(doc.get_content(MetadataMode.EMBED))
for doc in text_docs:
    total_tokens += count_tokens(doc.get_content(MetadataMode.EMBED))
    
print(total_tokens)

381739


In [68]:
multimodal_vector_path = os.path.join(DATA_DIR, "multimodal-chroma")
desc_persist_dir = os.path.join(multimodal_vector_path, "descriptions")
if not os.path.exists(desc_persist_dir):
    os.makedirs(desc_persist_dir, exist_ok=True)

In [None]:
# Table Index
generate_vectorindex(
    embeddings=embed_model,
    emb_size=1536,
    documents=table_docs,
    output_directory=os.path.join(desc_persist_dir, "tables"),
    emb_store_type="chroma",
    chunk_size=1024,
    chunk_overlap=0,
    index_name="tables",
)

In [79]:
generate_vectorindex(
    embeddings=embed_model,
    emb_size=1536,
    documents=text_docs,
    output_directory=os.path.join(desc_persist_dir, "texts"),
    emb_store_type="chroma",
    chunk_size=512,
    chunk_overlap=20,
    index_name="texts"
)

2023-10-27 16:02:18,999:INFO: Processing documents from provided list.
INFO:config:Processing documents from provided list.
2023-10-27 16:02:19,001:INFO: 293 documents remained after page filtering.
INFO:config:293 documents remained after page filtering.
2023-10-27 16:02:19,003:INFO: Total number of text chunks to create vector index store: 293
INFO:config:Total number of text chunks to create vector index store: 293
2023-10-27 16:02:19,006:INFO: Creating chroma Vectorstore
INFO:config:Creating chroma Vectorstore
INFO:chromadb.telemetry.posthog:Anonymized telemetry enabled. See https://docs.trychroma.com/telemetry for more information.
2023-10-27 16:04:05,261:INFO: Successfully created chroma vectorstore at ../data/multimodal-chroma/descriptions/texts
INFO:config:Successfully created chroma vectorstore at ../data/multimodal-chroma/descriptions/texts


In [52]:
tables_index = load_vectorindex(
    db_directory = os.path.join(desc_persist_dir, "tables"),
    emb_store_type = "chroma", index_name = "tables",
)

INFO:chromadb.telemetry.posthog:Anonymized telemetry enabled. See https://docs.trychroma.com/telemetry for more information.


2023-10-25 23:06:12,600:INFO: chroma VectorStore successfully loaded from ../data/multimodal-chroma/descriptions/tables.
INFO:config:chroma VectorStore successfully loaded from ../data/multimodal-chroma/descriptions/tables.


In [53]:
texts_index = load_vectorindex(
    db_directory = os.path.join(desc_persist_dir, "texts"),
    emb_store_type = "chroma", index_name = "texts",
)

INFO:chromadb.telemetry.posthog:Anonymized telemetry enabled. See https://docs.trychroma.com/telemetry for more information.
2023-10-25 23:06:21,228:INFO: chroma VectorStore successfully loaded from ../data/multimodal-chroma/descriptions/texts.
INFO:config:chroma VectorStore successfully loaded from ../data/multimodal-chroma/descriptions/texts.
