In [1]:
"""
This script loads the dataset compiled from the previous step and generate a responses from OpenAI's API for each record.

Pseudo code
-----------
1. Load results storage as array.
    1.1 If result storage is does not exist, create an empty array.
2. For each row in the dataframe:
    2.1 Check if rowid is in the result storage.
        2.1.1 If exists: record, skip the row.
    2.2 Compose the two prompts and generate a response using the OpenAI API.
    2.3 Store the response and other relevant information in the result storage.
    2.4 Save the result storage to a file periodically.
3. Save the result storage one last time.
    3.1 If rows were skipped, print a warning message.
    3.2 save the skipped rows to a file.


Details regarding #2.1.1 of the pseudo code:
--------------------------------------------

Since responses are billed by token consumption, we want to avoid reprocessing the same record. Previously processed 
are stored and saved using the unique identifier as key, and when the same storage file is loaded, the script checks
whether the unique identifier is already in the storage.

For accounting purposes, we store any unique identifier that was skipped in a separate file. in #3.2.

    
Details regarding #2.2 of the pseudo code:
------------------------------------------

We utilize the OpenAI API's Chat Completions to generate a response from the model. The parameters used are as follows:

    message: This is input text to be processed by the API. It is composed of two text prompts: 
        System Prompt: Provides the model context on its role and the expected output format. Same for all records. 
            We used the following system prompt:
            "You are a physician with expertise in determining underlying causes of death in Sierra Leone by assigning 
            ICD-10 codes for deaths using verbal autopsy narratives. Return only the ICD-10 code without description. 
            E.g. A00 If there are multiple ICD-10 codes, show one code per line"
            
        User Prompt: Specific instructions and individual record data.
            We used the following user prompt:
            "With the highest certainty, determine the underlying cause of death and provide the most accurate ICD-10 
            code for a verbal autopsy narrative of a AGE_VALUE_DEATH AGE_UNIT_DEATH old SEX_COD death in Sierra Leone: 
            {open_narrative}"
            
            AGE_VALUE_DEATH and AGE_UNIT_DEATH: replaced with age_value_death and age_unit_death values from the age dataset.
            SEX_COD: replaced with sex_cod value from the questionnaire dataset.
            open_narrative: replaced with summary value from the open narrative dataset.

    model: This parameter specifies the language model to be used. To strive for consistency and reproducibility of the 
    results, we used specific versions of the GPT-3 and GPT-4 models; gpt-3.5-turbo-0125 and gpt-4-0613, respectively.
    
    temperature: This parameter can be any value between 0 and 2. Low temperatures result in more deterministic responses, 
    while high temperatures result in more random responses. We set this to 0.

    lobprobs: This parameters controls whether the output includes the log probabilities of the tokens. We set it to True.

    AGE_VALUE_DEATH and AGE_UNIT_DEATH: replaced with age_value_death and age_unit_death values from the age dataset.
    SEX_COD: replaced with sex_cod value from the questionnaire dataset.
    open_narrative: replaced with summary value from the open narrative dataset.


Details regarding #2.3 of the pseudo code:
------------------------------------------

Below is the data structure of the output data:

    record[uid] = {
        'uid': the unique identifier for the dataset  --> change to uid
        'rowid': original rowid,
        'param_model': model used
        'param_temperature': temperature used
        'param_logprobs': logprobs used
        'param_system_prompt': system prompt used
        'param_user_prompt': user prompt used
        # 'output_msg': output text message from the API // drop
        # 'output_logprobs': the log probabilities of the tokens in the output message // drop
        # 'output_usage': token consumption // drop
        'output_timestamp': current timestamp
        # 'other_columns': extra columns that were not recognized by the script. Can be useful for debugging or incorporating additional information.
        'output': serialized response from the Chat Completions API
        ... : passthough columns not required by the script but included in the original dataset
    }
    
Duration
- GPT4 took 686m 31s to process all data.
- observed API stuck for more than 7 mins
    
"""
pass



In [2]:
import os
import pandas as pd
import numpy as np
from openai import OpenAI
import textwrap
import datetime
import pytz
import json
import re


open_api_key = os.environ.get('OPEN_API_KEY')
client = OpenAI(api_key=open_api_key)

DEMO_MODE = False        # If set to True, the script will only process a subset of the data
DEMO_RANDOM = False     # If set to True, the script will process a random subset of the data
DEMO_SIZE_LIMIT = 10    # Size of the demo


# Discard columns that are not recognized by the script in the output file.
# Keep this to False if you want the output file to retain columns needed for post-processing purposes.
DROP_EXCESS_COLUMNS = False

INPUT_DATASET_FILE = "healsl_dataset_all_240309_040141.csv"     # Input file# Input datasets

