In [1]:
import os,json

from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate

In [87]:
MAIN_DIR = ".."
DATA_DIR = os.path.join(MAIN_DIR, "data")
ARTIFACT_DIR = os.path.join(MAIN_DIR, "artifacts")
DOCUMENTS_DIR = os.path.join(DATA_DIR, "document_sources")
EMB_DIR = os.path.join(DATA_DIR, "emb_store")

with open(os.path.join(MAIN_DIR, "auth", "api_keys.json"), "r") as f:
    api_keys = json.load(f)
    
os.environ["OPENAI_API_KEY"] = api_keys["OPENAI_API_KEY"]

## Setup VectorStore

In [50]:
from langchain.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.vectorstores import FAISS

embeddings = OpenAIEmbeddings()
llm = ChatOpenAI(model_name="gpt-4", temperature=0, max_tokens=512)

In [4]:
vectordb_params = dict(
    chunk_size=1024,
    chunk_overlap=128,
)

emb_db_path = os.path.join(EMB_DIR, "faiss",
                           "openai_{}_{}".format(vectordb_params["chunk_size"],
                                                 vectordb_params["chunk_overlap"]))

In [75]:
# Create document stores
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

documents = [
    "ACR low back pain.pdf",
    "ACR suspected spine infection.pdf",
    "ACR suspected spine trauma.pdf"
    ]

texts = []

for document in documents:
    lbp_path = os.path.join(DOCUMENTS_DIR, document)
    docs = PyPDFLoader(lbp_path).load()
    print("Number of document pages:", len(docs))
    text_splitter = RecursiveCharacterTextSplitter(**vectordb_params)
    texts.extend(text_splitter.split_documents(docs))

print("Number of text chunks:", len(texts))

if not os.path.exists(emb_db_path):
    os.makedirs(emb_db_path, exist_ok=True)

docsearch = FAISS.from_documents(texts, embeddings)
docsearch.save_local(emb_db_path)

Number of document pages: 22
Number of document pages: 17
Number of document pages: 25
Number of text chunks: 321


In [76]:
docsearch = FAISS.load_local(emb_db_path, embeddings=embeddings)

In [77]:
retriever = docsearch.as_retriever(
    search_kwargs={"k": 5}
)

patient_profile = ("48 year old Indian male.  CEO of shipping company.  Recreational cricket player. "
                "Prior history of fatty liver, hyperlipidemia on statins. "
                "Now presenting with low back pain for 3 days post cricket match. "
                "No radiation to groin or lower limbs.  Able to walk.  Urination and bowel motion are ok. "
                "No fever.  On examination, straight leg raise test is positive. "
                "Power in bilateral lower limbs is full.  Digital rectal examination shows good anal tone. "
                "No prior imaging.  For MRI lumbar spine with IV contrast to assess for prolapsed intervertebral disc or muscle strain. "
                )
scan_order = "MRI lumbar spine with IV contrast"

sample_query = """PATIENT PROFILE: {}
MRI SCAN ORDERED: {}
""".format(patient_profile, scan_order)

print("Sample Query:", sample_query)
print()
docs = retriever.get_relevant_documents(sample_query)
for doc in docs:
    print()
    print(doc.page_content)

Sample Query: PATIENT PROFILE: 48 year old Indian male.  CEO of shipping company.  Recreational cricket player. Prior history of fatty liver, hyperlipidemia on statins. Now presenting with low back pain for 3 days post cricket match. No radiation to groin or lower limbs.  Able to walk.  Urination and bowel motion are ok. No fever.  On examination, straight leg raise test is positive. Power in bilateral lower limbs is full.  Digital rectal examination shows good anal tone. No prior imaging.  For MRI lumbar spine with IV contrast to assess for prolapsed intervertebral disc or muscle strain. 
MRI SCAN ORDERED: MRI lumbar spine with IV contrast



ACR Appropriateness  Criteria® 12 Low Back  Pain spinal  canal  patency.  MRI lumbar  spine  without  IV contrast  is most  useful  in the evaluation  of suspected  CES, 
multifocal deficit, or progressive  neurologic  deficit  because of  its ability  to accurately  depict  soft-tissue pathology,  
assess vertebral  marrow,  and assess the spina

In [78]:
system_template = """
You are a radiologist expert at providing imaging recommendations for patients with musculoskeletal conditions.
If you do not know an answer, just say "I dont know", do not make up an answer.
==========
TASK:
1. Extract from given PATIENT PROFILE relevant information for classification of imaging appropriateness.
Important information includes AGE, SYMPTOMS, DIAGNOSIS (IF ANY), which stage of diagnosis (INITIAL IMAGING OR NEXT STUDY).
2. Refer to the reference information given under CONTEXT to analyse the appropriate imaging recommendations given the patient profile.
3. Recommend if the image scan ordered is appropriate given the PATIENT PROFILE and CONTEXT. If the scan is not appropriate, recommend an appropriate procedure.
STRICTLY answer based on the given PATIENT PROFILE and CONTEXT.
==========
OUTPUT INSTRUCTIONS:
Your output should contain the following:
1. Classification of appropriateness for the ordered scan. Can be one of [USUALLY APPROPRIATE, MAY BE APPROPRIATE, USUALLY NOT APPROPRIATE, INSUFFICIENT INFORMATION]
2. Provide explanation for the appropriateness classification.
3. If classification answer is USUALLY NOT APPROPRIATE, either recommend an alternative appropriate scan procedure or return NO SCAN REQUIRED.
==========
CONTEXT:
{context}
==========
"""

