In [2]:
import os
import pandas as pd
import numpy as np
from openai import OpenAI
import textwrap
import datetime
import pytz
import json
import re


open_api_key = os.environ.get('OPEN_API_KEY')
client = OpenAI(api_key=open_api_key)

WORDWRAP_WIDTH = 100
DATA_FILE = "data_storage.json"
SAVE_FREQ = 5

# Models
GPT4 = "gpt-4-0613"
GPT3 = "gpt-3.5-turbo-0125"
MODEL_NAME = GPT3

# Set the timezone to Eastern Time
TIMEZONE = pytz.timezone('US/Eastern')



# SYS_PROMPT = """You are a physician with expertise in determining underlying causes of death in Sierra Leone 
# by assigning ICD-10 codes for deaths using verbal autopsy narratives. Return only the ICD-10 code in JSON format: {“icd10”: [code1, code2, code3, code4, code5]}"""

# SYS_PROMPT = """You are a physician with expertise in determining underlying causes of death in Sierra Leone 
# by assigning ICD-10 codes for deaths using verbal autopsy narratives. Return only the ICD-10 code in JSON format: {“icd10”: code1}"""

# USR_PROMPT = """Determine the underlying cause of death and provide an ICD-10 code for a verbal autopsy narrative
# of a AGE_VALUE_DEATH AGE_UNIT_DEATH old SEX_COD death in Sierra Leone: {open_narrative}"""

SYS_PROMPT = """You are a physician with expertise in determining underlying causes of death in Sierra Leone by assigning ICD-10 codes for deaths using verbal autopsy narratives. Return only the ICD-10 code without description. E.g. A00 
If there are multiple ICD-10 codes, show one code per line."""

USR_PROMPT = """With the highest certainty, determine the underlying cause of death and provide the most accurate ICD-10 code for a verbal autopsy narrative of a AGE_VALUE_DEATH AGE_UNIT_DEATH old SEX_COD death in Sierra Leone: {open_narrative}"""

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [3]:
# cost projection
# (((len(SYS_PROMPT + prompt) // 4) / 1000) * 0.0005 + (15/1000) * 0.0015 ) * 12000

In [4]:
questionnaire_df =  pd.read_csv("../data_202402/healsl_rd2_neo_v1.csv")
age_df =            pd.read_csv("../data_202402/healsl_rd2_neo_age_v1.csv")
narrative_df =      pd.read_csv("../data_202402/healsl_rd2_neo_narrative_v1.csv")

narrative_df = narrative_df.rename(columns={'summary': 'open_narrative'})


In [5]:
# quick_gp3 = pd.read_csv("../data_202402/healsl_rd1_rapid_chatgpt3_v1.csv")
# quick_gp4 = pd.read_csv("../data_202402/healsl_rd1_rapid_chatgpt4_v1.csv")

# questionnaire_df[questionnaire_df['p1_recon_icd_cod'].isna()][['rowid','p1_icd_cod','p2_icd_cod']]

In [6]:
narrative_only = narrative_df[['rowid','open_narrative']]
sex_only = questionnaire_df[['rowid','sex_cod']]
age_only = age_df[['rowid','age_value_death','age_unit_death']]

merged_df = narrative_only.merge(sex_only, on='rowid').merge(age_only, on='rowid')

# Fill in missing values with empty string
merged_df['sex_cod'] = merged_df['sex_cod'].fillna('')

assert not merged_df.isnull().values.any(), "Execution halted: NaN values found in merged_df"

print(f"Sample of merged_df {merged_df.shape}:")
display(merged_df.sample(5))

Sample of merged_df (233, 5):


Unnamed: 0,rowid,open_narrative,sex_cod,age_value_death,age_unit_death
30,24001110,"As per respondent, the deceased, a zero day ma...",Male,0,Days
98,24002511,According to the respondent that happened to b...,Female,2,Days
199,24002982,"As per respondent, the deceased was a 0 days o...",Male,0,Days
219,24001328,The deceased was 0 day old neonate who died be...,Female,0,Days
173,24000529,"According to the respondent, the deceased was ...",Female,23,Days


In [7]:
# F(x): Initialize the data storage dictionary

