In [2]:
# Validating responses from the API
# 100 samples for 10 requests each. 1000 samples in total

In [3]:
import os
import pandas as pd
import numpy as np
from openai import OpenAI
import textwrap
import datetime
import pytz
import json
import re


open_api_key = os.environ.get('OPEN_API_KEY')
client = OpenAI(api_key=open_api_key)

WORDWRAP_WIDTH = 100
DATA_FILE = "response_validation.json"
SAVE_FREQ = 3
N_REPEAT_RESPONSES = 3
N_SAMPLES = 3  # Replace X with the desired number of values to select
# 100 SAMPLES
# 10 times per sample

# Models
GPT4 = "gpt-4-0613"
GPT3 = "gpt-3.5-turbo-0125"
MODEL_NAME = GPT3

# Set the timezone to Eastern Time
TIMEZONE = pytz.timezone('US/Eastern')

SYS_PROMPT = """You are a physician with expertise in determining underlying causes of death in Sierra Leone by assigning ICD-10 codes for deaths using verbal autopsy narratives. Return only the ICD-10 code without description. E.g. A00 
If there are multiple ICD-10 codes, show one code per line."""

USR_PROMPT = """With the highest certainty, determine the underlying cause of death and provide the most accurate ICD-10 code for a verbal autopsy narrative of a AGE_VALUE_DEATH AGE_UNIT_DEATH old SEX_COD death in Sierra Leone: {open_narrative}"""

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [4]:
path_prefix = "../data_202402/"
merged_all_df = pd.DataFrame()

rounds = ['rd1', 'rd2']
age_groups = ['adult', 'child', 'neo']

for r in rounds:
    for a in age_groups:
        
        questionnaire_df =  pd.read_csv(f"{path_prefix}healsl_{r}_{a}_v1.csv")
        age_df =            pd.read_csv(f"{path_prefix}healsl_{r}_{a}_age_v1.csv")
        narrative_df =      pd.read_csv(f"{path_prefix}healsl_{r}_{a}_narrative_v1.csv")

        narrative_df = narrative_df.rename(columns={'summary': 'open_narrative'})
        
        # Merge the dataframes
        narrative_only = narrative_df[['rowid','open_narrative']]
        sex_only = questionnaire_df[['rowid','sex_cod']]
        age_only = age_df[['rowid','age_value_death','age_unit_death']]
        
        merged_df = narrative_only.merge(sex_only, on='rowid').merge(age_only, on='rowid')

        # Fill in missing values with empty string
        merged_df['sex_cod'] = merged_df['sex_cod'].fillna('')

        assert not merged_df.isnull().values.any(), "Execution halted: NaN values found in merged_df"

        print(f"round: {r.ljust(10)} age group: {a.ljust(10)} len: {str(merged_df.shape[0]).ljust(10)}")
        # print(f"Sample of merged_df {merged_df.shape}:")
        # display(merged_df.sample(5))
        
        merged_all_df = pd.concat([merged_all_df, merged_df])
        


  questionnaire_df =  pd.read_csv(f"{path_prefix}healsl_{r}_{a}_v1.csv")
  questionnaire_df =  pd.read_csv(f"{path_prefix}healsl_{r}_{a}_v1.csv")


round: rd1        age group: adult      len: 4987      
round: rd1        age group: child      len: 2998      
round: rd1        age group: neo        len: 585       
round: rd2        age group: adult      len: 2025      
round: rd2        age group: child      len: 1059      
round: rd2        age group: neo        len: 233       


In [3]:
random_rowids = merged_all_df.sample(N_SAMPLES)
print(random_rowids.rowid.tolist())
# print(random_rowids['rowids'])


[14005287, 14001013, 14003752]


In [4]:
# F(x): Initialize the data storage dictionary

def load_data(filename=DATA_FILE):
    
    if os.path.exists(filename):
        print(f"{filename} found. Loading data...")
        with open(filename, 'r') as file:
            data = json.load(file)
        return data
    else:
        print(f"{filename} not found. Initializing empty dictionary...")
        return {}

def save_data(data, filename=DATA_FILE):
    # Save data to a file    
    with open(filename, 'w') as file:
        json.dump(data, file)

In [5]:
# F(x): Send a message to the chatbot
def get_completion(
    messages: list[dict[str, str]],
    model: str = "gpt-3.5-turbo-0125",
    # model: str = "gpt-3.5-turbo-0125",
    # max_tokens=500,
    temperature=0,
    # stop=None,
    # seed=123,
    tools=None,
    logprobs=None,
    top_logprobs=None,
) -> str:

    params = {
        "model": model,
        # "response_format": { "type": "json_object" },
        "messages": messages,
        # "max_tokens": max_tokens,
        "temperature": temperature,
        # "stop": stop,
        # "seed": seed,
        "logprobs": logprobs,
        "top_logprobs": top_logprobs,
    }
    if tools:
        params["tools"] = tools

    completion = client.chat.completions.create(**params)
    return completion

