In [1]:
# Validating responses from the API
# 100 samples for 10 requests each. 1000 samples in total

In [2]:
import os
import pandas as pd
import numpy as np
from openai import OpenAI
import textwrap
import datetime
import pytz
import json
import re


open_api_key = os.environ.get('OPEN_API_KEY')
client = OpenAI(api_key=open_api_key)

WORDWRAP_WIDTH = 100
DATA_FILE = "response_validation_data_storage.json"
SAVE_FREQ = 5
N_REPEAT_RESPONSES = 10
N_SAMPLES = 100  # Replace X with the desired number of values to select
# 100 SAMPLES
# 10 times per sample

# Models
GPT4 = "gpt-4-0613"
GPT3 = "gpt-3.5-turbo-0125"
MODEL_NAME = GPT3

# Set the timezone to Eastern Time
TIMEZONE = pytz.timezone('US/Eastern')

SYS_PROMPT = """You are a physician with expertise in determining underlying causes of death in Sierra Leone by assigning ICD-10 codes for deaths using verbal autopsy narratives. Return only the ICD-10 code without description. E.g. A00 
If there are multiple ICD-10 codes, show one code per line."""

USR_PROMPT = """With the highest certainty, determine the underlying cause of death and provide the most accurate ICD-10 code for a verbal autopsy narrative of a AGE_VALUE_DEATH AGE_UNIT_DEATH old SEX_COD death in Sierra Leone: {open_narrative}"""

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [3]:
path_prefix = "../data_202402/"
merged_all_df = pd.DataFrame()

rounds = ['rd1', 'rd2']
age_groups = ['adult', 'child', 'neo']

for r in rounds:
    for a in age_groups:
        
        questionnaire_df =  pd.read_csv(f"{path_prefix}healsl_{r}_{a}_v1.csv")
        age_df =            pd.read_csv(f"{path_prefix}healsl_{r}_{a}_age_v1.csv")
        narrative_df =      pd.read_csv(f"{path_prefix}healsl_{r}_{a}_narrative_v1.csv")

        narrative_df = narrative_df.rename(columns={'summary': 'open_narrative'})
        
        # Merge the dataframes
        narrative_only = narrative_df[['rowid','open_narrative']]
        sex_only = questionnaire_df[['rowid','sex_cod']]
        age_only = age_df[['rowid','age_value_death','age_unit_death']]
        
        merged_df = narrative_only.merge(sex_only, on='rowid').merge(age_only, on='rowid')

        # Fill in missing values with empty string
        merged_df['sex_cod'] = merged_df['sex_cod'].fillna('')
        
        merged_df['group'] = f"{a}_{r}"

        assert not merged_df.isnull().values.any(), "Execution halted: NaN values found in merged_df"

        print(f"round: {r.ljust(10)} age group: {a.ljust(10)} len: {str(merged_df.shape[0]).ljust(10)}")
        # print(f"Sample of merged_df {merged_df.shape}:")
        # display(merged_df.sample(5))
        
        merged_all_df = pd.concat([merged_all_df, merged_df])
        


  questionnaire_df =  pd.read_csv(f"{path_prefix}healsl_{r}_{a}_v1.csv")
  questionnaire_df =  pd.read_csv(f"{path_prefix}healsl_{r}_{a}_v1.csv")


round: rd1        age group: adult      len: 4987      
round: rd1        age group: child      len: 2998      
round: rd1        age group: neo        len: 585       
round: rd2        age group: adult      len: 2025      
round: rd2        age group: child      len: 1059      
round: rd2        age group: neo        len: 233       


In [4]:
sampling_frac = ((merged_all_df.value_counts('group') / len(merged_all_df)) * N_SAMPLES).round(0).astype(int).to_dict()
sample_ids = {}

for sample in sampling_frac:
    sample_ids[sample] = merged_all_df[merged_all_df['group'] == sample].sample(sampling_frac[sample], random_state=1).rowid.tolist()    
    print(f"{sample}: {sampling_frac[sample]} records")

adult_rd1: 42 records
child_rd1: 25 records
adult_rd2: 17 records
child_rd2: 9 records
neo_rd1: 5 records
neo_rd2: 2 records


In [5]:
sorted_sample_ids = dict(sorted(sample_ids.items(), key=lambda item: len(item[1]), reverse=True))

# Get the actual samples count
sample_values_count = len([item for subitem in sorted_sample_ids.values() for item in subitem])

# If the sample values count is greater than 100, remove the excess samples
if sample_values_count > N_SAMPLES: 
    excess = sample_values_count - N_SAMPLES
    print(f"There are more than {N_SAMPLES} samples. Removing excess samples.")
    
    # Iterate through the dictionary and remove the excess samples
    for _ in range(excess):
        for key in sorted_sample_ids:
            
            # If list has sufficient samples, remove the last item
            if len(sorted_sample_ids[key]) > 10:
                sorted_sample_ids[key].pop()
                break
else:
    print(f"There are {sample_values_count} samples. No need to remove any samples.")
    
# Flatten the dictionary
sample_ids_list = [item for sublist in sorted_sample_ids.values() for item in sublist]

