<a href="https://colab.research.google.com/github/kedar-bhumkar/Clinical-companion-backend/blob/main/singular.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [57]:
pip install fastapi pydantic openai PyYAML tiktoken duckduckgo_search langchain langchain_openai faker asyncio



In [58]:
import json
import yaml
import random
import re
from difflib import *
import tiktoken
from typing import List, Dict, Any, Optional
from datetime import date, datetime, timedelta
from pydantic import BaseModel, Field
from fastapi import FastAPI, Request
import time
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from openai import OpenAI
import duckduckgo_search
from langchain.agents import tool
from faker import Faker
from langchain_openai import ChatOpenAI
from langchain.agents import tool
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.agents.format_scratchpad.openai_tools import (
    format_to_openai_tool_messages,
)
from langchain_core.messages import AIMessage, HumanMessage
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
from langchain.agents import AgentExecutor
import os, asyncio


In [59]:
# Constants
standard_model = "gpt-3.5-turbo-0125"
config_file = "./config/config.yaml"
prompts_file = "./config/prompts.yaml"
db_conn_file = './config/db_config.yaml'
action_file = './config/action.yaml'

In [60]:
default_mode = "serial"
default_page = "demo"
default_model_family = "openai"
default_model = "gpt-4o"
default_usecase = "demo"
default_temperature = 0.1
default_run_mode = "same-llm"
default_run_count = 1
default_sleep = 0.75
default_accuracy_check = "ON"
default_encoding = "cl100k_base"
default_fuzzy_matching_threshold = 80
default_negative_prompt = "ON"
default_formatter = "ros_pe_formatter"
default_use_for_training = False
default_error_detection = True

In [61]:
# Pydantic models
class Chat(BaseModel):
    message: str
    sender: str

In [62]:
class Message(BaseModel):
    page: Optional[str] = Field(default_page, description="Who sends the error message.")
    mode: Optional[str] = Field(default_mode, description="Who sends the error message.")
    family: Optional[str] = Field(default_model_family, description="Who sends the error message.")
    model: Optional[str] = Field(default_model, description="Who sends the error message.")
    negative_prompt: Optional[str] = Field(default_negative_prompt, description="Who sends the error message.")
    use_for_training: Optional[bool] = Field(default_use_for_training, description="use this for training.")
    message: Optional[str] = Field(None, description="The message to be sent to the model.")
    history: list[Chat] = Field(default_factory=list, description="List of messages to be sent to the model.")
    intent: Optional[str] = Field(None, description="The intent of the message.")
    entity: Optional[str] = Field(None, description="The entity of the message.")
    patient_id: Optional[str] = Field(None, description="The patient id of the message.")
    patient_ids: Optional[list[str]] = Field(["1"], description="The patient ids of the message.")
    form_data: Optional[str] = Field(None, description="The form data of the message used in auto complete.")

In [63]:
class Patient(BaseModel):
    name: str
    phone_number: str = Field(pattern=r'^\+?1?\d{9,15}$')
    address: str
    age: int = Field(ge=0, le=120)
    gender: str
    weight: float = Field(ge=0)
    height: float = Field(ge=0)

In [64]:
class Allergy(BaseModel):
    allergy_name: str
    loinc_code: str = Field(pattern=r'^[0-9]{1,5}-[0-9]$')
    start_date: date
    status: str
    end_date: Optional[date] = None

In [65]:
class PatientWithAllergies(BaseModel):
    patient: Patient
    allergies: List[Allergy]

In [66]:
# Shared data instance
class SharedData:
    _instance = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(SharedData, cls).__new__(cls)
            cls._instance.data = {}
        return cls._instance

    def get_data(self, key):
        return self.data.get(key, None)

    def set_data(self, key, value):
        self.data[key] = value

    def clear(self, key=None):
        if key is None:
            self.data.clear()
        elif key in self.data:
            del self.data[key]

In [67]:
shared_data_instance = SharedData()

In [68]:
# Utility functions
def getConfig(file_path):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