def load_data(filename=DATA_FILE):
    
    if os.path.exists(filename):
        print(f"{filename} found. Loading data...")
        with open(filename, 'r') as file:
            data = json.load(file)
        return data
    else:
        print(f"{filename} not found. Initializing empty dictionary...")
        return {}

def save_data(data, filename=DATA_FILE):
    # Save data to a file    
    with open(filename, 'w') as file:
        json.dump(data, file)

In [8]:
# F(x): Send a message to the chatbot
def get_completion(
    messages: list[dict[str, str]],
    model: str = "gpt-3.5-turbo-0125",
    # model: str = "gpt-3.5-turbo-0125",
    # max_tokens=500,
    temperature=0,
    # stop=None,
    # seed=123,
    tools=None,
    logprobs=None,
    top_logprobs=None,
) -> str:

    params = {
        "model": model,
        # "response_format": { "type": "json_object" },
        "messages": messages,
        # "max_tokens": max_tokens,
        "temperature": temperature,
        # "stop": stop,
        # "seed": seed,
        "logprobs": logprobs,
        "top_logprobs": top_logprobs,
    }
    if tools:
        params["tools"] = tools

    completion = client.chat.completions.create(**params)
    return completion

In [9]:
# F(x): Extract ICD probabilities from tokens

def extract_icd_probabilities(logprobs):
    parsed_icd = []
    for pos in range(len(logprobs)):
        temp_df = pd.DataFrame(logprobs[pos: pos+4])
        temp_df = temp_df[temp_df[0].notna() & (temp_df[0].str.strip() != '')]
        temp_df = temp_df[temp_df[0].str.strip() != '\n']
        temp_concat = ''.join(temp_df.iloc[:, 0]).strip()
        if len(temp_concat) > 9:
            continue
        # pattern = r'^[A-Z]\d{0,4}(\.\d{0,4})?$'
        pattern_4part = r'^[A-Z]\d{0,4}(\.\d{1,4})?$'
        match = re.match(pattern_4part, temp_concat)

        if match:
            # print(f"{temp_concat} - valid ICD {np.round((np.exp(temp_df.iloc[:, 1]).mean())*100,2)}%")
            # print(f"**** {temp_concat} - VALID 2-parts ICD ****")
            parsed_icd.append((temp_concat, (np.exp(temp_df.iloc[:, 1]).mean())))
        else:
            # print(f"{temp_concat} - invalid 4-parts.")
            
            #trying 2-parts
            temp_df = pd.DataFrame(logprobs[pos: pos+2])
            temp_concat = ''.join(temp_df.iloc[:, 0]).strip()
            pattern_2part = r'^[A-Z]\d{1,4}$'
            match = re.match(pattern_2part, temp_concat)
            if match:
                # print(f"**** {temp_concat} - VALID 2-parts ICD ****")
                parsed_icd.append((temp_concat, (np.exp(temp_df.iloc[:, 1]).mean())))
            else:
                # print(f"{temp_concat} - invalid 2-parts.")
                pass
            pass

    return parsed_icd

In [17]:
import datetime
# Load existing data or initialize an empty dictionary
data_storage = load_data()
skipped_rows = []
repeated_skips = False
print()