# Compile the final dataframe based on the sample_ids_list
random_rowids = merged_all_df[merged_all_df['rowid'].isin(sample_ids_list)]

There are 100 samples. No need to remove any samples.


In [6]:
# F(x): Initialize the data storage dictionary

def load_data(filename=DATA_FILE):
    
    if os.path.exists(filename):
        print(f"{filename} found. Loading data...")
        with open(filename, 'r') as file:
            data = json.load(file)
        return data
    else:
        print(f"{filename} not found. Initializing empty dictionary...")
        return {}

def save_data(data, filename=DATA_FILE):
    # Save data to a file    
    with open(filename, 'w') as file:
        json.dump(data, file)

In [7]:
# F(x): Send a message to the chatbot
def get_completion(
    messages: list[dict[str, str]],
    model: str = "gpt-3.5-turbo-0125",
    # model: str = "gpt-3.5-turbo-0125",
    # max_tokens=500,
    temperature=0,
    # stop=None,
    # seed=123,
    tools=None,
    logprobs=None,
    top_logprobs=None,
) -> str:

    params = {
        "model": model,
        # "response_format": { "type": "json_object" },
        "messages": messages,
        # "max_tokens": max_tokens,
        "temperature": temperature,
        # "stop": stop,
        # "seed": seed,
        "logprobs": logprobs,
        "top_logprobs": top_logprobs,
    }
    if tools:
        params["tools"] = tools

    completion = client.chat.completions.create(**params)
    return completion

In [8]:
import datetime
# Load existing data or initialize an empty dictionary
data_storage = load_data()
skipped_rows = []
repeated_skips = False
count = 0
print()

for rp in range(N_REPEAT_RESPONSES):
    for _, row in random_rowids.iterrows():
        # Access the values of each column in the current row
        # hijacking row 
        # row = merged_df[merged_df['rowid'] == 14005966].iloc[0]
        
        rowid = row['rowid']        
        u_rowid = f"{rowid}_{rp}"
        
        # # Check if rowid already processed. Testing both because json changes int keys to str    
        if (u_rowid) in data_storage or str(u_rowid) in data_storage:
            if repeated_skips:
                print("\r", end='', flush=True)
            print(f"Skipping count {count}, row {rowid} - Already processed.", end='', flush=True)
            repeated_skips = True
            skipped_rows.append(rowid)
            continue

        
        narrative = row['open_narrative']
        sex_cod = row['sex_cod']
        age_value_death = row['age_value_death']
        age_unit_death = row['age_unit_death']
        
        prompt = USR_PROMPT
        prompt = prompt.replace('AGE_VALUE_DEATH', str(age_value_death))
        prompt = prompt.replace('AGE_UNIT_DEATH', age_unit_death.lower())
        prompt = prompt.replace('SEX_COD', sex_cod.lower())
        prompt = prompt.format(open_narrative=narrative)
        
        # print("Prompt:")    
        # print(textwrap.fill(prompt, width=WORDWRAP_WIDTH))
        # print()
        
        # for a in range(5):
        completion = get_completion(
            [
                {"role": "system", "content": SYS_PROMPT},
                {"role": "user", "content": prompt}
            ] ,
            model=MODEL_NAME,
            logprobs=True,
            # top_logprobs=2,
        )

        count += 1
        
        # print(completion.choices[0].message)
        
        # for token in completion.choices[0].logprobs.content:
        #     print(f"{repr(str(token.token)).ljust(15)}  {str(token.logprob).ljust(20)} {np.round(np.exp(token.logprob)*100,2)}%")
            
        output_msg = completion.choices[0].message.content
        logprob_data = [(token.token, float(token.logprob)) for token in completion.choices[0].logprobs.content]
        usage_data = list(completion.usage)    
        current_time = datetime.datetime.now(tz=TIMEZONE).isoformat()
        
        data_storage[str(u_rowid)] = {
            'rowid': rowid,
            'u_rowid': u_rowid,
            'model': MODEL_NAME,
            'system_prompt': SYS_PROMPT,
            'user_prompt': prompt,
            'output_msg': output_msg,
            'logprobs': logprob_data,
            'usage': usage_data,
            'timestamp': current_time
        }

        # Save data periodically (you can adjust the frequency based on your needs)    
        if count % SAVE_FREQ == 0 and count > 0:
            if repeated_skips:
                print("\n", flush=True)
            repeated_skips = False
            
            save_data(data_storage)
            print(f"Saving count: {str(count).ljust(8)} Processing: {str(rowid).ljust(12)} Rows skipped: {len(skipped_rows)}", sep=' ', end='\r', flush=True)
            # break
        
    try:
        save_data(data_storage)
        print(f"Saving count: {str(count).ljust(8)} Processing: {str(rowid).ljust(12)} Rows skipped: {len(skipped_rows)}", sep=' ', end='\r', flush=True)
        # print("\nData saved successfully.")
    except Exception as e:
        print(f"Error saving data: {e}")

if len(skipped_rows) > 0:
    print(f"DF length: {len(merged_df)}")
    print(f"Rows skipped: {len(skipped_rows)}")    

response_validation_data_storage.json not found. Initializing empty dictionary...

Saving count: 1000     Processing: 24000133     Rows skipped: 0