In [69]:
def num_tokens_from_string(string: str, encoding_name: str, type: str) -> int:
    encoding = tiktoken.get_encoding(encoding_name)
    num_tokens = len(encoding.encode(string))
    print(f'For {type} the no of tokens are {num_tokens}')
    return num_tokens

In [70]:
def get_nested_value(config: Dict[str, Any], keys: list) -> Optional[Any]:
    for key in keys:
        if isinstance(config, dict) and key in config:
            config = config[key]
        else:
            return None
    return config

In [71]:
def date_to_str(obj):
    if isinstance(obj, date):
        return obj.isoformat()
    raise TypeError(f"Object of type {type(obj)} is not JSON serializable")

In [72]:
# Mock data generation
fake = Faker()

In [73]:
def generate_patient_id() -> str:
    return f"P{fake.random_number(digits=6, fix_len=True)}"

In [74]:
def getPatientSummary(patient_id: str) -> Dict[str, Any]:
    return {"Summary": {
        "patient_id": patient_id,
        "basic_info": {
            "name": fake.name(),
            "age": random.randint(18, 90),
            "gender": random.choice(["Male", "Female", "Other"]),
            "blood_type": random.choice(["A+", "A-", "B+", "B-", "AB+", "AB-", "O+", "O-"]),
            "height": round(random.uniform(150, 200), 1),
            "weight": round(random.uniform(45, 120), 1),
            "is_member_self_responsible": random.choice([True, False])
        },
        "vitals": getPatientVitals(patient_id),
        "allergies": getPatientAllergies(patient_id),
        "conditions": getPatientConditions(patient_id),
        "immunizations": getPatientImmunizations(patient_id),
        "lab_results": getPatientLabResults(patient_id),
        "procedures": getPatientProcedures(patient_id),
        "medications": getPatientMedications(patient_id),
        "appointments": getPatientAppointments(patient_id),
        "messages": getPatientMessages(patient_id),
        "soap_note": getSoapNoteData()
    }}

In [75]:
def getPatientVitals(patient_id: str) -> Dict[str, Any]:
    return  {"Vitals": {
        "patient_id": patient_id,
        "temperature": round(random.uniform(36.1, 37.5), 1),
        "heart_rate": random.randint(60, 100),
        "blood_pressure": f"{random.randint(90, 140)}/{random.randint(60, 90)}",
        "respiratory_rate": random.randint(12, 20),
        "oxygen_saturation": random.randint(95, 100),
    }}

In [76]:
def getPatientAllergies(patient_id: str) -> List[str]:
    allergies = ["Penicillin", "Peanuts", "Latex", "Aspirin", "Shellfish", "Eggs", "Soy", "Wheat", "Milk", "Tree nuts"]
    return random.sample(allergies, random.randint(0, 3))

In [77]:
def getPatientConditions(patient_id: str) -> List[str]:
    conditions = ["Hypertension", "Type 2 Diabetes", "Asthma", "Osteoarthritis", "Depression", "Anxiety", "GERD", "Hypothyroidism", "Hyperlipidemia", "Chronic kidney disease"]
    return random.sample(conditions, random.randint(0, 4))

In [78]:
def getPatientImmunizations(patient_id: str) -> List[Dict[str, Any]]:
    immunizations = ["Influenza", "Tetanus", "Hepatitis B", "MMR", "Pneumococcal", "HPV", "Varicella", "Shingles"]
    return [
        {
            "name": imm,
            "date": fake.date_between(start_date="-5y", end_date="today").strftime("%Y-%m-%d"),
        }
        for imm in random.sample(immunizations, random.randint(2, 5))
    ]

In [79]:
def getPatientLabResults(patient_id: str) -> List[Dict[str, Any]]:
    lab_tests = ["Complete Blood Count", "Lipid Panel", "Comprehensive Metabolic Panel", "Hemoglobin A1C", "Thyroid Function Tests", "Urinalysis"]
    return [
        {
            "test_name": test,
            "date": fake.date_between(start_date="-1y", end_date="today").strftime("%Y-%m-%d"),
            "result": f"{random.uniform(0.5, 2.0):.2f}",
            "unit": random.choice(["mg/dL", "mmol/L", "%", "U/L"]),
            "reference_range": f"{random.uniform(0.1, 0.9):.1f} - {random.uniform(1.1, 3.0):.1f}",
        }
        for test in random.sample(lab_tests, random.randint(2, 4))
    ]