for index, row in merged_df.iterrows():
    # Access the values of each column in the current row
    # hijacking row 
    # row = merged_df[merged_df['rowid'] == 14005966].iloc[0]
    
    rowid = row['rowid']
    
    # Check if rowid already processed. Testing both because json changes int keys to str    
    if (rowid) in data_storage or str(rowid) in data_storage:
        if repeated_skips:
            print("\r", end='', flush=True)
        print(f"Skipping index {index}, row {rowid} - Already processed.", end='', flush=True)
        repeated_skips = True
        skipped_rows.append(rowid)
        continue

    
    narrative = row['open_narrative']
    sex_cod = row['sex_cod']
    age_value_death = row['age_value_death']
    age_unit_death = row['age_unit_death']
    
    prompt = USR_PROMPT
    prompt = prompt.replace('AGE_VALUE_DEATH', str(age_value_death))
    prompt = prompt.replace('AGE_UNIT_DEATH', age_unit_death.lower())
    prompt = prompt.replace('SEX_COD', sex_cod.lower())
    prompt = prompt.format(open_narrative=narrative)
    
    # print("Prompt:")    
    # print(textwrap.fill(prompt, width=WORDWRAP_WIDTH))
    # print()
    
    # for a in range(5):
    completion = get_completion(
        [
            {"role": "system", "content": SYS_PROMPT},
            {"role": "user", "content": prompt}
        ] ,
        model=MODEL_NAME,
        logprobs=True,
        # top_logprobs=2,
    )
    
    # print(completion.choices[0].message)
    
    # for token in completion.choices[0].logprobs.content:
    #     print(f"{repr(str(token.token)).ljust(15)}  {str(token.logprob).ljust(20)} {np.round(np.exp(token.logprob)*100,2)}%")
        
    output_msg = completion.choices[0].message.content
    logprob_data = [(token.token, float(token.logprob)) for token in completion.choices[0].logprobs.content]
    usage_data = list(completion.usage)    
    current_time = datetime.datetime.now(tz=TIMEZONE).isoformat()
       
    data_storage[str(rowid)] = {
        'rowid': rowid,
        'model': MODEL_NAME,
        'system_prompt': SYS_PROMPT,
        'user_prompt': prompt,
        'output_msg': output_msg,
        'logprobs': logprob_data,
        'usage': usage_data,
        'timestamp': current_time
    }

    # Save data periodically (you can adjust the frequency based on your needs)    
    if index % SAVE_FREQ == 0 and index > 0:
        if repeated_skips:
            print("\n", flush=True)
        repeated_skips = False
        
        save_data(data_storage)
        print(f"Saving index: {str(index).ljust(8)} Processing: {str(rowid).ljust(12)} Rows skipped: {len(skipped_rows)}", sep=' ', end='\r', flush=True)
        # break
    
try:
    save_data(data_storage)
    print(f"Saving index: {str(index).ljust(8)} Processing: {str(rowid).ljust(12)} Rows skipped: {len(skipped_rows)}", sep=' ', end='\r', flush=True)
    print("\nData saved successfully.")
except Exception as e:
    print(f"Error saving data: {e}")

if len(skipped_rows) > 0:
    print(f"DF length: {len(merged_df)}")
    print(f"Rows skipped: {len(skipped_rows)}")    

data_storage.json found. Loading data...

Saving index: 232      Processing: 24001069     Rows skipped: 0
Data saved successfully.


In [10]:
data_storage = load_data()

data_storage.json found. Loading data...


In [11]:
df = pd.DataFrame(data_storage).T
df['icds'] = df.apply(lambda x: extract_icd_probabilities(x['logprobs']), axis=1)
df['best_icd'] = df.apply(lambda x: pd.DataFrame(x['icds']).sort_values(by=1, ascending=False).iloc[0,0], axis=1)


In [16]:
df.to_json("data_storage_parsed.json", orient='records')

In [14]:
print(f"output_msg exceeding normal ICD length:")
df[df['output_msg'].apply(lambda x:len(x) > 10)][['output_msg','icds','best_icd']]

output_msg exceeding normal ICD length:


