<a href="https://colab.research.google.com/github/ekrombouts/GenCareAI/blob/main/notebooks/100_note_generation/140_GenerateClientRecords.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GenCare AI: Generating client records

## Info


**Author:** Eva Rombouts  
**Date:** 2024-06-15  
**Updated:** 2024-10-10  
**Version:** 2.0

### Description
This notebook generates synthetic care records for fictional clients in nursing homes. It uses client profiles and weekly scenarios created in earlier notebooks. Example notes are pulled from a Chroma vector database and filtered based on gender to avoid mismatches in pronouns. A GPT-based model then generates 21 care notes per week for each client. The goal is to create diverse data that can be used for analysis or machine learning tasks.

## Setup

In [None]:
!pip install -U GenCareAI
from GenCareAI.GenCareAIUtils import GenCareAISetup, ClientProfileFormatter

setup = GenCareAISetup()

if setup.environment == 'Colab':
        !pip install -q langchain langchain-openai langchain-community langchain-chroma

In [20]:
# Imports
import os
import random
import pandas as pd
from typing import List

from langchain.output_parsers import PydanticOutputParser
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.prompts import PromptTemplate
from pydantic import BaseModel, Field
from langchain_community.callbacks import get_openai_callback

from tqdm.notebook import tqdm

In [21]:
# Paths to various data files and constants for the model and temperature settings.
ward_name = 'Dahlia' # Make sure to use the same ward-name for which clientprofiles and -scenarios have been generated
path_profiles = setup.get_file_path(f'data/gcai_client_profiles_{ward_name}.csv')
path_scenarios = setup.get_file_path(f'data/gcai_client_scenarios_{ward_name}.csv')
path_notes = setup.get_file_path(f'data/gcai_client_notes_{ward_name}.csv')

# Path to the Chroma vector database
path_db_gcai = setup.get_file_path('data/chroma_db_gcai')
collection_name = 'Gardenia'
num_examples = 3

model_notes = 'gpt-3.5-turbo-0125'
model_embeddings = 'text-embedding-ada-002'
temp_notes = 1.1

verbose = False # Toggle for printing debug information
sample_client_id = 3
sample_weekno = 2
sep_line = '-' * 100

### Load data

In [22]:
# Load scenarios and profiles from CSV files
df_scenarios = pd.read_csv(path_scenarios)
df_profiles = pd.read_csv(path_profiles)

if verbose:
    print(df_profiles.info())
    print(df_scenarios.info())

if verbose:
    # Select the profile row for the sample client
    sample_profile_row = df_profiles.loc[df_profiles['client_id'] == sample_client_id].iloc[0]

    # Select the scenario row for the sample client and scenario
    sample_scenario_row = df_scenarios.loc[
        (df_scenarios['client_id'] == sample_client_id) &
        (df_scenarios['week'] == sample_weekno)
    ].iloc[0]

In [23]:
# Format and print the profile and scenario for the sample client
if verbose:
    cpf = ClientProfileFormatter()
    sample_profile = cpf.format_client_profile(
        profile_row=sample_profile_row
    )
    print(sample_profile)
    print(sep_line)
    sample_scenario = sample_scenario_row['events_description']
    print(sample_scenario)

## Setting up the example library

In [24]:
# Retrieving notes for the example library
# Initialize Chroma vector database
vectordb = Chroma(
    persist_directory=path_db_gcai,
    embedding_function=OpenAIEmbeddings(api_key=setup.get_openai_key(), model=model_embeddings),
    collection_name = collection_name
    )

# Define a retriever
retriever = vectordb.as_retriever(search_kwargs={"k": 40})

# Function to retrieve example notes
def retrieve_examples (profile, retriever):
    text = profile
    example_library = []
    documents = retriever.invoke(text)
    for document in documents:
        example_library.append(document.page_content.strip('"'))

    return example_library

# Retrieve example library based on the formatted profile
if verbose:
    example_library = retrieve_examples(
        profile=cpf.format_client_profile(
            profile_row=sample_profile_row,
            display_diagnosis=False
        ),
        retriever=retriever)
    for i in example_library[0:5]:
        print(f"- {i}")


The example notes often contain gender-specific pronouns or titles. When these are included in the prompt the model tends to generate responses with incorrect pronouns for our client. Therefore, we need to filter the example library to exclude notes that use ‘mr’ or ‘mrs’.