In [80]:
def getPatientProcedures(patient_id: str) -> List[Dict[str, Any]]:
    procedures = ["Appendectomy", "Colonoscopy", "Knee Arthroscopy", "Cataract Surgery", "Tonsillectomy", "Wisdom Teeth Extraction"]
    return [
        {
            "name": proc,
            "date": fake.date_between(start_date="-3y", end_date="today").strftime("%Y-%m-%d"),
            "provider": fake.name(),
        }
        for proc in random.sample(procedures, random.randint(0, 2))
    ]

In [81]:
def getPatientMedications(patient_id: str) -> List[Dict[str, Any]]:
    medications = ["Lisinopril", "Metformin", "Levothyroxine", "Amlodipine", "Metoprolol", "Omeprazole", "Gabapentin", "Sertraline"]
    return [
        {
            "name": med,
            "dosage": f"{random.choice([5, 10, 20, 25, 50, 100])} mg",
            "frequency": random.choice(["Once daily", "Twice daily", "Three times daily", "As needed"]),
            "prescribed_date": fake.date_between(start_date="-1y", end_date="today").strftime("%Y-%m-%d"),
        }
        for med in random.sample(medications, random.randint(1, 4))
    ]

In [82]:
def getPatientAppointments(patient_id: str) -> List[Dict[str, Any]]:
    appointment_types = ["Annual Physical", "Follow-up", "Specialist Consultation", "Vaccination", "Lab Work"]
    return [
        {
            "type": random.choice(appointment_types),
            "date": fake.date_between(start_date="today", end_date="+6m").strftime("%Y-%m-%d"),
            "time": fake.time(),
            "provider": fake.name(),
        }
        for _ in range(random.randint(1, 3))
    ]

In [83]:
def getPatientMessages(patient_id: str) -> List[Dict[str, Any]]:
    message_templates = [
        {
            "subject": "Appointment Confirmation",
            "content": "Your appointment with Dr. {doctor} is confirmed for {date} at {time}. Please arrive 15 minutes early to complete any necessary paperwork. If you need to reschedule, please call our office at least 24 hours in advance."
        },
        {
            "subject": "Prescription Refill Request",
            "content": "This is to confirm that we've received your request to refill your prescription for {medication}. We'll process this and send it to your pharmacy within 48 hours. If you have any questions, please don't hesitate to contact us."
        },
        {
            "subject": "Lab Results Available",
            "content": "Your recent lab results for {test} are now available. Please log in to your patient portal to view them. If you have any questions about your results, please schedule a follow-up appointment with Dr. {doctor}."
        },
        {
            "subject": "Appointment Reminder",
            "content": "This is a reminder that you have an appointment scheduled with Dr. {doctor} on {date} at {time} for your {appointment_type}. Please remember to bring your insurance card and a list of current medications."
        },
        {
            "subject": "Health Question",
            "content": "I've been experiencing {symptom} for the past {duration}. It seems to worsen when {trigger}. Is this something I should be concerned about or schedule an appointment for?"
        }
    ]

    return [
        {
            "subject": template["subject"],
            "date": fake.date_between(start_date="-3m", end_date="today").strftime("%Y-%m-%d"),
            "time": fake.time(),
            "sender": random.choice(["Patient", "Provider"]),
            "content": template["content"].format(
                doctor=fake.name(),
                date=fake.date_between(start_date="today", end_date="+30d").strftime("%Y-%m-%d"),
                time=fake.time(),
                medication=fake.word(),
                test=random.choice(["Complete Blood Count", "Lipid Panel", "Thyroid Function"]),
                appointment_type=random.choice(["annual check-up", "follow-up visit", "consultation"]),
                symptom=random.choice(["headache", "back pain", "nausea", "fatigue"]),
                duration=random.choice(["two days", "a week", "three days"]),
                trigger=random.choice(["I stand for long periods", "I eat certain foods", "I exercise"])
            )
        }
        for template in random.sample(message_templates, random.randint(2, 5))
    ]

