<a href="https://colab.research.google.com/github/ekrombouts/GenCareAI/blob/main/notebooks/100_note_generation/120_GenerateClientScenarios.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GenCare AI: Generating client scenarios

**Author:** Eva Rombouts  
**Date:**   13-06-2024  
**Updated:** 2024-09-30  
**Version:** 2.0

### Description
This scripts generates client scenarios based on profiles generated [here](https://colab.research.google.com/github/ekrombouts/GenCareAI/blob/main/notebooks/100_note_generation/110_GenerateClientProfiles.ipynb).

Generating scenarios based on 24 client profiles and 8 periods, the cost is approximately $0.03 per run.

In [None]:
!pip install GenCareAI
from GenCareAI.GenCareAIUtils import GenCareAISetup

setup = GenCareAISetup()

if setup.environment == 'Colab':
        !pip install -q langchain langchain_core langchain_openai langchain_community

In [2]:
from typing import List

from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_community.callbacks import get_openai_callback
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI

import os
import pandas as pd
import random
import numpy as np

In [3]:
# Constants and Configurations
ward_name = 'Hermes'
fn_profiles = setup.get_file_path(f'data/gcai_client_profiles_{ward_name}.csv')
fn_scenarios = setup.get_file_path(f'data/gcai_client_scenarios_{ward_name}.csv')

model = 'gpt-3.5-turbo-0125'
temp = 1.1

duration = 8 # Number of periods
duration_sd = 3 # Standard deviation of the number of periods
num_complications_min = 1
num_complications_max = 3

complications_library = ["gewichtsverlies", "algehele achteruitgang", "decubitus", "urineweginfectie", "pneumonie", "delier", "verergering van onderliggende lichamelijke klachten", "verbetering van de klachten", "overlijden", "valpartij"]

verbose = True

In [4]:
# Load the client profiles
df = pd.read_csv(fn_profiles)

In [5]:
# Pydantic models
class ClientScenario(BaseModel):
    period: str = Field(description="Volgnummer van de periode")
    events_description: str = Field(description="Beschrijving van de gebeurtenissen en zorg")

class ClientScenarios(BaseModel):
    scenario: List[ClientScenario]

In [6]:
# Initialize model and parser
model = ChatOpenAI(api_key=setup.get_openai_key(), temperature=temp, model=model)
pyd_parser = PydanticOutputParser(pydantic_object=ClientScenarios)
format_instructions = pyd_parser.get_format_instructions()

In [None]:
# Define the prompt template
template="""
Dit is het profiel van een fictieve client in het verpleeghuis:
---
{client_profile}
---

Schrijf in een tijdlijn het beloop van zijn/haar verblijf in het verpleeghuis gedurende {num_periods} periodes.
Verwerk de volgende complicatie(s) hierin: {complications}.
Hou wijzigingen subtiel. Vermijd al te grote dramatiek.
Vermijd het noemen van de naam.

{format_instructions}
"""

prompt_template = PromptTemplate(
    template=template,
    input_variables=["client_profile", "num_periods", "complications"],
    partial_variables={"format_instructions": format_instructions},
)

if verbose: 
    print(prompt_template.format(client_profile="client profiel",
                                 num_periods = 6, 
                                 complications = "complicatie(s)"))


In [8]:
# Create a chain of operations: prompt template -> model -> output parser
chain_scenario = prompt_template | model | pyd_parser

In [None]:
# Generate and save scenarios
if not os.path.exists(fn_scenarios):
    print("Data file not found. Generating new data...")

    def generate_scenarios(df, chain):
        def display_profile(row):
            return (
                f"Naam: {row['naam']}\n"
                f"Type Dementie: {row['type_dementie']}\n"
                f"Lichamelijke klachten: {row['somatiek']}\n"
                f"ADL: {row['adl']}\n"
                f"Mobiliteit: {row['mobiliteit']}\n"
                f"Cognitie/gedrag: {row['gedrag']}"
            )

        def determine_duration(mean=6, std_dev=2):
            return int(np.round(np.random.normal(mean, std_dev)))

        def determine_num_complications(min=1, max=3):
            return random.randint(min, max)

        scenario_list = []
        for _, row in df.iterrows():
            client_profile = display_profile(row)
            print(f"Generating scenario for client: {row['naam']}")
            num_periods = determine_duration(mean=duration, std_dev=duration_sd)
            num_complications = determine_num_complications(min=num_complications_min, max=num_complications_max)
            chosen_complications = random.sample(complications_library, num_complications)
            complications = ", ".join(chosen_complications)

            try:
                result = chain.invoke({"client_profile": client_profile, "num_periods": str(num_periods), "complications": complications})
            except Exception as e:
                print(f"Error encountered: {e}. Retrying...")
                result = chain.invoke({"client_profile": client_profile, "num_periods": str(num_periods), "complications": complications})
                print("Retry successful")
            
            for scenario in result.scenario:
                scenario_list.append((row['client_id'], scenario.period, scenario.events_description, complications, num_periods))
        return scenario_list

    with get_openai_callback() as cb:
        scenario_data = generate_scenarios(df, chain_scenario)
        print(cb)

    df_scenarios = pd.DataFrame(scenario_data, columns=['client_id', 'period', 'events_description', 'complications', 'num_periods'])
    df_scenarios.to_csv(fn_scenarios, index=False)
    print(f"Data saved successfully to {fn_scenarios}.")
else:
    print("Data file found. Loading data...")
    df_scenarios = pd.read_csv(fn_scenarios)

In [None]:
if verbose:    
    sample_client = 4
    def display_profile(row):
            return (
                f"Naam: {row['naam']}\n"
                f"Type Dementie: {row['type_dementie']}\n"
                f"Lichamelijke klachten: {row['somatiek']}\n"
                f"ADL: {row['adl']}\n"
                f"Mobiliteit: {row['mobiliteit']}\n"
                f"Cognitie/gedrag: {row['gedrag']}"
            )
    
    print(display_profile(df[df['client_id'] == sample_client].iloc[0]))

    print(100*'*')
    print(df_scenarios[df_scenarios['client_id'] == sample_client][['period', 'events_description']].to_string(index=False))