In [25]:
def filter_notes_by_gender(notes, gender):
    """
    Filters example notes based on the specified gender to ensure relevance.

    Args:
    notes (list): List of example notes to filter.
    gender (str): Gender to filter for ('male', 'female', or 'unknown').

    Returns:
    list: Filtered list of example notes.
    """
    if gender == 'male':
        gender_words = ['mw', 'mevr', 'mvr', 'mevrouw']
    elif gender == 'female':
        gender_words = ['dhr', 'meneer']
    else:
        return notes  # No filtering for unknown gender

    def contains_gender_words(note):
        return any(gender_word in note.lower() for gender_word in gender_words)

    return [note for note in notes if not contains_gender_words(note)]

if verbose:
    ex_lib_gender_filtered = filter_notes_by_gender(
        notes=example_library,
        gender=cpf.determine_client_gender(
            profile_row=sample_profile_row))

    [print('- '+item) for item in random.sample(ex_lib_gender_filtered, 5)]

    print(f"\nOorspronkelijk aantal voorbeelden: {len(example_library)}")
    print(f"Gefilterd aantal voorbeelden: {len(ex_lib_gender_filtered)}")

In [26]:
def sample_and_format_example_library(example_library, num_items=3):
    """
    Selects a random set of examples from the example library and returns them as a bulleted string.

    Args:
    example_library (list): List of example notes.
    num_items (int): Number of random items to select.

    Returns:
    str: A formatted string of the randomly selected notes.
    """
    random_items = random.sample(example_library, num_items)
    return '\n'.join(['- ' + item for item in random_items])

if verbose:
    ex_lib_sample = sample_and_format_example_library(ex_lib_gender_filtered, 5)
    print(ex_lib_sample)

## Generating client notes

In [27]:
#Structure for a single care note
class CareNote(BaseModel):
    dag: int = Field(description="volgnummer dag")
    tijd: str = Field(description="tijd van de rapportage (hh:mm)")
    rapportage: str = Field(description="Inhoud van de rapportage. Een rapportage beschrijft over het algemeen één zorgaspect, soms meer")

# Structure for multiple notes
class CareNotes(BaseModel):
    notes: List[CareNote] = Field(description="Rapportages voor een week, drie per dag. Totaal 21 rapportages")

In [28]:
# Define the OpenAI model to generate notes
notes_model = ChatOpenAI(
    api_key=setup.get_openai_key(),
    temperature=temp_notes,
    model=model_notes,
    max_tokens=2048)

# Parser for model output
notes_parser = PydanticOutputParser(pydantic_object=CareNotes)
notes_format_instructions = notes_parser.get_format_instructions()
if verbose:
    print(notes_format_instructions)

### Note generation prompts & chain

In [29]:
notes_template = """**SCENARIO**
{scenario}
**EINDE SCENARIO**

Schrijf rapportages op basis van dit scenario van een fictieve client die verblijft op een psychogeriatrische afdeling van het verpleeghuis. De rapportages staan op zich, maar zorg voor een subtiele en geleidelijke opbouw zodat ze samen het verhaal van de client vertellen.

Schrijf rapportages voor een week (7 dagen). Per dag worden drie rapportages geschreven, dus er zijn 21 rapportages totaal.

Profiel van de client:
{profile}

Instructies:
- **Volg het scenario nauwkeurig**
- Vermijd het noemen van de naam van de client.
- Je spreekt de taal van een niveau 3 zorgmedewerker (Verzorgende IG).  Varieer in zinsopbouw en stijl, soms zijn de rapportages langer en meer gedetailleerd.
- Zorg dat de beschreven zorg realistisch is. Een fysiotherapeut komt niet elke dag langs, bijvoorbeeld

Voorbeeld rapportages (herhaal deze niet, maar gebruik ze als leidraad):
{examples}

{format_instructions}
"""

# Define the prompt template with variables and partial instructions
notes_prompt_template = PromptTemplate(
    template=notes_template,
    input_variables=["scenario", "profile", "examples"],
    partial_variables={"format_instructions": notes_format_instructions},
)

if verbose:
    notes_prompt = notes_prompt_template.format(
        scenario = sample_scenario,
        profile = sample_profile,
        examples = sample_and_format_example_library(ex_lib_gender_filtered,3),
        )
    print(notes_prompt)