In [84]:
def getSoapNoteData() -> Dict[str, Any]:
    chief_complaints = {
        "Chest pain": ["sharp", "dull", "crushing", "radiating", "intermittent", "constant"],
        "Shortness of breath": ["at rest", "on exertion", "when lying flat", "sudden onset"],
        "Abdominal pain": ["cramping", "sharp", "dull", "localized", "diffuse", "intermittent"],
        "Headache": ["throbbing", "pressure-like", "unilateral", "bilateral", "with aura"],
        "Back pain": ["lower", "upper", "mid", "radiating", "constant", "intermittent"],
        "Fever": ["high-grade", "low-grade", "intermittent", "with chills", "persistent"],
        "Cough": ["dry", "productive", "persistent", "with blood-tinged sputum", "nocturnal"],
        "Fatigue": ["generalized", "sudden onset", "progressive", "with weakness"],
        "Nausea": ["with vomiting", "without vomiting", "intermittent", "persistent"],
        "Dizziness": ["vertigo", "lightheadedness", "with fainting", "positional"]
    }

    hpi_elements = {
        "onset": ["sudden", "gradual", "acute", "chronic", "intermittent"],
        "location": ["localized", "diffuse", "radiating", "migrating"],
        "duration": ["for the past day", "for several days", "for a week", "for several weeks", "for months"],
        "characterization": ["sharp", "dull", "aching", "burning", "throbbing", "stabbing"],
        "alleviating factors": ["rest", "medication", "position change", "ice", "heat"],
        "aggravating factors": ["movement", "eating", "stress", "certain positions", "time of day"],
        "associated symptoms": ["nausea", "vomiting", "fever", "chills", "sweating", "fatigue"]
    }

    ros_systems = {
        "Constitutional": ["fever", "chills", "fatigue", "weight loss", "weight gain"],
        "Eyes": ["vision changes", "eye pain", "redness", "discharge"],
        "Ears, Nose, Mouth, Throat": ["hearing loss", "tinnitus", "sore throat", "nasal congestion"],
        "Cardiovascular": ["chest pain", "palpitations", "edema", "orthopnea"],
        "Respiratory": ["cough", "shortness of breath", "wheezing", "hemoptysis"],
        "Gastrointestinal": ["nausea", "vomiting", "diarrhea", "constipation", "abdominal pain"],
        "Genitourinary": ["dysuria", "frequency", "urgency", "hematuria"],
        "Musculoskeletal": ["joint pain", "muscle pain", "stiffness", "swelling"],
        "Integumentary": ["rash", "itching", "skin lesions", "changes in moles"],
        "Neurological": ["headache", "dizziness", "numbness", "tingling", "weakness"],
        "Psychiatric": ["depression", "anxiety", "sleep disturbances", "mood changes"],
        "Endocrine": ["heat/cold intolerance", "excessive thirst", "excessive urination"],
        "Hematologic/Lymphatic": ["easy bruising", "bleeding", "swollen lymph nodes"],
        "Allergic/Immunologic": ["seasonal allergies", "food allergies", "frequent infections"]
    }

    pe_sections = {
        "General": ["Well-appearing", "Ill-appearing", "Comfortable", "Distressed"],
        "HEENT": ["PERRL", "EOMI", "TMs clear", "Oropharynx clear", "Mild erythema"],
        "Neck": ["Supple", "No lymphadenopathy", "Thyromegaly", "JVD present"],
        "Chest": ["Clear to auscultation", "Wheezes", "Crackles", "Rhonchi"],
        "Cardiovascular": ["RRR", "Murmur present", "S3 gallop", "S4 gallop"],
        "Abdomen": ["Soft", "Non-tender", "Distended", "Rebound tenderness", "Guarding"],
        "Musculoskeletal": ["Full ROM", "Tenderness", "Swelling", "Deformity"],
        "Neurological": ["Alert and oriented", "CN II-XII intact", "Sensory intact", "Motor strength 5/5"],
        "Skin": ["No rashes", "Erythema", "Petechiae", "Ecchymosis"]
    }

    def generate_chief_complaint() -> str:
        complaint, details = random.choice(list(chief_complaints.items()))
        return f"{complaint}: {', '.join(random.sample(details, random.randint(1, min(3, len(details)))))}"

    def generate_hpi() -> str:
        return " ".join([
            f"{element.capitalize()}: {random.choice(details)}."
            for element, details in random.sample(list(hpi_elements.items()), random.randint(3, 6))
        ])

    def generate_ros() -> Dict[str, str]:
        return {
            system: ", ".join(random.sample(findings, random.randint(0, min(3, len(findings)))))
            if random.choice([True, False]) else "No abnormalities noted."
            for system, findings in ros_systems.items()
        }

    def generate_pe() -> Dict[str, str]:
        return {
            section: ", ".join(random.sample(findings, random.randint(1, min(3, len(findings)))))
            for section, findings in pe_sections.items()
        }

    return {
        "chief_complaint": generate_chief_complaint(),
        "history_of_present_illness": generate_hpi(),
        "review_of_systems": generate_ros(),
        "physical_exam": generate_pe()
    }