In [6]:
# F(x): Extract ICD probabilities from tokens

def extract_icd_probabilities(logprobs):
    parsed_icd = []
    for pos in range(len(logprobs)):
        temp_df = pd.DataFrame(logprobs[pos: pos+4])
        temp_df = temp_df[temp_df[0].notna() & (temp_df[0].str.strip() != '')]
        temp_df = temp_df[temp_df[0].str.strip() != '\n']
        temp_concat = ''.join(temp_df.iloc[:, 0]).strip()
        if len(temp_concat) > 9:
            continue
        # pattern = r'^[A-Z]\d{0,4}(\.\d{0,4})?$'
        pattern_4part = r'^[A-Z]\d{0,4}(\.\d{1,4})?$'
        match = re.match(pattern_4part, temp_concat)

        if match:
            # print(f"{temp_concat} - valid ICD {np.round((np.exp(temp_df.iloc[:, 1]).mean())*100,2)}%")
            # print(f"**** {temp_concat} - VALID 2-parts ICD ****")
            parsed_icd.append((temp_concat, (np.exp(temp_df.iloc[:, 1]).mean())))
        else:
            # print(f"{temp_concat} - invalid 4-parts.")
            
            #trying 2-parts
            temp_df = pd.DataFrame(logprobs[pos: pos+2])
            temp_concat = ''.join(temp_df.iloc[:, 0]).strip()
            pattern_2part = r'^[A-Z]\d{1,4}$'
            match = re.match(pattern_2part, temp_concat)
            if match:
                # print(f"**** {temp_concat} - VALID 2-parts ICD ****")
                parsed_icd.append((temp_concat, (np.exp(temp_df.iloc[:, 1]).mean())))
            else:
                # print(f"{temp_concat} - invalid 2-parts.")
                pass
            pass

    return parsed_icd

In [7]:
import datetime
# Load existing data or initialize an empty dictionary
data_storage = load_data()
skipped_rows = []
repeated_skips = False
count = 0
print()

for rp in range(N_REPEAT_RESPONSES):
    for _, row in random_rowids.iterrows():
        # Access the values of each column in the current row
        # hijacking row 
        # row = merged_df[merged_df['rowid'] == 14005966].iloc[0]
        
        rowid = row['rowid']        
        u_rowid = f"{rowid}_{rp}"
        
        # # Check if rowid already processed. Testing both because json changes int keys to str    
        if (u_rowid) in data_storage or str(u_rowid) in data_storage:
            if repeated_skips:
                print("\r", end='', flush=True)
            print(f"Skipping count {count}, row {rowid} - Already processed.", end='', flush=True)
            repeated_skips = True
            skipped_rows.append(rowid)
            continue

        
        narrative = row['open_narrative']
        sex_cod = row['sex_cod']
        age_value_death = row['age_value_death']
        age_unit_death = row['age_unit_death']
        
        prompt = USR_PROMPT
        prompt = prompt.replace('AGE_VALUE_DEATH', str(age_value_death))
        prompt = prompt.replace('AGE_UNIT_DEATH', age_unit_death.lower())
        prompt = prompt.replace('SEX_COD', sex_cod.lower())
        prompt = prompt.format(open_narrative=narrative)
        
        # print("Prompt:")    
        # print(textwrap.fill(prompt, width=WORDWRAP_WIDTH))
        # print()
        
        # for a in range(5):
        completion = get_completion(
            [
                {"role": "system", "content": SYS_PROMPT},
                {"role": "user", "content": prompt}
            ] ,
            model=MODEL_NAME,
            logprobs=True,
            # top_logprobs=2,
        )

        count += 1
        
        # print(completion.choices[0].message)
        
        # for token in completion.choices[0].logprobs.content:
        #     print(f"{repr(str(token.token)).ljust(15)}  {str(token.logprob).ljust(20)} {np.round(np.exp(token.logprob)*100,2)}%")
            
        output_msg = completion.choices[0].message.content
        logprob_data = [(token.token, float(token.logprob)) for token in completion.choices[0].logprobs.content]
        usage_data = list(completion.usage)    
        current_time = datetime.datetime.now(tz=TIMEZONE).isoformat()
        
        data_storage[str(rowid)] = {
            'rowid': rowid,
            'u_rowid': u_rowid,
            'model': MODEL_NAME,
            'system_prompt': SYS_PROMPT,
            'user_prompt': prompt,
            'output_msg': output_msg,
            'logprobs': logprob_data,
            'usage': usage_data,
            'timestamp': current_time
        }

        # Save data periodically (you can adjust the frequency based on your needs)    
        if count % SAVE_FREQ == 0 and count > 0:
            if repeated_skips:
                print("\n", flush=True)
            repeated_skips = False
            
            save_data(data_storage)
            print(f"Saving count: {str(count).ljust(8)} Processing: {str(rowid).ljust(12)} Rows skipped: {len(skipped_rows)}", sep=' ', end='\r', flush=True)
            # break
        
    try:
        save_data(data_storage)
        print(f"Saving count: {str(count).ljust(8)} Processing: {str(rowid).ljust(12)} Rows skipped: {len(skipped_rows)}", sep=' ', end='\r', flush=True)
        print("\nData saved successfully.")
    except Exception as e:
        print(f"Error saving data: {e}")