human_template = "{question}"

PROMPT_TEMPLATE = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate.from_template(system_template),
        HumanMessagePromptTemplate.from_template(human_template)
    ]
)

print(PROMPT_TEMPLATE.format(context="context", question="question"))

System: 
You are a radiologist expert at providing imaging recommendations for patients with musculoskeletal conditions.
If you do not know an answer, just say "I dont know", do not make up an answer.
TASK:
1. Extract from given PATIENT PROFILE relevant information for classification of imaging appropriateness.
Important information includes AGE, SYMPTOMS, DIAGNOSIS (IF ANY), which stage of diagnosis (INITIAL IMAGING OR NEXT STUDY).
2. Refer to the reference information given under CONTEXT to analyse the appropriate imaging recommendations given the patient profile.
3. Recommend if the image scan ordered is appropriate given the PATIENT PROFILE and CONTEXT. If the scan is not appropriate, recommend an appropriate procedure.
STRICTLY answer based on the given PATIENT PROFILE and CONTEXT.
OUTPUT INSTRUCTIONS:
Your output should contain the following:
1. Classification of appropriateness for the ordered scan. Can be one of [USUALLY APPROPRIATE, MAY BE APPROPRIATE, USUALLY NOT APPROPRIATE,

In [79]:
from langchain.chains import RetrievalQA
from typing import List, Optional
from langchain.schema import Document, BaseDocumentTransformer
from langchain.callbacks.manager import CallbackManagerForChainRun

class ReOrderQARetrieval(RetrievalQA):
    
    reorder_fn: Optional[BaseDocumentTransformer] = None
    
    def _get_docs(
        self,
        question: str,
        *,
        run_manager: CallbackManagerForChainRun,
    ) -> List[Document]:
        """Get docs."""
        print(question)
        docs = self.retriever.get_relevant_documents(
            question, callbacks=run_manager.get_child()
        )
        
        docs = self.reorder_fn.transform_documents(docs) if self.reorder_fn else docs
     
        return docs

In [102]:
if not os.environ.get("123"):
    print(1)

1


In [80]:
from langchain.document_transformers import LongContextReorder

qa_chains = ReOrderQARetrieval.from_chain_type(
    llm=llm,
    retriever=docsearch.as_retriever(search_kwargs={"k": 5}),
    reorder_fn = LongContextReorder(),
    chain_type="stuff",b
    chain_type_kwargs=dict(
        document_variable_name = "context",
        prompt=PROMPT_TEMPLATE
    ),
    input_key="question",
    return_source_documents = True,
    verbose=True
)

In [83]:
with open(os.path.join(DATA_DIR, "queries", "lbp_cases.json"), "r") as f:
    lbp_testcases = json.load(f)

In [84]:
responses = []
for test_case in lbp_testcases:
    sample_query = """PATIENT PROFILE: {}
    MRI SCAN ORDERED: {}
    """.format(test_case["patient_profile"], test_case["scan_order"])
    response = qa_chains(sample_query)
    responses.append(response)



[1m> Entering new ReOrderQARetrieval chain...[0m
PATIENT PROFILE: 41 year old Chinese male.  No prior medical history.  Smoker.  Now coming into emergency room for road traffic accident. Patient was driving company car when rear-ended by a truck.  Car hit road divider.  On examintion, patient is alert, but power in bilateral lower limbs is 2 out of 5.  Slightly lax anal tone on digital rectal examination.  X-rays show T12 chance fracture.  MRI thoracic and lumbar spine without IV contrast to assess for neurologic injury.
    MRI SCAN ORDERED: MRI thoracic and lumbar spine without IV contrast
    

[1m> Finished chain.[0m


[1m> Entering new ReOrderQARetrieval chain...[0m
PATIENT PROFILE: 48 year old Indian male.  CEO of shipping company.  Recreational cricket player.  Prior history of fatty liver, hyperlipidemia on statins.   Now presenting with low back pain for 3 days post cricket match.  No radiation to groin or lower limbs.  Able to walk.  Urination and bowel motion are ok.

In [92]:
for response in responses:
    response["source_documents"] = [doc.page_content for doc in response["source_documents"]]
    
with open(os.path.join(ARTIFACT_DIR, "lbp_sample_results.json"), "w") as f:
    json.dump(responses, f)

In [97]:
import pandas as pd

lbp_sample_dict = {
    "question": [response["question"] for response in responses],
    "answer": [response["result"] for response in responses]
}

for idx in range(5):
    lbp_sample_dict[f"doc_{idx+1}"] = [response["source_documents"][idx] for response in responses]
    
lbp_sample_df = pd.DataFrame(lbp_sample_dict)

In [100]:
lbp_sample_df.to_csv(os.path.join(ARTIFACT_DIR, "lbp_sample_results.csv"), header=True)