In [85]:
def getAllPatientSummary(patient_ids: List[str]) -> Dict[str, Any]:
    summaries = []
    for patient_id in patient_ids:
        summaries.append(getPatientSummary(patient_id))
    return {'summaries': summaries}

In [86]:
def get_patient(patient_id: str) -> Dict[str, Any]:
    mock_patient = Patient(
        name="John Doe",
        phone_number="+1234567890",
        address="123 Main St, Anytown, USA",
        age=35,
        gender="Male",
        weight=75.5,
        height=180.0
    )

    mock_allergies = [
        Allergy(
            allergy_name="Peanut Allergy",
            loinc_code="12345-6",
            start_date=date(2020, 1, 15),
            status="Active"
        ),
        Allergy(
            allergy_name="Penicillin Allergy",
            loinc_code="78901-2",
            start_date=date(2018, 5, 20),
            status="Inactive",
            end_date=date(2022, 3, 10)
        )
    ]

    mock_patient_with_allergies = PatientWithAllergies(
        patient=mock_patient,
        allergies=mock_allergies
    )

    patient_data = json.loads(json.dumps(mock_patient_with_allergies.model_dump(), default=date_to_str))

    return {
        "data": patient_data,
        "types": {
            "patient": {
                "name": "str",
                "phone_number": "str",
                "address": "str",
                "age": "int",
                "gender": "str",
                "weight": "float",
                "height": "float"
            },
            "allergies": {
                "allergy_name": "str",
                "loinc_code": "str",
                "start_date": "date",
                "status": "str",
                "end_date": "Optional[date]"
            }
        },
        "descriptions": {
            "patient": {
                "name": "Full name of the patient",
                "phone_number": "Contact phone number with country code",
                "address": "Current residential address",
                "age": "Age in years",
                "gender": "Self-identified gender",
                "weight": "Weight in kilograms",
                "height": "Height in centimeters"
            },
            "allergies": {
                "allergy_name": "Common name of the allergy",
                "loinc_code": "LOINC code for standardized allergy identification",
                "start_date": "Date when the allergy was first diagnosed or reported",
                "status": "Current status of the allergy (e.g., Active, Inactive)",
                "end_date": "Date when the allergy was resolved or became inactive, if applicable"
            }
        }
    }