In [30]:
# Define chain for generating notes
notes_chain = notes_prompt_template | notes_model | notes_parser

# Function to generate care notes for a client
def generate_notes(profile, scenario, examples):
    try:
        notes = notes_chain.invoke({
            "scenario": scenario,
            "profile": profile,
            "examples": examples,
        })
    except Exception as e:
        # Try once more in case of failure
        print(f"Error in generating notes, retrying...")
        notes = notes_chain.invoke({
            "scenario": scenario,
            "profile": profile,
            "examples": examples,
        })
        print("Retry successful")

    return notes.notes # Return the list of generated notes


In [31]:
def process_client(df_profile_row, df_scenarios):
    cpf = ClientProfileFormatter()
    client_id = df_profile_row['client_id']

    # Format the client profile for the prompt
    profile = cpf.format_client_profile(df_profile_row, display_name=False)
    client_gender = cpf.determine_client_gender(df_profile_row)

    # Initialize a list to store generated notes with metadata
    ct_notes = []

    # Set up the retriever
    retriever = vectordb.as_retriever(search_kwargs={"k": 20})

    # Loop over the scenarios
    with get_openai_callback() as cb:
        for idx, scenario_row in tqdm(df_scenarios.iterrows(), total=df_scenarios.shape[0], desc="Processing Scenarios"):
            # Retrieve the current scenario description
            scenario = scenario_row['events_description']
            scenario_id = scenario_row['week']

            # Generate example notes based on the client profile and scenario
            example_library = retrieve_examples(
                profile=profile,
                retriever=retriever
            )
            ex_lib_gender_filtered = filter_notes_by_gender(
                notes=example_library,
                gender=client_gender
            )
            ex_lib_sample = sample_and_format_example_library(
                example_library=ex_lib_gender_filtered,
                num_items=num_examples
            )

            # Generate care notes for the current scenario
            result_notes = generate_notes(
                scenario=scenario,
                profile=profile,
                examples=ex_lib_sample
            )

            # Add generated notes to the notes list along with metadata
            for note in result_notes:
                ct_notes.append({
                    "client_id": client_id,
                    "weekno": scenario_id,
                    "dag": note.dag,
                    "tijd": note.tijd,
                    "rapportage": note.rapportage,
                })

        print(f"Cost for client {client_id}: {cb.total_cost}")

    return ct_notes

In [32]:
def process_clients(df_profiles, df_scenarios, output_path):
    # Check if output_path exists
    if os.path.exists(output_path):
        # Read in existing data
        existing_notes_df = pd.read_csv(output_path)
    else:
        existing_notes_df = pd.DataFrame()

    # Loop through each client profile
    for _, df_profile_row in tqdm(df_profiles.iterrows(), total=df_profiles.shape[0], desc="Processing Clients"):
        client_id = df_profile_row['client_id']

        # Check which weeks are already processed for this client
        if not existing_notes_df.empty:
            processed_weeks = set(existing_notes_df[existing_notes_df['client_id'] == client_id]['weekno'].unique())
        else:
            processed_weeks = set()

        # Get the scenarios for this client that have not yet been processed
        client_scenarios = df_scenarios[df_scenarios['client_id'] == client_id]
        unprocessed_scenarios = client_scenarios[~client_scenarios['week'].isin(processed_weeks)]

        # If all scenarios are processed, skip to next client
        if unprocessed_scenarios.empty:
            print(f"Client {client_id} is already fully processed. Skipping.")
            continue

        # Generate notes for the current client
        client_notes = process_client(
            df_profile_row=df_profile_row,
            df_scenarios=unprocessed_scenarios,  # Pass unprocessed scenarios only
        )

        # Convert client_notes to DataFrame
        client_notes_df = pd.DataFrame(client_notes)

        # Append new notes to existing_notes_df
        if not existing_notes_df.empty:
            existing_notes_df = pd.concat([existing_notes_df, client_notes_df], ignore_index=True)
        else:
            existing_notes_df = client_notes_df

        # Save the combined notes to a CSV file after each client
        existing_notes_df.to_csv(output_path, index=False)
        print(f"Data saved to {output_path} after processing client {client_id}")

In [None]:
process_clients(
    df_profiles=df_profiles,
    df_scenarios=df_scenarios,
    output_path=path_notes,
)