# Create Directories

In [93]:
import os, json, logging, sys
import openai
from datetime import datetime
from config import MAIN_DIR, ARTIFACT_DIR, EXCLUDE_DICT
import tiktoken
import faiss
from typing import Union, Callable

from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers import CommaSeparatedListOutputParser

from llama_index.storage import StorageContext
from llama_index.graph_stores import NebulaGraphStore
from llama_index.vector_stores import FaissVectorStore
from llama_index import (
    ServiceContext,
    SimpleDirectoryReader,
    VectorStoreIndex,
    load_index_from_storage,
    get_response_synthesizer
    )
from llama_index.llms import OpenAI
from llama_index.embeddings import OpenAIEmbedding
from llama_index.callbacks import CallbackManager, TokenCountingHandler
from llama_index.indices.postprocessor import LongContextReorder
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.llms import ChatMessage, MessageRole
from llama_index.prompts import ChatPromptTemplate

logging.basicConfig(stream=sys.stdout, level=logging.WARNING)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

In [None]:
DATA_DIR = os.path.join(MAIN_DIR, "data", "document_store", "uc")
LLAMAINDEX_DIR = os.path.join(MAIN_DIR, "data", "llamaindex")
KG_PERSIST_DIR = os.path.join(LLAMAINDEX_DIR, "kg_store")
VECTOR_PERSIST_DIR = os.path.join(LLAMAINDEX_DIR, "vector_store")

with open(os.path.join(MAIN_DIR, "auth", "api_keys.json"), "r") as f:
    api_keys = json.load(f)

openai.api_key = api_keys["OPENAI_API_KEY"]
os.environ["OPENAI_API_KEY"] = api_keys["OPENAI_API_KEY"]

# Helper Functions

In [102]:
from typing import List, Dict
from llama_index.schema import Document
from langchain.output_parsers import ListOutputParser

def filter_by_pages(
    doc_list: List[Document], exclude_info: Dict[str, List]
) -> List[Document]:
    
    filtered_list = []
    for doc in doc_list:
        file_name = doc.metadata["file_name"]
        page = doc.metadata["page_label"]
        if file_name not in exclude_info.keys():
            filtered_list.append(doc)
            continue
        if int(page) not in exclude_info[file_name]:
            filtered_list.append(doc)

    return filtered_list

def correct_page_label(doc_list):
    
    page_counter = {}
    for doc in doc_list:
        document_name = doc.metadata["file_name"]
        if document_name in page_counter:
            current_page = page_counter[document_name]
            doc.metadata["page_label"] = current_page
        else:
            current_page = 1
            page_counter[document_name] = 1
            doc.metadata["page_label"] = 1
        page_counter[document_name] += 1
    return doc_list

def print_token_usage(token_counter):
    print("Tokens Consumption: Total: {}, Prompt: {}, Completion: {}, Embeddings: {}"
          .format(token_counter.total_llm_token_count,
                  token_counter.prompt_llm_token_count,
                  token_counter.completion_llm_token_count,
                  token_counter.embedding_token_counts))
    
def process_document_node(node):
    return {"text": node.text, "metadata": node.metadata}

def convert_prompt_to_string(prompt) -> str:
    return prompt.format(**{v: v for v in prompt.template_vars})

fix_prompt_template = """TASK: Extract the recommended drug names from the provided text query.
If there are no drugs from the answer, return an empty list.
===============
FORMAT INSTRUCTIONS: Your output should be a list of comma separated values of drugs extracted from the text query.
===============
EXAMPLES:

TEXT QUERY: Based on the given context, the patient is a new patient with severe ulcerative colitis (UC). However, the context does not provide information on the patient's prior response to Infliximab, prior failure to Anti-TNF agents, prior failure to Vedolizumab, pregnancy status, extraintestinal manifestations, or pouchitis. 
Given the patient's age and severity of UC, the top two choices of biological drugs could be Vedolizumab (VDZ) and Infliximab (IFX). 
1. Vedolizumab:
   - Advantages: VDZ has been shown to be superior to Adalimumab in achieving clinical remission and endoscopic improvement in moderate to severe UC. It has a good safety profile and is a gut-selective medication, which means it specifically targets the gut and spares other systems.
   - Disadvantages: VDZ is associated with an increased risk of new or worsening extraintestinal manifestations. It may not be feasible due to payer preference.
2. Infliximab:
   - Advantages: IFX is an effective option with a good safety profile. It is the preferred biologic for induction and maintenance of remission in patients with severe UC requiring hospitalization.
   - Disadvantages: There is a potential for the colon to act as a “sink” for drugs, which means that drug concentrations should be checked early to ensure proper dosing and to detect immunogenicity early.
ANSWER: Vedolizumab, Infliximab

TEXT QUERY: Based on the given context, no appropriate drug recommendation can be given
ANSWER: 

===============
TEXT QUERY: {query}
ANSWER:
"""