In [87]:
@tool
def get_advanced_care_plan(patient_id: str) -> Dict[str, Any]:
    """
    A tool that returns a dictionary containing the Advanced_care_plan data, data types, and field descriptions.

    The returned dictionary has the following structure:
    {
        "data": { ... },
        "types": { ... },
        "descriptions": { ... }
    }
    """
    shared_data_instance.set_data('auto_populate', [])
    today = date.today()

    acp_data = {
        "is_member_self_responsible": True,
        "status": "Completed",
        "created_date": today.isoformat()
    }

    return {"Advanced care plan (ACP)": {
        "data": acp_data,
        "types": {
            "is_member_self_responsible": "bool",
            "status": "str",
            "created_date": "date"
        },
        "descriptions": {
            "is_member_self_responsible": "Indicates if the member is responsible for their own decisions",
            "status": "Current status of the ACP document (options: Draft, Completed, Entered in error)",
            "created_date": "Date when the ACP document was created"
        }
    }}

In [88]:
@tool
def get_Responsible_party(patient_id: str) -> Dict[str, Any]:
    """

    A tool that returns a dictionary containing Responsible_party details or MemberContacts data, data types, and field descriptions.

    The returned dictionary has the following structure:
    {
        "data": { ... },
        "types": { ... },
        "descriptions": { ... }
    }
    """
    shared_data_instance.set_data('auto_populate', 'auto_populate')
    today = date.today()

    member_contacts_data = {
        "responsible_party_name": "Ron H",
        "responsible_party_role": "POA",
        "created_date": today.isoformat(),
        "status": "active"
    }

    return {"Member contacts (MemContacts)": {
        "data": member_contacts_data,
        "types": {
            "responsible_party_name": "str",
            "responsible_party_role": "str",
            "created_date": "date",
            "status": "str"
        },
        "descriptions": {
            "responsible_party_name": "Represents the responsible party's name",
            "responsible_party_role": "Represents the  role of the responsible party (options: Legal Guardian, POA, Health surrogate)",
            "created_date": "Date when the contact was created",
            "status": "Current status of the contact (options: active, inactive)"
        }
    }}

In [89]:
# DTO functions
def get_All_patient_summary(patient_ids: List[str]) -> Dict[str, Any]:
    return getAllPatientSummary(patient_ids)

In [90]:
def get_Patient_Summary(patient_id: str) -> Dict[str, Any]:
    return getPatientSummary(patient_id)

In [91]:
def get_patient_vitals(patient_id: str) -> Dict[str, Any]:
    return getPatientVitals(patient_id)

In [92]:
@tool
def web_search(query: str) -> List[Dict[str, Any]]:
    """
    Perform a web search using DuckDuckGo based on the user's query.

    Args:
        query (str): The search query provided by the user.

    Returns:
        List[Dict[str, Any]]: A list of search results, where each result is a dictionary
        containing information about the search result.
    """
    ddg = duckduckgo_search.DDGS()
    results = list(ddg.text(query, max_results=5))
    return results

In [93]:
# Service functions
prompt_config = {}
action_config = {}

In [94]:
async def handlemessage(message: Message) -> Dict[str, Any]:
    print("Inside handlemessage")
    print('message', message)
    global prompt_config
    global action_config
    patient_data = []

    shared_data_instance.clear()

    prompt_config = getConfig(prompts_file)
    action_config = getConfig(action_file)

    job_aid_match = re.search(r'@Job aid:\s*(.*)', message.message)
    web_search_match = re.search(r'@web:\s*(.*)', message.message)
    if job_aid_match:
        user_question = job_aid_match.group(1).strip()
        #return await call_job_aid_api(user_question)
    elif web_search_match:
        user_question = web_search_match.group(1).strip()
        message.message = user_question
        message.intent = "user-intent"

    value = []
    functions = get_nested_value(action_config, [message.entity, message.intent, message.message, "functions"])
    print("functions", functions)
    if functions is None:
        base_function = get_nested_value(action_config, [message.entity, "base-function"])
        functions = [base_function] if base_function else []
        print("base-functions", functions)

    for function in functions:
        result = globals()[function](message.patient_ids)
        value.append(result)

    patient_data = value

    response = runner(message, patient_data)
    print("response", response)

    if re.search(r'<[^>]+>', response):
        return {"html": response}
    elif message.message == "auto_populate" or shared_data_instance.get_data('auto_populate') == 'auto_populate':
        return {"auto_populate": response}
    else:
        return response

