In [None]:
!pip install -r requirements.txt -q

In [None]:
# My LangChain credentials to access the vignettes dataset
%env LANGSMITH_TRACING=true
%env LANGSMITH_ENDPOINT=https://api.smith.langchain.com
%env LANGSMITH_API_KEY=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
%env LANGSMITH_PROJECT=pr-shadowy-nightgown-69

In [None]:
from langchain_ollama import ChatOllama
from langsmith import Client

llm = ChatOllama(
    model="llama3.2"
)

In [None]:
from llmconstants import ClinicalLLMConstants, PatientLLMConstants
cllm = ClinicalLLMConstants()
pllm = PatientLLMConstants()

In [None]:
client = Client()
rows_iter = client.list_examples(dataset_name="casevignettes")
rows_arr = list(rows_iter)
len(rows_arr)

In [None]:
from typing import List, Optional, TypedDict
from langchain_core.messages import BaseMessage

class ExtendedMessagesState(TypedDict):
    messages: List[BaseMessage]
    clinical_instructions: Optional[str]
    patient_instructions: Optional[str]
    clinical_messages: Optional[str]
    patient_messages: Optional[str]
    termination: Optional[bool]

In [None]:
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, END, StateGraph
from langchain_core.messages import SystemMessage, HumanMessage

# With an LLM definition, a dataset, and an array of the max number of questions the clinician can ask the patient, this function will evaluate the accuracy and return an array
# in the shape of the num_qs, giving accuracy for each value in num_qs
def run_with_llm_model_as_(llm, rows_arr, num_qs):
    acc_rates = []

    for num_q in num_qs:
        print(f"Beginning diagnoses with LLM question limit at {num_q}!")

        workflow = StateGraph(state_schema=ExtendedMessagesState)

        def call_cllm_model(state: ExtendedMessagesState):
            clinical_history = state.get("clinical_messages", [])
            if not clinical_history:
                clinical_history = [SystemMessage(content=cllm.get_initial_prompt(num_q))]

            data_specific_instructions = state.get("clinical_instructions", "")
            if data_specific_instructions and len(clinical_history) <= 1:
                clinical_history.append(HumanMessage(content=data_specific_instructions))
            
            # Conversational turns between clinician and patient
            chat_messages = state.get("messages", [])
            
            invoke_llm = clinical_history + chat_messages

            response = llm.invoke(invoke_llm)
            chat_messages.append(response)
            state["clinical_messages"] = clinical_history
            state["messages"] = chat_messages

            return {"clinical_messages": clinical_history, "messages": chat_messages, "termination": "Final Diagnosis" in response.content}

        def call_patient_model(state: ExtendedMessagesState):
            patient_history = state.get("patient_messages", [])
            if not patient_history:
                patient_history = [SystemMessage(content=pllm.get_initial_prompt())]
            data_specific_instructions = state.get("patient_instructions", "")
            if data_specific_instructions and len(patient_history) <= 1:
                patient_history.append(SystemMessage(content=data_specific_instructions))
            
            # Conversational turns between clinician and patient
            chat_messages = state.get("messages", [])
            question = chat_messages[-1]
            invoke_llm = patient_history + [HumanMessage(content=question.content)]

            response = llm.invoke(invoke_llm)
            chat_messages.append(HumanMessage(content=response.content))
            state["patient_messages"] = patient_history
            state["messages"] = chat_messages

            return {"patient_messages": patient_history, "messages": chat_messages, "termination": state.get("termination", False)}

        workflow.add_node("clinical_model", call_cllm_model)
        workflow.add_node("patient_model", call_patient_model)

        def should_continue(state: ExtendedMessagesState) -> str:
            if state.get("termination", False):
                return END
            return "patient_model"

        def should_continue_back(state: ExtendedMessagesState) -> str:
            if state.get("termination", False):
                return END
            return "clinical_model"

        # Add conditional edges
        workflow.add_edge(START, "clinical_model")
        workflow.add_conditional_edges("clinical_model", should_continue)
        workflow.add_conditional_edges("patient_model", should_continue_back)

        memory = MemorySaver()
        app = workflow.compile(checkpointer=memory)

        rows = rows_arr # temporary

        from concurrent.futures import ThreadPoolExecutor

        def execute(i):
            try:
                if not rows[i].outputs or "answer" not in rows[i].outputs:
                    return 0

                props = rows[i].inputs
                resp = app.invoke(
                    {
                        "clinical_instructions": f"{cllm.get_specialty(props['category'])} At the end, you will choose one diagnosis from these four options, separated by slashes:  (A) {props['choice_1']} / (B) {props['choice_2']} / (C) {props['choice_3']} / (D) {props['choice_4']}. IMPORTANT: At the end, respond ONLY with the letter of your choice — exactly one of A, B, C, or D, on its own line. No explanation. No parentheses. Just the letter.",
                        "patient_instructions": f"**Case vignette**: {pllm.get_case_vignette(props['case_vignette'])}"
                    },
                    config={"configurable": {"thread_id": i}, "recursion_limit": 300}
                )
                p_answer = rows[i].outputs["answer"]
                if p_answer == props["choice_1"]:
                    answer = "A"
                elif p_answer == props["choice_2"]:
                    answer = "B"
                elif p_answer == props["choice_3"]:
                    answer = "C"
                elif p_answer == props["choice_4"]:
                    answer = "D"
                else:
                    # raise Exception(f"Answer was not located. The prospective answer was {props['answer']}")
                    return 0
                lm = resp["messages"][-1].content
                if " A" in lm:
                    diagnosis = "A"
                elif " B" in lm:
                    diagnosis = "B"
                elif " C" in lm:
                    diagnosis = "C"
                elif " D" in lm:
                    diagnosis = "D"
                else:
                    # raise Exception(f"Diagnosis was not located. The text involved was: {lm}")
                    return 0

                if diagnosis == answer:
                    return 1
                return 0
            except:
                return 0

        with ThreadPoolExecutor() as executor:
            results = list(executor.map(execute, range(len(rows))))

        success_rate = (sum(results) / len(results)) * 100 if results else 0
        print(f"Success Rate: {success_rate:.2f}%")

        acc_rates.append(success_rate / 100)

    return acc_rates

In [None]:
import numpy as np
np.random.seed(42)
minibatch = np.random.choice(rows_arr, size=50)

In [None]:
run_with_llm_model_as_(llm=llm, rows_arr=minibatch, num_qs=[5, 10, 15, 20])

For Llama3.2, the results were [30%, 32%, 28%, 40%] when the clinician was allowed to ask up to [5, 10, 15, 20] questions to the patient respectively.

In [None]:
# To try a different model, customize it here:
my_llm = ChatOllama(
    model="MODEL_NAME"
)
# Set the data to be all case vignettes?
# minibatch = rows_arr
run_with_llm_model_as_(llm=my_llm, rows_arr=minibatch, num_qs=[5, 10, 15, 20])