OUTPUT_DATA_FILE = "all_data_gpt4_0313.json"                         # Output file



# Set the timezone to Eastern Time
TIMEZONE = pytz.timezone('US/Eastern')

WORDWRAP_WIDTH = 100

# How often the temp storage is saved to disk
SAVE_FREQ = 5

# Models
GPT4 = "gpt-4-0613"
GPT3 = "gpt-3.5-turbo-0125"
MODEL_NAME = GPT4
TEMPERATURE = 0
LOGPROBS = True



SYS_PROMPT = """You are a physician with expertise in determining underlying causes of death in Sierra Leone by assigning the most probable ICD-10 code for each death using verbal autopsy narratives. Return only the ICD-10 code without description. E.g. A00 
If there are multiple ICD-10 codes, show one code per line."""

USR_PROMPT = """Determine the underlying cause of death and provide the most probable ICD-10 code for a verbal autopsy narrative of a AGE_VALUE_DEATH AGE_UNIT_DEATH old SEX_COD death in Sierra Leone: {open_narrative}"""

In [3]:
# Variables
dataset_passed = True
required_colnames = ['uid', 'rowid', 'age_value_death', 'age_unit_death',
                     'open_narrative', 'sex_cod']

# Load dataset
merged_df = pd.read_csv(INPUT_DATASET_FILE)


# Check all required columns are in the df
for colnames in required_colnames:
    if colnames not in merged_df.columns:
        print(f"Missing column \"{colnames}\"")
        dataset_passed = False

if not dataset_passed:
    print("Please ensure dataset has all the required columns.")
    raise ValueError(f"Error: Missing columns required for processing.")
    
# Get columns names that are not required
extra_colnames = ['uid']
extra_colnames += [colname for colname in merged_df.columns if colname not in required_colnames]

# Transform non-required columns as dictionary in a new column
# merged_df['model_residual_columns'] = merged_df[extra_colnames].apply(lambda x: x.to_dict(), axis=1)

In [4]:
# When DEMO_MODE is set to True, process only a subset of the data
if DEMO_MODE:
    limit_records = int(DEMO_SIZE_LIMIT)
    merged_df.sort_values(by='uid', inplace=True)
    print(f"DEMO MODE: Processing only {limit_records} records.")
    if DEMO_RANDOM:
        merged_df = merged_df.sample(limit_records)
    else:
        merged_df = merged_df.head(limit_records)


In [5]:
# F(x): Misc. functions

# Function to get current time in string YYMMDD_HHMMSS format
def get_current_str_time():
    return datetime.datetime.now(tz=TIMEZONE).strftime("%y%m%d_%H%M%S")