In [122]:
async def handle_config(message: Message) -> Dict[str, Any]:
    print("Inside handle_config")
    print('message', message)

    with open('config/action.yaml', 'r') as file:
        prompt_config = yaml.safe_load(file)

    entity_options = prompt_config.get(message.entity, {}).get('system-intent', {}).get('options', {})
    print("entity_options", entity_options)

    if not entity_options:
        response = {"options": {"error": "No Action has been configured"}}
    else:
        response = {
            "options": {key: value for key, value in entity_options.items()}
        }

    return response

In [96]:
MEMORY_KEY = "chat_history"
cache = {}

In [97]:
def chat(system_prompt, user_input, functions, model):
    chat_history = []

    os.environ["OPENAI_API_KEY"] = shared_data_instance.get_data('key')

    llm = ChatOpenAI(model=model, temperature=0)
    if functions:
        tools = [globals()[name] for name in functions if name in globals()]
    else:
        tools = []

    print("tools", tools)

    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                system_prompt,
            ),
            MessagesPlaceholder(variable_name=MEMORY_KEY),
            ("user", "Question from the user : {input}."),
            MessagesPlaceholder(variable_name="agent_scratchpad"),
        ]
    )
    if tools:
        llm_with_tools = llm.bind_tools(tools)
    else:
        llm_with_tools = llm

    agent = (
        {
            "input": lambda x: x["input"],
            "agent_scratchpad": lambda x: format_to_openai_tool_messages(
                x["intermediate_steps"]
            ),
            "chat_history": lambda x: x["chat_history"],
        }
        | prompt
        | llm_with_tools
        | OpenAIToolsAgentOutputParser()
    )

    agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

    result = agent_executor.invoke({"input": "User Question-" + user_input, "chat_history": chat_history})
    chat_history.extend(
        [
            HumanMessage(content=user_input),
            AIMessage(content=result["output"]),
        ]
    )

    return result['output']

In [98]:
def runner(message, patient_data):
    print("Inside runner")
    user_prompt, system_prompt = create_prompt(message, patient_data)
    client = init_AI_client(message.family)

    if message.intent == "user-intent":
        functions = action_config[message.entity]['user-intent']["functions"]
        print("functions", functions)
        return chat(system_prompt, user_prompt, functions, message.model)
    else:
        return generate(client, message.model, user_prompt, system_prompt)

In [99]:
def create_prompt(message, patient_data):
    print("Inside create_prompt")
    system_prompt = prompt_config["system_prompt"]

    if message.intent == "system-intent":
        user_prompt = prompt_config["user_prompt"][message.entity][message.intent][message.message]['input']
    else:
        user_prompt = prompt_config["user_prompt"][message.entity][message.intent]['general']['input']

    print(f"user_prompt - {user_prompt}")
    if message.intent == "system-intent":
        if message.message == "auto_populate":
            print(f"message.form_data - {message.form_data}")
            user_prompt = user_prompt.format(patient_data=patient_data, format=prompt_config["user_prompt"][message.entity][message.intent][message.message]["output"], rules=prompt_config["user_prompt"][message.entity][message.intent][message.message]["rules"], form_data=message.form_data)
        else:
            user_prompt = user_prompt.format(patient_data=patient_data, format=prompt_config["user_prompt"][message.entity][message.intent][message.message]["output"], rules=prompt_config["user_prompt"][message.entity][message.intent][message.message]["rules"])
    else:
        print(f"message.form_data - {message.form_data}")
        output = prompt_config["user_prompt"][message.entity][message.intent]['general']["output"]
        print(f"output - {output}")
        user_prompt = user_prompt.format(patient_data=patient_data, format=output, rules=prompt_config["user_prompt"][message.entity][message.intent]['general']["rules"], question=message.message, form_data=message.form_data)
        user_prompt = user_prompt.replace("{format}", output)

    print("\n user_prompt", user_prompt)
    print("\n system_prompt", system_prompt)
    return user_prompt, system_prompt

