<a href="https://colab.research.google.com/github/graphlit/graphlit-samples/blob/main/python/Notebook%20Examples/Graphlit_2024_09_21_CrewAI_Medical_AI_Agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Description**

This example shows how to use Graphlit with CrewAI. Based on this [great blog post](https://farzzy.hashnode.dev/building-healthcare-ai-agent-medical-guidelines) and [notebook](https://farzzy.hashnode.dev/building-healthcare-ai-agent-medical-guidelines) written by Farzad Sunavala, we emulate Farzad's workflow and show how to use Graphlit instead of LlamaIndex, while already leveraging Azure AI Search under the hood.

**Requirements**

Prior to running this notebook, you will need to [signup](https://docs.graphlit.dev/getting-started/signup) for Graphlit, and [create a project](https://docs.graphlit.dev/getting-started/create-project).

You will need the Graphlit organization ID, preview environment ID and JWT secret from your created project.

Assign these properties as Colab secrets: GRAPHLIT_ORGANIZATION_ID, GRAPHLIT_ENVIRONMENT_ID and GRAPHLIT_JWT_SECRET.

For CrewAI, you will need your OpenAI API key.

Assign this property as Colab secret: OPENAI_API_KEY

---

In [None]:
!pip install --upgrade graphlit-tools[crewai]

Clone the original repo to access the PDFs

In [None]:
!rm -rf azure-ai-search-python-playground
!git clone https://github.com/farzad528/azure-ai-search-python-playground

Initialize Graphlit

In [None]:
import os
from google.colab import userdata
from graphlit import Graphlit
from graphlit_api import input_types, enums, exceptions

os.environ['GRAPHLIT_ORGANIZATION_ID'] = userdata.get('GRAPHLIT_ORGANIZATION_ID')
os.environ['GRAPHLIT_ENVIRONMENT_ID'] = userdata.get('GRAPHLIT_ENVIRONMENT_ID')
os.environ['GRAPHLIT_JWT_SECRET'] = userdata.get('GRAPHLIT_JWT_SECRET')

graphlit = Graphlit()

Initialize sample data

In [None]:
DATA_DIRECTORY = "azure-ai-search-python-playground"

PATH_PDFS = DATA_DIRECTORY + "/data/acc/"

Define Graphlit helper functions

In [None]:
import base64
import mimetypes
import asyncio
import time
from typing import List, Optional
from tqdm.asyncio import tqdm

async def process_file(filename, progress_bar):
    pdf_path = os.path.join(PATH_PDFS, filename)

    print(f'Starting to ingest content from [{pdf_path}].')

    content_id = await ingest_file(pdf_path)

    if content_id is not None:
        print(f'Completed ingesting content [{content_id}] from [{pdf_path}].')

    progress_bar.update(1)

# NOTE: for local files, load from disk and convert to Base64 data
async def ingest_file(file_path: str):
    if graphlit.client is None:
        return;

    try:
        file_name = os.path.basename(file_path)
        content_name, _ = os.path.splitext(file_name)

        mime_type = mimetypes.guess_type(file_name)[0]

        if mime_type is None:
            print('Failed to infer MIME type')
            return None

        with open(file_path, "rb") as file:
            file_content = file.read()

        base64_content = base64.b64encode(file_content).decode('utf-8')

        response = await graphlit.client.ingest_encoded_file(content_name, base64_content, mime_type, is_synchronous=True)

        return response.ingest_encoded_file.id if response.ingest_encoded_file is not None else None
    except exceptions.GraphQLClientError as e:
        print(str(e))
        return None

async def create_openai_specification(model: enums.OpenAIModels):
    if graphlit.client is None:
        return;

    input = input_types.SpecificationInput(
        name=f"OpenAI [{str(model)}]",
        type=enums.SpecificationTypes.COMPLETION,
        serviceType=enums.ModelServiceTypes.OPEN_AI,
        openAI=input_types.OpenAIModelPropertiesInput(
            model=model,
        ),
        strategy=input_types.ConversationStrategyInput(
            embedCitations=True
        ),
        retrievalStrategy=input_types.RetrievalStrategyInput(
            type=enums.RetrievalStrategyTypes.SECTION
        ),
        rerankingStrategy=input_types.RerankingStrategyInput(
            serviceType=enums.RerankingModelServiceTypes.COHERE
        )
    )

    try:
        response = await graphlit.client.create_specification(input)

        return response.create_specification.id if response.create_specification is not None else None
    except exceptions.GraphQLClientError as e:
        print(str(e))
        return None

    return None

async def create_conversation(specification_id: str):
    if graphlit.client is None:
        return;

    input = input_types.ConversationInput(
        name="Conversation",
        specification=input_types.EntityReferenceInput(
            id=specification_id
        )
    )

    try:
        response = await graphlit.client.create_conversation(input)

        return response.create_conversation.id if response.create_conversation is not None else None
    except exceptions.GraphQLClientError as e:
        print(str(e))
        return None

async def prompt_conversation(conversation_id: str, prompt: str):
    if graphlit.client is None:
        return None, None

    try:
        response = await graphlit.client.prompt_conversation(prompt, conversation_id)

        message = response.prompt_conversation.message.message if response.prompt_conversation is not None and response.prompt_conversation.message is not None else None
        citations = response.prompt_conversation.message.citations if response.prompt_conversation is not None and response.prompt_conversation.message is not None else None

        return message, citations
    except exceptions.GraphQLClientError as e:
        print(str(e))
        return None, None

# NOTE: these functions are just used to clean-up old data before executing the example
async def delete_all_specifications():
    if graphlit.client is None:
        return;

    _ = await graphlit.client.delete_all_specifications(is_synchronous=True)

async def delete_all_conversations():
    if graphlit.client is None:
        return;

    _ = await graphlit.client.delete_all_conversations(is_synchronous=True)

async def delete_all_contents():
    if graphlit.client is None:
        return;

    _ = await graphlit.client.delete_all_contents(is_synchronous=True)

Ingest sample PDFs

In [None]:
import nest_asyncio

nest_asyncio.apply()

In [None]:
# Remove any existing contents; only needed for notebook example
await delete_all_contents()

print('Deleted all contents.')

# List of files to process
files = os.listdir(PATH_PDFS)

progress_bar = tqdm(total=len(files))

tasks = [process_file(filename, progress_bar) for filename in files]

await asyncio.gather(*tasks)

print('Ingested all contents.')

Copied from original notebook, for patient configuration

In [None]:
# Define complex patient profile
patient_profile = {
    "patient_id": "43454357890",
    "name": "Sarah Johnson",
    "age": 68,
    "gender": "Female",
    "height_cm": 165,
    "weight_kg": 72,
    "bmi": 26.4,
    "blood_type": "A-",
    "allergies": ["Sulfa drugs"],
    "current_medications": [
        "Flecainide",
        "Apixaban",
        "Metoprolol",
        "Atorvastatin",
        "Cilostazol",
    ],
    "chronic_conditions": {
        "atrial_fibrillation": True,
        "peripheral_arterial_disease": True,
        "hypertension": True,
        "hyperlipidemia": True,
        "coronary_artery_disease": False,
    },
    "family_medical_history": {"heart_disease": True, "stroke": True, "cancer": False},
    "lifestyle_factors": {
        "smoking": "Former smoker, quit 5 years ago",
        "alcohol_use": "Rare",
        "physical_activity_per_week": "4-5 days",
        "diet": "Heart-healthy, low-sodium",
        "sleep_hours_per_night": 6,
    },
    "recent_lab_results": {
        "inr": 1.4,
        "creatinine_mg_dl": 1.1,
        "ldl_cholesterol_mg_dl": 85,
        "hdl_cholesterol_mg_dl": 55,
        "blood_pressure_mm_hg": "135/80",
    },
    "vaccination_status": {
        "influenza_vaccine": True,
        "covid_vaccine": True,
        "pneumonia_vaccine": True,
    },
    "surgical_history": [
        "Carotid endarterectomy",
        "Cholecystectomy",
    ],
    "imaging_history": {
        "last_echocardiogram_date": "2024-07-15",
        "last_carotid_ultrasound_date": "2024-06-20",
    },
    "mental_health": {
        "anxiety": True,
        "depression": False,
        "cognitive_function_issues": False,
    },
    "preferences": {
        "preferred_treatment_type": [
            "Minimally invasive procedures",
            "Evidence-based treatments",
        ],
        "end_of_life_care": "Yes",
        "pain_management": "Non-opioid",
    },
    "recent_visits": [
        {
            "visit_date": "2024-08-25",
            "reason": "Follow-up for atrial fibrillation",
            "notes": "Recurrence of AF despite flecainide. Discussing ablation vs alternative antiarrhythmic options.",
        },
        {
            "visit_date": "2024-07-10",
            "reason": "Peripheral arterial disease management",
            "notes": "Stable claudication symptoms. Continue current management and exercise program.",
        },
    ],
    "af_management": {
        "current_treatment": "Flecainide",
        "treatment_history": [
            "Rate control with metoprolol",
            "Rhythm control with flecainide",
        ],
        "af_recurrence": True,
        "considering_options": ["Catheter ablation", "Alternative antiarrhythmic drug"],
        "chads2_vasc_score": 4,
        "has_bled_score": 2,
    },
    "pad_management": {
        "ankle_brachial_index": 0.78,
        "fontaine_classification": "Stage II",
        "current_treatment": ["Cilostazol", "Supervised exercise program"],
        "last_vascular_assessment": "2024-06-20",
    },
}

Execute example RAG prompt with citations

In [None]:
import nest_asyncio

nest_asyncio.apply()

In [None]:
from IPython.display import display, Markdown, HTML
import time

# Remove any existing conversations and specifications; only needed for notebook example
await delete_all_conversations()
await delete_all_specifications()

print('Deleted all conversations and specifications.')

# Formulate a query based on Sarah Johnson's profile
prompt = (
    f"Which option reduces stroke risk for a {patient_profile['age']}-year-old "
    f"female with atrial fibrillation and peripheral arterial disease, "
    f"who is considering catheter ablation or antiarrhythmic drug therapy?"
)

model = enums.OpenAIModels.GPT4O_MINI_128K
#model = enums.OpenAIModels.O1_MINI_128K

specification_id = await create_openai_specification(model)

if specification_id is not None:
    print(f'Created specification [{specification_id}].')

    conversation_id = await create_conversation(specification_id=specification_id)

    if conversation_id is not None:
        message, citations = await prompt_conversation(conversation_id, prompt)

        if message is not None:
            display(Markdown(f'### Patient-Specific Question Response for {patient_profile["name"]}:'))
            display(Markdown(f'**Final Response:**\n{message}'))
            print()

        if citations is not None:
            print("\nReference Information:")

            for citation in citations:
                if citation is not None and citation.content is not None:
                    display(Markdown(f'**Citation [{citation.index}]:** {citation.content.name}'))
                    display(Markdown(citation.text))
                    print()


Initialize CrewAI

In [None]:
from crewai import Agent, Task, Crew, Process
from graphlit_tools import PromptTool, CrewAIConverter
from textwrap import dedent

os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')

prompt_tool = CrewAIConverter.from_tool(PromptTool(graphlit))

# Embed patient profile information into the agent goals
# Define the CrewAI agents
guideline_expert = Agent(
    role="Guideline Expert",
    goal=(
        f"Retrieve and summarize relevant ACC guidelines for a {patient_profile['age']}-year-old "
        f"with atrial fibrillation and peripheral arterial disease. "
        "Focus on treatment options such as catheter ablation and antiarrhythmic therapy."
    ),
    backstory="You are an expert on ACC guidelines for managing atrial fibrillation and PAD.",
    tools=[prompt_tool],
    verbose=True,
)

patient_educator = Agent(
    role="Patient Educator",
    goal=(
        f"Translate the medical guidelines into easy-to-understand terms for Sarah Johnson. "
        "Focus on explaining the treatment options available for managing atrial fibrillation."
    ),
    backstory="You explain complex medical terms in patient-friendly language.",
    verbose=True,
)

treatment_planner = Agent(
    role="Treatment Planner",
    goal=(
        f"Create a personalized treatment plan for Sarah Johnson, considering her preference for minimally invasive procedures "
        f"and her options of catheter ablation or antiarrhythmic therapy."
    ),
    backstory="You specialize in personalized treatment plans based on patient history and preferences.",
    verbose=True,
)

output_generator = Agent(
    role="Output Generator",
    goal=(
        f"Compile the information into a comprehensive patient decision aid document for Sarah Johnson, "
        f"including a clear summary of her treatment options and next steps."
    ),
    backstory="You ensure that the medical recommendations are presented clearly and concisely.",
    verbose=True,
)

# Define tasks for each agent
task1 = Task(
    description="Retrieve ACC guidelines for managing atrial fibrillation and PAD.",
    expected_output="Summarized guidelines with a focus on catheter ablation and stroke risk reduction.",
    agent=guideline_expert,
)

task2 = Task(
    description="Translate the guidelines into patient-friendly language.",
    expected_output="Simplified, patient-friendly explanations of the treatment options.",
    agent=patient_educator,
)

task3 = Task(
    description="Personalize the treatment plan according to the patient's preferences.",
    expected_output="A treatment plan tailored to the patient's specific conditions and preferences.",
    agent=treatment_planner,
)

task4 = Task(
    description="Generate a patient decision aid document.",
    expected_output="A final decision aid document summarizing the patient's condition and treatment recommendations.",
    agent=output_generator,
)

# Create the Crew and define the process flow
crew = Crew(
    agents=[guideline_expert, patient_educator, treatment_planner, output_generator],
    tasks=[task1, task2, task3, task4],
    process=Process.sequential,  # Ensures tasks are executed in sequence
    verbose=True
)

# Execute the multi-step reasoning process
result = await crew.kickoff_async()

# Display the final patient decision aid
print("Final Patient Decision Aid:")
print(result)