# Used to convert ChatCompletion object to a dictionary, recursively
def recursive_dict(obj):
    if isinstance(obj, dict):
        return {k: recursive_dict(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [recursive_dict(v) for v in obj]
    elif hasattr(obj, '__dict__'):
        return recursive_dict(obj.__dict__)
    else:
        return obj

In [6]:
# F(x): Initialize the data storage dictionary

def load_data(filename=OUTPUT_DATA_FILE):
    
    if os.path.exists(filename):
        print(f"{filename} found. Loading data...")
        with open(filename, 'r') as file:
            data = json.load(file)
        return data
    else:
        print(f"{filename} not found. Initializing empty dictionary...")
        return {}

def save_data(data, filename=OUTPUT_DATA_FILE):           
    # Save data to a file   
    with open(filename, 'w') as file:
        json.dump(data, file)

In [7]:
# F(x): Send a message to the chatbot
def get_completion(
    messages: list[dict[str, str]],
    model: str = "gpt-3.5-turbo-0125",
    # model: str = "gpt-3.5-turbo-0125",
    # max_tokens=500,
    temperature=0,
    # stop=None,
    # seed=123,
    tools=None,
    logprobs=None,
    top_logprobs=None,
) -> str:

    params = {
        "model": model,
        # "response_format": { "type": "json_object" },
        "messages": messages,
        # "max_tokens": max_tokens,
        "temperature": temperature,
        # "stop": stop,
        # "seed": seed,
        "logprobs": logprobs,
        "top_logprobs": top_logprobs,
    }
    if tools:
        params["tools"] = tools

    completion = client.chat.completions.create(**params)
    return completion

In [12]:
# Load existing data or initialize an empty dictionary
data_storage = load_data()
skipped_rows = []
repeated_skips = False
print()


for index, row in merged_df.iterrows():
    # Access the values of each column in the current row
    # hijacking row 
    # row = merged_df[merged_df['rowid'] == 14005966].iloc[0]
    
    uid = row['uid']    
    
    # Check if rowid already processed. Testing both because json changes int keys to str    
    if (uid) in data_storage or str(uid) in data_storage:
        # if repeated_skips:
        #     print("\r", end='', flush=True)
        # print(f"Skipping index {index}, row {uid} - Already processed.", end='', flush=True)
        repeated_skips = True
        skipped_rows.append(uid)
        continue

    rowid = row['rowid']
    narrative = row['open_narrative']
    sex_cod = row['sex_cod']
    age_value_death = row['age_value_death']
    age_unit_death = row['age_unit_death']
    other_columns = row[extra_colnames].to_dict()
    
    prompt = USR_PROMPT
    prompt = prompt.replace('AGE_VALUE_DEATH', str(age_value_death))
    prompt = prompt.replace('AGE_UNIT_DEATH', age_unit_death.lower())
    prompt = prompt.replace('SEX_COD', str(sex_cod).lower())
    prompt = prompt.format(open_narrative=narrative)
    
    # print("Prompt:")    
    # print(textwrap.fill(prompt, width=WORDWRAP_WIDTH))
    # print()
    
    # for a in range(5):
    completion = get_completion(
        [
            {"role": "system", "content": SYS_PROMPT},
            {"role": "user", "content": prompt}
        ] ,
        model=MODEL_NAME,
        logprobs=LOGPROBS,
        temperature=TEMPERATURE,
        # top_logprobs=2,
    )
    
    # print(completion.choices[0].message)
    
    # for token in completion.choices[0].logprobs.content:
    #     print(f"{repr(str(token.token)).ljust(15)}  {str(token.logprob).ljust(20)} {np.round(np.exp(token.logprob)*100,2)}%")
        
    output_msg = completion.choices[0].message.content
    logprob_data = [(token.token, float(token.logprob)) for token in completion.choices[0].logprobs.content]
    usage_data = list(completion.usage)    
    current_time = datetime.datetime.now(tz=TIMEZONE).isoformat()
       
    data_storage[str(uid)] = {
        'uid': uid,               # 'uid' is the unique identifier for the dataset
        'rowid': rowid,
        'param_model': MODEL_NAME,
        'param_temperature': 0,
        'param_logprobs': True,
        'param_system_prompt': SYS_PROMPT,
        'param_user_prompt': prompt,
        # 'output_msg': output_msg,
        # 'logprobs': logprob_data,
        # 'usage': usage_data,
        'timestamp': current_time,
        # 'other_columns': other_columns,
        'output': recursive_dict(completion),
    }

    if not DROP_EXCESS_COLUMNS:
        # Append columns not required by the script but exists on the original dataset
        data_storage[str(uid)].update(other_columns)

    # Save data periodically (you can adjust the frequency based on your needs)    
    if index % SAVE_FREQ == 0 and index > 0:
        if repeated_skips:
            print("\n", flush=True)
        repeated_skips = False
        
        save_data(data_storage)
        print(f"Saving index: {str(index).ljust(8)} Processing: {str(uid).ljust(12)} Rows skipped: {len(skipped_rows)}", sep=' ', end='\r', flush=True)
        # break
    
try:
    save_data(data_storage)
    print(f"Saving index: {str(index).ljust(8)} Processing: {str(uid).ljust(12)} Rows skipped: {len(skipped_rows)}", sep=' ', end='\r', flush=True)
    print("\nData saved successfully. Processing Complete.")
except Exception as e:
    print(f"Error saving data: {e}")

if len(skipped_rows) > 0:
    print(f"{len(skipped_rows)} rows skipped. Check skipped_rows for details.")
    
    # Write skipped rows to a file
    with open(f"Step02_process_openai_report_{get_current_str_time()}.txt", "w") as file:
        file.write(f"The follow rows are skipped because they were already processed.\n")
        for item in skipped_rows:        
            file.write(f"{str(item)}\n")

all_data_gpt4_0313.json found. Loading data...



Saving index: 1025     Processing: 14000757     Rows skipped: 91

Saving index: 1375     Processing: 14003545     Rows skipped: 92

Saving index: 2345     Processing: 14001953     Rows skipped: 93

Saving index: 3710     Processing: 14000093     Rows skipped: 94

Saving index: 4915     Processing: 14001076     Rows skipped: 95

Saving index: 5700     Processing: 14006579     Rows skipped: 96

Saving index: 5960     Processing: 14001581     Rows skipped: 97

Saving index: 6990     Processing: 14000596     Rows skipped: 98

Saving index: 7320     Processing: 14004826     Rows skipped: 99

Saving index: 8330     Processing: 14001654     Rows skipped: 100

Saving index: 11886    Processing: 24001069     Rows skipped: 101
Data saved successfully. Processing Complete.
101 rows skipped. Check skipped_rows for details.


In [11]:
str(sex_cod)

'nan'