In [100]:
def init_AI_client(model_family):
    config = getConfig(config_file)
    key = config[model_family]["key"]
    base_url = config[model_family]["url"]

    shared_data_instance.set_data('key', key)
    shared_data_instance.set_data('base_url', base_url)

    return OpenAI(
        api_key=key,
        base_url=base_url
    )

In [101]:
def generate(client, model, user_prompt, system_prompt):
    num_tokens_from_string(''.join([system_prompt, user_prompt]), default_encoding, "input")

    chat_completion = client.chat.completions.create(
      messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user",   "content": user_prompt}
        ],
        model=model,
        temperature=default_temperature
    )
    response = chat_completion.choices[0].message.content
    num_tokens_from_string(response, default_encoding, "output")
    return response

In [102]:
# FastAPI app
@asynccontextmanager
async def lifespan(app: FastAPI):
    print("Server startup .....")
    yield
    print("Server shutdown .......")

In [103]:
app = FastAPI(lifespan=lifespan)

In [104]:
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

In [105]:
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
    start_time = time.time()
    response = await call_next(request)
    process_time = time.time() - start_time
    response.headers["X-Process-Time"] = str(process_time)
    print(f"Request took {process_time} secs to complete")
    return response

In [106]:
@app.get("/")
def doGet(request: Request):
    return {"Hello": "World"}

In [107]:
@app.post("/chat")
async def doChat(request: Request, message: Message):
    print("Inside /chat")
    print(f'prompt - {message.entity}, mode - {message.message}')

    response = await handlemessage(message)
    return {"llm_response": response}

In [108]:
@app.post("/config")
def doConfig(request: Request, message: Message):
    print("Inside /config")
    print(f'prompt - {message.entity}, mode - {message.message}')

    response = handle_config(message)
    return {"llm_response": response}

In [109]:

async def test_web_search():
    mock_message = Message(
        intent="user-intent",
        entity="landing_page",
        message="@web:Explain me the pensrose diagram for black holes",
        patient_id="12345"
    )
    response = await handlemessage(mock_message)
    print(f"Response from handlemessage for web search: {response}")

In [120]:
# Test function
async def test_handle_config():
    # Create a sample Message
    test_message = Message(entity="order_page")

    # Call handle_config
    result = await handle_config(test_message)

    # Print the result
    print("\nResult:")
    print(result)

async def test_get_All_patient_summary():

    mock_message = Message(
        intent="system-intent",
        entity="memberlist_page",
        message="patient_summary",
        patient_ids=["1","2"]
    )

    # Call handlemessage
    response = await handlemessage(mock_message)

    # Print the response
    print("Response from handlemessage:")
    print(f"LLM Response: {response}")

In [124]:
async def main():
    print("Starting the server")
    #await test_web_search()
    #await test_handle_config()
    await test_get_All_patient_summary()
    print("Shutting down the server")

if __name__ == "__main__":
    await main()


Starting the server
Inside handlemessage
message page='demo' mode='serial' family='openai' model='gpt-4o' negative_prompt='ON' use_for_training=False message='patient_summary' history=[] intent='system-intent' entity='memberlist_page' patient_id=None patient_ids=['1', '2'] form_data=None
functions None
base-functions ['get_All_patient_summary']
Inside runner
Inside create_prompt
user_prompt - Summarize the below patient data focussing on key details Patient Data: {patient_data}. Use the below rules to provide the summary. Rules: {rules}. Provide the summary in the below format. Format:{format}

 user_prompt Summarize the below patient data focussing on key details Patient Data: [{'summaries': [{'Summary': {'patient_id': '1', 'basic_info': {'name': 'Robert Molina', 'age': 86, 'gender': 'Male', 'blood_type': 'O+', 'height': 157.6, 'weight': 49.0, 'is_member_self_responsible': True}, 'vitals': {'Vitals': {'patient_id': '1', 'temperature': 36.4, 'heart_rate': 94, 'blood_pressure': '137/61'