if len(skipped_rows) > 0:
    print(f"DF length: {len(merged_df)}")
    print(f"Rows skipped: {len(skipped_rows)}")    

response_validation.json not found. Initializing empty dictionary...

Saving count: 3        Processing: 14003752     Rows skipped: 0
Data saved successfully.
Saving count: 6        Processing: 14003752     Rows skipped: 0
Data saved successfully.
Saving count: 9        Processing: 14003752     Rows skipped: 0
Data saved successfully.


In [19]:
data_storage = load_data()

data_storage.json found. Loading data...


In [20]:
df = pd.DataFrame(data_storage).T
df['icds'] = df.apply(lambda x: extract_icd_probabilities(x['logprobs']), axis=1)
df['best_icd'] = df.apply(lambda x: pd.DataFrame(x['icds']).sort_values(by=1, ascending=False).iloc[0,0], axis=1)


Exporting files

In [21]:
# df.to_json("data_storage_parsed.json", orient='records')
# df.to_csv("data_storage_parsed.csv", index=True)

Showing `output_msg` that exceeds ICD length

In [22]:
print(f"output_msg exceeding normal ICD length:")
df[df['output_msg'].apply(lambda x:len(x) > 10)][['output_msg','icds','best_icd']]

output_msg exceeding normal ICD length:


Unnamed: 0,output_msg,icds,best_icd
14002658,Malignant neoplasm of bone and articular carti...,"[(C40.9, 0.8579251792598699)]",C40.9
14004747,A09\nR50.9\nR11.0\nR63.4,"[(A09, 0.3801835598412292), (R50.9, 0.66909530...",R50.9
14006258,V89.2 (Person injured in collision between oth...,"[(V89.2, 0.8129130625381266)]",V89.2
14003822,T79.8\nR40.2,"[(T79.8, 0.6216546258358351), (R40.2, 0.681784...",R40.2
14007755,Tetanus: A33,"[(A33, 0.7714985657154876)]",A33
14002323,T79.3\nT88.9\nN17.9,"[(T79.3, 0.6682882892563279), (T88.9, 0.574929...",N17.9
14008510,V89.2 (Pedestrian injured in collision with ot...,"[(V89.2, 0.8549707965091256)]",V89.2
14007670,Tetanus: A33,"[(A33, 0.7297871458081743)]",A33
14008863,W17.89 - Other specified fall from one level t...,"[(W17.89, 0.7211488803081331)]",W17.89
14001730,Malaria\nB54,"[(B54, 0.9215047096956849), (B54, 0.9215047096...",B54


In [34]:
questionnaire_filename_list = [
    "../data_202402/healsl_rd1_adult_v1.csv",
    "../data_202402/healsl_rd1_child_v1.csv",
    "../data_202402/healsl_rd1_neo_v1.csv",
    "../data_202402/healsl_rd2_adult_v1.csv",
    "../data_202402/healsl_rd2_child_v1.csv",
    "../data_202402/healsl_rd2_neo_v1.csv",
]

q_rowids = set()

for filename in questionnaire_filename_list:
    temp_df = pd.read_csv(filename)
    q_rowids.update(temp_df['rowid'].tolist())

gpt_rowids = set(df['rowid'].tolist())


  temp_df = pd.read_csv(filename)
  temp_df = pd.read_csv(filename)


In [37]:
# ids that are in questionnaire but not in gpt
# spot checked a few, reason is that the id does not exist in the narrative file
q_rowids - gpt_rowids

{14000974,
 14002357,
 14002664,
 14003259,
 14003583,
 14003847,
 14004056,
 14004130,
 14004483,
 14004600,
 14004714,
 14005087,
 14005957,
 14006271,
 14006706,
 14007419,
 14007877,
 14008104,
 14008114,
 14008180,
 14008385,
 14008683,
 14009166,
 24000192,
 24000630,
 24001644,
 24002102,
 24002201,
 24002552,
 24002719,
 24003065,
 24003489,
 24003649}

In [46]:
len(gpt_rowids)

11887

In [45]:
len(q_rowids - gpt_rowids)

33

In [47]:
merged_df.sample(5)

Unnamed: 0,rowid,open_narrative,sex_cod,age_value_death,age_unit_death
217,24003743,"According to the respondent, a 1 day old neona...",Male,1,Days
166,24003091,A 3 days old new-born female baby died after 2...,Female,3,Days
163,24003022,A 0 day old newly born female baby was born de...,Female,0,Days
95,24000329,"According to the respondent, the deceased was ...",Female,1,Days
229,24002598,The deceased was a 0 old day female neonate wh...,Female,0,Days