FIX_PROMPT = PromptTemplate.from_template(fix_prompt_template)

fixing_chain = LLMChain(
    llm=ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, max_tokens=128),
    prompt=FIX_PROMPT
)

def count_document_tokens(
    doc_list: Union[str, Document, NodeWithScore, List[Union[str, Document, NodeWithScore]]],
    tokenize_fn: Callable = tiktoken.encoding_for_model("gpt-3.5-turbo").encode
):
    if not isinstance(doc_list, List):
        doc_list = [doc_list]
    token_counter = 0
    for doc in doc_list:
        text = doc if isinstance(doc, str) else doc.text
        token_counter += len(tokenize_fn(text))
    
    return token_counter

# Setup Nebula KG Space

In [None]:
# !curl -fsSL nebula-up.siwei.io/install.sh | bash

┌──────────────────────────────────────────────────────────────────────────────────────────┐
│ 🌌 Nebula-Graph Playground is on the way...                                              │
├──────────────────────────────────────────────────────────────────────────────────────────┤
│.__   __.  _______ .______    __    __   __          ___            __    __  .______     │
│|  \ |  | |   ____||   _  \  |  |  |  | |  |        /   \          |  |  |  | |   _  \    │
│|   \|  | |  |__   |  |_)  | |  |  |  | |  |       /  ^  \   ______|  |  |  | |  |_)  |   │
│|  . `  | |   __|  |   _  <  |  |  |  | |  |      /  /_\  \ |______|  |  |  | |   ___/    │
│|  |\   | |  |____ |  |_)  | |  `--   | |   ----./  _____  \       |  `--   | |  |        │
│|__| \__| |_______||______/   \______/  |_______/__/     \__\       \______/  | _|        │
└──────────────────────────────────────────────────────────────────────────────────────────┘

 ℹ️    VERSION not provided

 ℹ️    Installing NebulaGraph release-3.6

In [6]:
os.environ['GRAPHD_HOST'] = "127.0.0.1"
os.environ['GRAPHD_PORT'] = "9669"
os.environ['NEBULA_USER'] = "root"
os.environ['NEBULA_PASSWORD'] = "nebula"
os.environ['NEBULA_ADDRESS'] = "{}:{}".format(os.environ["GRAPHD_HOST"], os.environ["GRAPHD_PORT"])

In [7]:
%reload_ext ngql
connection_string = "--address {} --port {} --user {} --password {}"\
    .format(os.environ["GRAPHD_HOST"],
            os.environ['GRAPHD_PORT'],
            os.environ['NEBULA_USER'],
            os.environ['NEBULA_PASSWORD'])
%ngql {connection_string}

Connection Pool Created


Unnamed: 0,Name
0,uc_recommendations


In [129]:
# # Setup New NebulaGraph Space
# space_name = "uc_recommendations"

# %ngql DROP SPACE IF EXISTS $space_name;
# %ngql SHOW SPACES

Unnamed: 0,Name


In [130]:
# %ngql CREATE SPACE $space_name(vid_type=FIXED_STRING(256), partition_num=1, replica_factor=1);
# %ngql SHOW SPACES

Unnamed: 0,Name
0,uc_recommendations


In [131]:
# %ngql USE $space_name;
# %ngql CREATE TAG entity(name string);
# %ngql CREATE EDGE relationship(relationship string);
# %ngql CREATE TAG INDEX entity_index ON entity(name(256));

In [132]:
space_name = "uc_recommendations"
%ngql USE $space_name;

# Prepare Graph

In [57]:
# Prepare Documents
original_documents = SimpleDirectoryReader(DATA_DIR).load_data()
exclude_dict = {file_name: [page + 1 for page in pages] for file_name, pages in EXCLUDE_DICT.items()}

documents = filter_by_pages(correct_page_label(original_documents), exclude_dict)
print("Number of documents:", len(documents))

Number of documents: 86


In [133]:
## Create Nebula Knowledge Graph

from llama_index import KnowledgeGraphIndex

graph_args = dict(
    space_name="uc_recommendations",
    edge_types=["relationship"],
    rel_prop_names=["relationship"],
    tags=["entity"]
)

# Setup Storage Context
graph_store = NebulaGraphStore(
    **graph_args
)

kg_storage_context = StorageContext.from_defaults(
    graph_store=graph_store
)

kg_service_context = ServiceContext.from_defaults(
    llm=OpenAI(model="gpt-3.5-turbo", temperature=0, max_tokens=512),
    embed_model=OpenAIEmbedding(), chunk_size=512, chunk_overlap=20
)

kg_index = KnowledgeGraphIndex.from_documents(
    documents = documents,
    max_triplets_per_chunk=15,
    storage_context=kg_storage_context,
    service_context=kg_service_context,
    include_embeddings=True,
    **graph_args
)

if not os.path.exists(KG_PERSIST_DIR):
    os.makedirs(KG_PERSIST_DIR, exist_ok=True)

kg_index.storage_context.persist(persist_dir=KG_PERSIST_DIR)

(armamentarium, growing, treatment)
(moderate-to-severe ulcerative colitis, treatment, biologics and small molecule drugs)
(biologics and small molecule drugs, efficacy and safety, patients)
(moderate-to-severe ulcerative colitis, systematic review, network meta-analysis)
(Juan S Lasa, contributed equally, Pablo A Olivera)
(Pablo A Olivera, contributed equally, Silvio Danese)
(Silvio Danese, efficacy and safety, patients)
(Laurent Peyrin-Biroulet, correspondence to, Prof Laurent Peyrin-Biroulet)
(Laurent Peyrin-Biroulet, efficacy and safety, patients)
(moderate-to-severe ulcerative colitis, treatment, systematic review)
(moderate-to-severe ulcerative colitis, treatment, network meta-analysis)
(Laurent Peyrin-Biroulet, correspondence to, peyrinbiroulet@gmail.com)
(biologics and small molecule drugs, efficacy and safety, moderate-to-severe ulcerative colitis)
(Laurent Peyrin-Biroulet, correspondence to, INSERM NGERE and Department of Hepatogastroenterology)
(Laurent Peyrin-Biroulet, corr

In [58]:
# d = 1536
# faiss_index = faiss.IndexFlatL2(d)

# vector_store = FaissVectorStore(faiss_index=faiss_index)
# vector_storage_context = StorageContext.from_defaults(
#     vector_store=vector_store
# )
# vector_service_context = ServiceContext.from_defaults(
#     chunk_size = 512, chunk_overlap = 64,
#     embed_model=OpenAIEmbedding()
# )

# vector_index = VectorStoreIndex.from_documents(
#     documents = documents,
#     service_context = vector_service_context,
#     storage_context = vector_storage_context,
#     show_progress=True
# )

# if not os.path.exists(VECTOR_PERSIST_DIR):
#     os.makedirs(VECTOR_PERSIST_DIR, exist_ok=True)

# vector_index.storage_context.persist(persist_dir=VECTOR_PERSIST_DIR)

  from .autonotebook import tqdm as notebook_tqdm
Parsing documents into nodes: 100%|██████████| 86/86 [00:00<00:00, 109.89it/s]
Generating embeddings: 100%|██████████| 283/283 [00:28<00:00, 10.07it/s]


# Experiment 1: TableKGRetriever

In [None]:
with open(os.path.join(MAIN_DIR, "data", "queries", "uc_all.txt"), "r") as f:
    test_cases = f.readlines()

## Experiment Settings

In [None]:
temperature = 0
model_name = "gpt-4"
max_tokens = 512
k = 7
description = "KG_Chat4"
time = datetime.now().strftime("%d-%m-%Y-%H-%M")
save_path = os.path.join(ARTIFACT_DIR, f"{model_name}_{description}_{time}")
print("Save directory:", save_path)
gt_path = os.path.join(MAIN_DIR, "data", "queries", "uc_all_gt.csv")
response_mode = "simple_summarize"

token_counter = TokenCountingHandler(
    tokenizer=tiktoken.encoding_for_model(model_name).encode
)
callback_manager = CallbackManager([token_counter])

## KG Retriever

### Load KG Index

In [21]:
## Load Knowledge Graph
graph_args = dict(
    space_name="uc_recommendations",
    edge_types=["relationship"],
    rel_prop_names=["relationship"],
    tags=["entity"]
)

# Setup Storage Context
graph_store = NebulaGraphStore(
    **graph_args
)

storage_context = StorageContext.from_defaults(
    persist_dir=KG_PERSIST_DIR,
    graph_store=graph_store
)

kg_index = load_index_from_storage(
    storage_context=storage_context, **graph_args, verbose=True
)

### Setup Retriever

In [69]:
kg_retriever = kg_index.as_retriever(
    retrieval_mode = "hybrid", # hybrid, keyword, embedding
    max_keywords_per_query=10,
    num_chunks_per_query=10,
    similarity_top_k=2,
    graph_store_query_depth=2,
    include_text=True,
    verbose=False
)

## Setup Response Synthesizer

### Prompt

In [64]:
system_template = """
You are a physician assistant giving advice on treatment for moderate to severe ulcerative colitis (UC).
Make reference to the CONTEXT given to assess the scenario.
If the answer cannot be inferred from CONTEXT, return "NO ANSWER", don't try to make up an answer.
=========
TASK: ANALYSE the given patient profile based on given query based on one of the following criteria:
- Whether treated patient is new patient or patient under maintenance
- Prior response to Infliximab
- Prior failure to Anti-TNF agents
- Prior failure to Vedolizumab
- Age
- Pregnancy
- Extraintestinale manifestations
- Pouchitis

FINALLY RETURN up to 2 TOP choices of biological drugs given patient profile and context. Explain the PROS and CONS of the 2 choices.
If answer cannot be derived from context, RETURN "NO ANSWER" and explain reason.
=========
OUTPUT INSTRUCTIONS:
Output your answer as a list of JSON objects with keys: drug_name, advantages, disadvantages.
=========
CONTEXT:
{context_str}
=========
"""

human_template = "PATIENT PROFILE: {query_str}"
messages = [
    ChatMessage(role=MessageRole.SYSTEM, content=system_template),
    ChatMessage(role=MessageRole.USER, content=human_template)   
]

CHAT_PROMPT_TEMPLATE = ChatPromptTemplate(messages)

In [135]:
# Service Context
llm = OpenAI(temperature=temperature, model=model_name, max_tokens=max_tokens)
embs = OpenAIEmbedding()
service_context = ServiceContext.from_defaults(
    llm=llm, embed_model=embs, callback_manager=callback_manager
)

response_synthesizer = get_response_synthesizer(
    service_context=service_context,
    response_mode=response_mode,
    text_qa_template=CHAT_PROMPT_TEMPLATE,
    verbose=False
)

## Setup Query Engine

In [153]:
# Query Engine
query_engine = RetrieverQueryEngine(
    retriever=kg_retriever,
    response_synthesizer=response_synthesizer
    # node_postprocessors=LongContextReorder()
)

## Run Test Cases

In [None]:
token_counter.reset_counts()
responses = []
for test_case in test_cases:
    print("Test Case:", test_case)
    response = query_engine.query(test_case)
    responses.append(response)
    
print_token_usage(token_counter)

In [175]:
questions = test_cases
answers = [response.response for response in responses]
source_documents = [[process_document_node(node) for node in response.source_nodes]
                    for response in responses]

json_results = [dict(question=question, answer=answer, source_document=source_document)
    for question, answer, source_document in zip(questions, answers, source_documents)
]

In [179]:
if not os.path.exists(save_path):
    os.makedirs(save_path, exist_ok=True)

with open(os.path.join(save_path, "results.json"), "w") as f:
    json.dump(json_results, f)

In [None]:
import pandas as pd

ground_truth = pd.read_csv(gt_path, encoding="ISO-8859-1")

info = {"question": questions}

info["gt_rec1"] = ground_truth["Recommendation 1"].tolist()
info["gt_rec2"] = ground_truth["Recommendation 2"].tolist()
info["gt_rec3"] = ground_truth["Recommendation 3"].tolist()
info["gt_avoid"] = ground_truth["Drug Avoid"].tolist()
info["gt_reason"] = ground_truth["Reasoning"].tolist()

info["prompt"] = [convert_prompt_to_string(CHAT_PROMPT_TEMPLATE)] * len(questions)

pd_answers = [[], []]

list_parser = CommaSeparatedListOutputParser()

for answer in answers:
    fixed_answer = fixing_chain(answer)["text"]
    if "NO ANSWER" in fixed_answer or fixed_answer=="":
        drugs = []
    else:
        drugs = list_parser.parse(fixed_answer)

    pd_answers[0].append(drugs[0] if len(drugs) > 0 else None)
    pd_answers[1].append(drugs[1] if len(drugs) > 1 else None)

info["raw_answer"] = answers
info["answer1"] = pd_answers[0]
info["answer2"] = pd_answers[1]

panda_df = pd.DataFrame(info)

panda_df.to_csv(os.path.join(save_path, "results.csv"), header=True)

# Experiment 2: KGRagRetriever

In [35]:
with open(os.path.join(MAIN_DIR, "data", "queries", "uc_all.txt"), "r") as f:
    test_cases = f.readlines()

## Experiment Settings

In [46]:
import tiktoken
from llama_index.callbacks import CallbackManager, TokenCountingHandler

temperature = 0
synthesizer_model = "gpt-4"
kg_rag_model = "gpt-3.5-turbo"
max_tokens = 512
description = "KG_Chat4_RAG_Retrieval"
time = datetime.now().strftime("%d-%m-%Y-%H-%M")
save_path = os.path.join(ARTIFACT_DIR, f"{model_name}_{description}_{time}")
print("Save directory:", save_path)
gt_path = os.path.join(MAIN_DIR, "data", "queries", "uc_all_gt.csv")
response_mode = "simple_summarize"

token_counter = TokenCountingHandler(
    tokenizer=tiktoken.encoding_for_model(model_name).encode
)
callback_manager = CallbackManager([token_counter])

Save directory: /mnt/c/Users/QUAN/Desktop/medical-chatbot/artifacts/gpt-4_KG_Chat4_RAG_Retrieval_21-10-2023-15-30


## KG Retriever

### Load KG Index

In [37]:
from llama_index import load_index_from_storage

## Load Knowledge Graph
graph_args = dict(
    space_name="uc_recommendations",
    edge_types=["relationship"],
    rel_prop_names=["relationship"],
    tags=["entity"]
)

# Setup Storage Context
graph_store = NebulaGraphStore(
    **graph_args
)

storage_context = StorageContext.from_defaults(
    persist_dir=KG_PERSIST_DIR,
    graph_store=graph_store
)

kg_index = load_index_from_storage(
    storage_context=storage_context, **graph_args, verbose=True
)

### Setup Retriever

In [42]:
from llama_index.retrievers import KnowledgeGraphRAGRetriever
from llama_index.llms import OpenAI

max_entities = 5
max_synonyms = 5
graph_traversal_depth = 2
retriever_mode = "keyword"

with_nl2graphquery = False
verbose = False

graph_rag_retriever = KnowledgeGraphRAGRetriever(
    storage_context=storage_context,
    service_context=service_context,
    llm=OpenAI(
        model=kg_rag_model,
        temperature=temperature, max_tokens=max_tokens),
    verbose=False,
)

In [None]:
sample_testcase = test_cases[0]
sample_retrieved_nodes =  graph_rag_retriever.retrieve(sample_testcase)
sample_retrieved_nodes

## Setup Response Synthesizer

### Prompt

In [43]:
from llama_index.llms import ChatMessage, MessageRole
from llama_index.prompts import ChatPromptTemplate

system_template = """
You are a physician assistant giving advice on treatment for moderate to severe ulcerative colitis (UC).
Make reference to the CONTEXT given to assess the scenario.
If the answer cannot be inferred from CONTEXT, return "NO ANSWER", don't try to make up an answer.
=========
TASK: ANALYSE the given patient profile based on given query based on one of the following criteria:
- Whether treated patient is new patient or patient under maintenance
- Prior response to Infliximab
- Prior failure to Anti-TNF agents
- Prior failure to Vedolizumab
- Age
- Pregnancy
- Extraintestinale manifestations
- Pouchitis

FINALLY RETURN up to 2 TOP choices of biological drugs given patient profile and context. Explain the PROS and CONS of the 2 choices.
If answer cannot be derived from context, RETURN "NO ANSWER" and explain reason.
=========
OUTPUT INSTRUCTIONS:
Output your answer as a list of JSON objects with keys: drug_name, advantages, disadvantages.
=========
CONTEXT:
{context_str}
=========
"""

human_template = "PATIENT PROFILE: {query_str}"
messages = [
    ChatMessage(role=MessageRole.SYSTEM, content=system_template),
    ChatMessage(role=MessageRole.USER, content=human_template)   
]

CHAT_PROMPT_TEMPLATE = ChatPromptTemplate(messages)

In [47]:
from llama_index import get_response_synthesizer

# Service Context
embs = OpenAIEmbedding()
service_context = ServiceContext.from_defaults(
    llm=OpenAI(
        model = synthesizer_model,
        temperature=temperature, max_tokens=max_tokens),
    embed_model=embs, callback_manager=callback_manager
)

response_synthesizer = get_response_synthesizer(
    service_context=service_context,
    response_mode=response_mode,
    text_qa_template=CHAT_PROMPT_TEMPLATE,
    verbose=False
)

## Setup Query Engine

In [50]:
# Query Engine
# from llama_index.indices.postprocessor import LongContextReorder
from llama_index.query_engine import RetrieverQueryEngine

kg_rag_query_engine = RetrieverQueryEngine(
    retriever=graph_rag_retriever,
    response_synthesizer=response_synthesizer
    # node_postprocessors=LongContextReorder()
)

## Run Test Cases

In [51]:
token_counter.reset_counts()
responses = []
for test_case in test_cases:
    print("Test Case:", test_case)
    response = kg_rag_query_engine.query(test_case)
    responses.append(response)
    
print_token_usage(token_counter)

Test Case: 40 year old male with newly diagnosed moderate UC and articular extraintestinal manifestations

[1;3;32mEntities processed: ['moderate UC', 'male', 'old', '40 year old male', 'newly diagnosed', 'newly', 'articular extraintestinal manifestations', 'year', '40', 'extraintestinal', 'UC', 'manifestations', 'articular', 'diagnosed', 'moderate']
[0m[1;3;32mEntities processed: ['MODERATE ulcerative colitis', 'Moderate', 'YEAR', 'Articular Extraintestinal Manifestationses', "MODERATE ulcerative colitises'", '40 Year Old Males', '40 year old males', '40', 'Diagnoses', 'Moderate ulcerative colitis', '40 YEAR OLD MALE', 'MODERATE ulcerative colitises', "moderate ulcerative colitises'", "moderate ulcerative colitis's", 'moderate UCs', 'MODERATE uc', 'newly', 'uC', '40s', 'news', 'MODERATE', 'moderate uc', 'Articular', "moderate Ulcerative Colitis's", 'Moderate UCs', 'NEWLY DIAGNOSES', 'ARTICULAR EXTRAINTESTINAL MANIFESTATION', 'NEWLY', 'ULCERATIVE COLITI', 'diagnosed', 'Diagnosing', 

In [52]:
questions = test_cases
answers = [response.response for response in responses]
source_documents = [[process_document_node(node) for node in response.source_nodes]
                    for response in responses]

json_results = [dict(question=question, answer=answer, source_document=source_document)
    for question, answer, source_document in zip(questions, answers, source_documents)
]

In [53]:
if not os.path.exists(save_path):
    os.makedirs(save_path, exist_ok=True)

with open(os.path.join(save_path, "results.json"), "w") as f:
    json.dump(json_results, f)

In [54]:
import pandas as pd

ground_truth = pd.read_csv(gt_path, encoding="ISO-8859-1")

info = {"question": questions}

info["gt_rec1"] = ground_truth["Recommendation 1"].tolist()
info["gt_rec2"] = ground_truth["Recommendation 2"].tolist()
info["gt_rec3"] = ground_truth["Recommendation 3"].tolist()
info["gt_avoid"] = ground_truth["Drug Avoid"].tolist()
info["gt_reason"] = ground_truth["Reasoning"].tolist()

info["prompt"] = [convert_prompt_to_string(CHAT_PROMPT_TEMPLATE)] * len(questions)

pd_answers = [[], []]

list_parser = CommaSeparatedListOutputParser()

for answer in answers:
    fixed_answer = fixing_chain(answer)["text"]
    if "NO ANSWER" in fixed_answer or fixed_answer=="":
        drugs = []
    else:
        drugs = list_parser.parse(fixed_answer)

    pd_answers[0].append(drugs[0] if len(drugs) > 0 else None)
    pd_answers[1].append(drugs[1] if len(drugs) > 1 else None)

info["raw_answer"] = answers
info["answer1"] = pd_answers[0]
info["answer2"] = pd_answers[1]

panda_df = pd.DataFrame(info)

panda_df.to_csv(os.path.join(save_path, "results.csv"), header=True)

# Experiment 3: Combined Vector and KG Retriever

## Settings

In [116]:
# General
temperature = 0
max_tokens = 512
verbose = False

# Retrieval Settings
## KG
kg_llm = "gpt-3.5-turbo"
kg_retrieval_mode = "hybrid"
max_keywords_per_query = 10
num_chunks_per_query = 3
kg_similarity_top_k = 2
graph_store_query_depth = 2
include_text = True
verbose = False

## Vector
vector_similarity_top_k = 4
vector_retrieval_mode = "default"

# Synthesizer Settings
synthesize_llm = "gpt-4"
response_mode = "simple_summarize"

# Experiment Settings
description = "CombinedKGVector_Chat4"
time = datetime.now().strftime("%d-%m-%Y-%H-%M")
save_path = os.path.join(ARTIFACT_DIR, f"{model_name}_{description}_{time}")
print("Save directory:", save_path)
gt_path = os.path.join(MAIN_DIR, "data", "queries", "uc_all_gt.csv")

oken_counter = TokenCountingHandler(
    tokenizer=tiktoken.encoding_for_model(model_name).encode
)
callback_manager = CallbackManager([token_counter])

Save directory: /mnt/c/Users/QUAN/Desktop/medical-chatbot/artifacts/gpt-4_CombinedKGVector_Chat4_21-10-2023-17-21


In [104]:
from llama_index.retrievers import KGTableRetriever, VectorIndexRetriever, BaseRetriever
from llama_index import QueryBundle
from llama_index.schema import NodeWithScore
from typing import Literal

class CustomCombinedRetriever(BaseRetriever):
    def __init__(
        self,
        vector_retriever: VectorIndexRetriever,
        kg_retriever: KGTableRetriever,
        mode: Literal["AND", "OR"] = "OR"
    ) -> None:
        self.vector_retriever = vector_retriever
        self.kg_retriever = kg_retriever
        self._mode = mode
        
    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""

        vector_nodes = self.vector_retriever.retrieve(query_bundle)
        kg_nodes = self.kg_retriever.retrieve(query_bundle)
        
        vector_ids = {n.node.node_id for n in vector_nodes}
        kg_ids = {n.node.node_id for n in kg_nodes}

        combined_dict = {n.node.node_id: n for n in vector_nodes}
        combined_dict.update({n.node.node_id: n for n in kg_nodes})

        if self._mode == "AND":
            retrieve_ids = vector_ids.intersection(kg_ids)
        else:
            retrieve_ids = vector_ids.union(kg_ids)

        retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]
        return retrieve_nodes

In [134]:
graph_args = dict(
    space_name="uc_recommendations",
    edge_types=["relationship"],
    rel_prop_names=["relationship"],
    tags=["entity"]
)

# Setup Storage Context
graph_store = NebulaGraphStore(**graph_args)

kg_storage_context = StorageContext.from_defaults(
    persist_dir=KG_PERSIST_DIR, graph_store=graph_store
)

kg_index = load_index_from_storage(
    storage_context=kg_storage_context, **graph_args, verbose=True
)

vector_store = FaissVectorStore.from_persist_dir(VECTOR_PERSIST_DIR)
vector_storage_context = StorageContext.from_defaults(
    vector_store=vector_store, persist_dir=VECTOR_PERSIST_DIR
)

vector_index = load_index_from_storage(storage_context=vector_storage_context)

In [136]:
vector_retriever = VectorIndexRetriever(
    index=vector_index,
    similarity_top_k=vector_similarity_top_k,
    vector_store_query_mode=vector_retrieval_mode,
    verbose=verbose
)

kg_retriever = KGTableRetriever(
    index=kg_index,
    retrieval_mode=kg_retrieval_mode,
    max_keywords_per_query=max_keywords_per_query,
    num_chunks_per_query=num_chunks_per_query,
    similarity_top_k=kg_similarity_top_k,
    graph_store_query_depth=graph_store_query_depth,
    include_text=True,
    verbose=verbose
)

custom_retriever = CustomCombinedRetriever(
    vector_retriever=vector_retriever, kg_retriever=kg_retriever,
)

## Setup Response Synthesizer

### Prompt

In [137]:
system_template = """
You are a physician assistant giving advice on treatment for moderate to severe ulcerative colitis (UC).
Make reference to the CONTEXT given to assess the scenario.
If the answer cannot be inferred from CONTEXT, return "NO ANSWER", don't try to make up an answer.
=========
TASK: ANALYSE the given patient profile based on given query based on one of the following criteria:
- Whether treated patient is new patient or patient under maintenance
- Prior response to Infliximab
- Prior failure to Anti-TNF agents
- Prior failure to Vedolizumab
- Age
- Pregnancy
- Extraintestinale manifestations
- Pouchitis

FINALLY RETURN up to 2 TOP choices of biological drugs given patient profile and context. Explain the PROS and CONS of the 2 choices.
If answer cannot be derived from context, RETURN "NO ANSWER" and explain reason.
=========
OUTPUT INSTRUCTIONS:
Output your answer as a list of JSON objects with keys: drug_name, advantages, disadvantages.
=========
CONTEXT:
{context_str}
=========
"""

human_template = "PATIENT PROFILE: {query_str}"
messages = [
    ChatMessage(role=MessageRole.SYSTEM, content=system_template),
    ChatMessage(role=MessageRole.USER, content=human_template)   
]

CHAT_PROMPT_TEMPLATE = ChatPromptTemplate(messages)

In [154]:
# Service Context
synthesizer_service_context = ServiceContext.from_defaults(
    llm=OpenAI(temperature=temperature, model=synthesize_llm, max_tokens=max_tokens),
    embed_model=OpenAIEmbedding(), callback_manager=callback_manager
)

synthesizer_service_context_long = ServiceContext.from_defaults(
    llm=OpenAI(temperature=temperature, model="gpt-3.5-turbo-16k", max_tokens=max_tokens),
    embed_model=OpenAIEmbedding(), callback_manager=callback_manager
)

combined_response_synthesizer = get_response_synthesizer(
    service_context=synthesizer_service_context,
    response_mode=response_mode,
    text_qa_template=CHAT_PROMPT_TEMPLATE,
    verbose=False
)

combined_response_synthesizer_long = get_response_synthesizer(
    service_context=synthesizer_service_context_long,
    response_mode=response_mode,
    text_qa_template=CHAT_PROMPT_TEMPLATE,
    verbose=False
)

## Setup Query Engine

In [155]:
# Query Engine
combined_query_engine = RetrieverQueryEngine(
    retriever=custom_retriever,
    response_synthesizer=combined_response_synthesizer
    # node_postprocessors=LongContextReorder()
)

combined_query_engine_long = RetrieverQueryEngine(
    retriever=custom_retriever,
    response_synthesizer=combined_response_synthesizer_long
    # node_postprocessors=LongContextReorder()
)

## Run Test Cases

In [159]:
GPT4_TOKEN_LIMIT = 8000
# token_counter.reset_counts()
# responses = []
for test_case in test_cases[11:]:
    print("Test Case:", test_case)
    retrieved_nodes = custom_retriever.retrieve(test_case)
    total_prompt_tokens = count_document_tokens(convert_prompt_to_string(CHAT_PROMPT_TEMPLATE)) + count_document_tokens(retrieved_nodes)
    if total_prompt_tokens <= GPT4_TOKEN_LIMIT:
        response = combined_query_engine.query(test_case)
    else:
        response = combined_query_engine_long.query(test_case)
    responses.append(response)
    
print_token_usage(token_counter)

Test Case: 52 year-old woman with moderate to severe distal ulcerative colitis that had a successful induction with vedolizumab. What would be the maintenance therapy?

Test Case: 24 year-old man with moderate to severe extensive ulcerative colitis previously in clinical remission with infliximab develops loss of response due to antibody formation.

Test Case: 44 year-old woman with moderate to severe extensive ulcerative colitis and rheumatoid arthritis.

Test Case: 55 year-old man with moderate to severe extensive ulcerative colitis who avlues convenience and limited time spent in hospital

Test Case: 60 year-old woman with severe ulcerative colitis that has loss response to anti-TNF, vedolizumab and ustekinumab

Test Case: 36 year-old man with moderate to severe extensive ulcerative colitis and spondylarthristis.

Test Case: 42 year-old woman with moderate ulcerative colitis on azathioprine and not responding to therapy

Test Case: 53 year-old man with moderate to severe extensive u

In [160]:
questions = test_cases
answers = [response.response for response in responses]
source_documents = [[process_document_node(node) for node in response.source_nodes]
                    for response in responses]

json_results = [dict(question=question, answer=answer, source_document=source_document)
    for question, answer, source_document in zip(questions, answers, source_documents)
]

In [161]:
if not os.path.exists(save_path):
    os.makedirs(save_path, exist_ok=True)

with open(os.path.join(save_path, "results.json"), "w") as f:
    json.dump(json_results, f)

In [162]:
import pandas as pd

ground_truth = pd.read_csv(gt_path, encoding="ISO-8859-1")

info = {"question": questions}

info["gt_rec1"] = ground_truth["Recommendation 1"].tolist()
info["gt_rec2"] = ground_truth["Recommendation 2"].tolist()
info["gt_rec3"] = ground_truth["Recommendation 3"].tolist()
info["gt_avoid"] = ground_truth["Drug Avoid"].tolist()
info["gt_reason"] = ground_truth["Reasoning"].tolist()

info["prompt"] = [convert_prompt_to_string(CHAT_PROMPT_TEMPLATE)] * len(questions)

pd_answers = [[], []]

list_parser = CommaSeparatedListOutputParser()

for answer in answers:
    fixed_answer = fixing_chain(answer)["text"]
    if "NO ANSWER" in fixed_answer or fixed_answer=="":
        drugs = []
    else:
        drugs = list_parser.parse(fixed_answer)

    pd_answers[0].append(drugs[0] if len(drugs) > 0 else None)
    pd_answers[1].append(drugs[1] if len(drugs) > 1 else None)

info["raw_answer"] = answers
info["answer1"] = pd_answers[0]
info["answer2"] = pd_answers[1]

panda_df = pd.DataFrame(info)

panda_df.to_csv(os.path.join(save_path, "results.csv"), header=True)