<a href="https://colab.research.google.com/github/ekrombouts/GenCareAI/blob/main/scripts/110_GenerateClientProfiles.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GenCare AI: Generating client profiles

**Author:** Eva Rombouts  
**Date:** 2024-06-01  
**Updated:** 2024-06-13  
**Version:** 1.2

### Description
This script generates synthetic healthcare data for NLP experiments. It generates diverse client profiles for a psychogeriatric ward using the OpenAI GPT-4 model. 

The output parser uses a structure called ClientProfile, which is created with Pydantic models. Pydantic helps define and validate the output for each client profile, ensuring that each profile has the right format and contains the necessary information.

The goal is to produce a comprehensive and varied dataset of client profiles for use in a psychogeriatric setting, avoiding repetitive or deterministic outputs. To achieve this, we use GPT-4 with a high temperature setting to enhance variability. Additionally, each query generates multiple profiles to further ensure diversity.

### Imports and constants

In [None]:
import os
# Determines the current environment (Google Colab or local)
def check_environment():
    try:
        import google.colab
        return "Google Colab"
    except ImportError:
        pass

    return "Local Environment"

In [None]:
# Installs and settings depending on the environment
# When running in CoLab, the Google drive is mounted and necessary packages are installed.
# Data paths are set and API keys retrieved

env = check_environment()

if env == "Google Colab":
    print("Running in Google Colab")
    !pip install -q langchain langchain_core langchain_openai langchain_community
    from google.colab import drive, userdata
    drive.mount('/content/drive')
    DATA_DIR = '/content/drive/My Drive/Colab Notebooks/GenCareAI/data'
    OPENAI_API_KEY = userdata.get('GCI_OPENAI_API_KEY')
    HF_TOKEN = userdata.get('HF_TOKEN')
else:
    print("Running in Local Environment")
    # !pip install python-dotenv langchain langchain_core langchain-community langchain_openai
    DATA_DIR = '../data'
    from dotenv import load_dotenv
    load_dotenv()
    OPENAI_API_KEY = os.getenv('GCI_OPENAI_API_KEY')
    HF_TOKEN = os.getenv('HF_TOKEN')

In [None]:
import os
import pandas as pd
from typing import List
from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_community.callbacks import get_openai_callback

In [None]:
# Constants and Configurations
# The ward name will be used in the filename. Practical when performing multiple
# experiments
WARD_NAME = 'Tulip'
FN_PROFILES = os.path.join(DATA_DIR, f'gcai_client_profiles_{WARD_NAME}.csv')
# Per query eight profiles are generated. The query is run NUM_WINGS times, so
# when NUM_WINGS is set to 3 the total number of client profiles generated is 24.
NUM_WINGS = 3
# GPT-4o yields better, more diverse results.
MODEL_PROFILES = 'gpt-4o-2024-05-13'
TEMP = 1.0

### Data

In [None]:
# Definition of Pydantic model to structure the client profile data
class ClientProfile(BaseModel):
    naam: str = Field(description="naam van de client (Meneer/Mevrouw Voornaam Achternaam, gebruik een naam die je normaal niet zou kiezen)")
    type_dementie: str = Field(description="type dementie (Alzheimer, gemengde dementie, vasculaire dementie, lewy body dementie, parkinsondementie, FTD: varieer, de kans op Alzheimer, gemengde en vasculaire dementie is het grootst)")
    somatiek: str = Field(description="lichamelijke klachten")
    # biografie: str = Field(description="een korte beschrijving van karakter en relevante biografische gegevens (vermijd stereotypen in beroep en achtergrond)")
    adl: str = Field(description="beschrijf welke ADL hulp de cliënt nodig heeft")
    mobiliteit: str = Field(description="beschrijf de mobiliteit (bv rolstoelafhankelijk, gebruik rollator, valgevaar)")
    gedrag: str = Field(description="beschrijf voor de zorg relevante aspecten van cognitie en probleemgedrag. Varieer met de ernst van het probleemgedrag van rustige cliënten, gemiddeld onrustige cliënten tot cliënten die fors apathisch, onrustig, angstig, geagiteerd of zelfs agressief kunnen zijn")

# Pydantic model to hold multiple client profiles
class ClientProfiles(BaseModel):
    clients: List[ClientProfile]

### Functions

In [None]:
def generate_data(chain_client_profiles, num_wings):
    all_data = []
    for i in range(num_wings):
        print(f'Generating data for wing{i+1}')
        result = chain_client_profiles.invoke({})
        if result is None or not hasattr(result, 'clients'):
            raise ValueError("No valid response received from the model.")
        data = [client.dict() for client in result.clients]
        all_data.extend(data)
    return pd.DataFrame(all_data)

In [None]:
def add_client_id(df):
    df['client_id'] = range(1, len(df) + 1)
    return df[['client_id', 'naam', 'type_dementie', 'somatiek', 'adl', 'mobiliteit', 'gedrag']]

In [None]:
def save_data(df, file_path):
    print(f"Data saved successfully to {file_path}.")
    df.to_csv(file_path, index=False)

In [None]:
def main(file_path, chain_client_profiles, num_wings):
    if os.path.exists(file_path):
        print("Data file found. Loading data...")
        return pd.read_csv(file_path)
    else:
        print("Data file not found. Generating new data...")

        with get_openai_callback() as cb:
            df = generate_data(chain_client_profiles=chain_client_profiles, num_wings=num_wings)
            print("Data generated successfully.\n")
            print(cb)

        df_with_id = add_client_id(df)
        save_data(df_with_id, file_path)
        return df_with_id


### Model initialization

In [None]:
model= ChatOpenAI(api_key=OPENAI_API_KEY, temperature=TEMP, model=MODEL_PROFILES)

In [None]:
pyd_parser = PydanticOutputParser(pydantic_object=ClientProfiles)

### Prompt template

In [None]:
PT_client_profiles = PromptTemplate(
    template = """
Schrijf acht profielen van cliënten die zijn opgenomen op een psychogeriatrische afdeling van het verpleeghuis. Hier wonen mensen met een gevorderde dementie met een hoge zorgzwaarte.
Zorg dat de profielen erg van elkaar verschillen.

{format_instructions}
""",
    input_variables=[],
    partial_variables={"format_instructions": pyd_parser.get_format_instructions()},
)

# Format the prompt for the example library
P_client_profiles = PT_client_profiles.format(profile="profile", scenario="scenario")
print(P_client_profiles)

In [None]:
chain_client_profiles = PT_client_profiles | model | pyd_parser

## Main workflow

In [None]:
if __name__ == "__main__":
    df = main(file_path=FN_PROFILES, chain_client_profiles=chain_client_profiles, num_wings=NUM_WINGS)

In [None]:
df.head(24)