Unnamed: 0,output_msg,icds,best_icd
14002658,Malignant neoplasm of bone and articular carti...,"[(C40.9, 0.8579251792598699)]",C40.9
14004747,A09\nR50.9\nR11.0\nR63.4,"[(A09, 0.3801835598412292), (R50.9, 0.66909530...",R50.9
14006258,V89.2 (Person injured in collision between oth...,"[(V89.2, 0.8129130625381266)]",V89.2
14003822,T79.8\nR40.2,"[(T79.8, 0.6216546258358351), (R40.2, 0.681784...",R40.2
14007755,Tetanus: A33,"[(A33, 0.7714985657154876)]",A33
14002323,T79.3\nT88.9\nN17.9,"[(T79.3, 0.6682882892563279), (T88.9, 0.574929...",N17.9
14008510,V89.2 (Pedestrian injured in collision with ot...,"[(V89.2, 0.8549707965091256)]",V89.2
14007670,Tetanus: A33,"[(A33, 0.7297871458081744)]",A33
14008863,W17.89 - Other specified fall from one level t...,"[(W17.89, 0.7211488803081331)]",W17.89
14001730,Malaria\nB54,"[(B54, 0.9215047096956849), (B54, 0.9215047096...",B54


In [207]:
# # comment this out. Used to fix f(x)
# def extract_icd_probabilities(logprobs):
#     parsed_icd = []
#     for pos in range(len(logprobs)):
#         temp_df = pd.DataFrame(logprobs[pos: pos+4])
#         temp_df = temp_df[temp_df[0].notna() & (temp_df[0].str.strip() != '')]
#         temp_df = temp_df[temp_df[0].str.strip() != '\n']
#         temp_concat = ''.join(temp_df.iloc[:, 0]).strip()
#         if len(temp_concat) > 9:
#             continue
#         # pattern = r'^[A-Z]\d{0,4}(\.\d{0,4})?$'
#         pattern_4part = r'^[A-Z]\d{0,4}(\.\d{1,4})?$'
#         match = re.match(pattern_4part, temp_concat)

#         if match:            
#             print(f"**** {temp_concat} - VALID 4-parts ICD ****")

            

#             display(temp_df)
#             parsed_icd.append((temp_concat, (np.exp(temp_df.iloc[:, 1]).mean())))
#         else:
#             print(f"{temp_concat} - invalid 4-parts.")
            
#             #trying 2-parts
#             temp_df = pd.DataFrame(logprobs[pos: pos+2])
#             temp_concat = ''.join(temp_df.iloc[:, 0]).strip()
#             pattern_2part = r'^[A-Z]\d{1,4}$'
#             match = re.match(pattern_2part, temp_concat)
#             if match:
#                 print(f"**** {temp_concat} - VALID 2-parts ICD ****")
#                 display(temp_df)
#                 parsed_icd.append((temp_concat, (np.exp(temp_df.iloc[:, 1]).mean())))
#             else:
#                 print(f"{temp_concat} - invalid 2-parts.")
#                 pass
#             pass

#     return parsed_icd

# print(data_storage['14000201']['output_msg'])
# print()
# print(extract_icd_probabilities(data_storage['14000201']['logprobs']))

K29.5
R57.0
R11.2
J18.9
K92.1

**** K29.5 - VALID 4-parts ICD ****


Unnamed: 0,0,1
0,K,-0.226075
1,29,-1.688498
2,.,-0.306559
3,5,-0.870149


29.5 - invalid 4-parts.
29. - invalid 2-parts.
.5R - invalid 4-parts.
.5 - invalid 2-parts.
5R57 - invalid 4-parts.
5 - invalid 2-parts.
R57. - invalid 4-parts.
R - invalid 2-parts.
**** R57.0 - VALID 4-parts ICD ****


Unnamed: 0,0,1
0,R,-0.156713
1,57,-1.423716
2,.,-0.018007
3,0,-0.032591


57.0 - invalid 4-parts.
57. - invalid 2-parts.
.0R - invalid 4-parts.
.0 - invalid 2-parts.
0R11 - invalid 4-parts.
0 - invalid 2-parts.
R11. - invalid 4-parts.
R - invalid 2-parts.
**** R11.2 - VALID 4-parts ICD ****


Unnamed: 0,0,1
0,R,-0.532045
1,11,-1.047437
2,.,-0.228074
3,2,-0.473701


11.2 - invalid 4-parts.
11. - invalid 2-parts.
.2J - invalid 4-parts.
.2 - invalid 2-parts.
2J18 - invalid 4-parts.
2 - invalid 2-parts.
J18. - invalid 4-parts.
J - invalid 2-parts.
**** J18.9 - VALID 4-parts ICD ****


Unnamed: 0,0,1
0,J,-1.313691
1,18,-0.361045
2,.,-0.005379
3,9,-0.00547


18.9 - invalid 4-parts.
18. - invalid 2-parts.
.9K - invalid 4-parts.
.9 - invalid 2-parts.
9K92 - invalid 4-parts.
9 - invalid 2-parts.
K92. - invalid 4-parts.
K - invalid 2-parts.
**** K92.1 - VALID 4-parts ICD ****


Unnamed: 0,0,1
0,K,-0.503517
1,92,-1.109303
2,.,-8e-06
3,1,-0.214341


92.1 - invalid 4-parts.
92. - invalid 2-parts.
.1 - invalid 4-parts.
.1 - invalid 2-parts.
1 - invalid 4-parts.
1 - invalid 2-parts.
[('K29.5', 0.5343298732313193), ('R57.0', 0.761463588036748), ('R11.2', 0.5892493045547196), ('J18.9', 0.7387385413692558), ('K92.1', 0.6853138496405538)]
