In [None]:
"""
The HEALSL project provides verbal autopsy data into multiple different files. Round one and round two; adult, child, and neonate; and questionnaire, 
age, and open narrative are all separate files. This script aims to combine these files into a single dataset (dataframe) to simplify processing.

We extract the deceased's sex from the questionnaire dataset, the deceased's age from the age dataset, and the open narrative recorded from the
verbal autopsy from the narrative dataset. Then, we combine the extracted features (columns) using their row id as the key.

We utilize the OpenAI API's Chat Completions to generate a response for each verbal autopsy record. Several parameters were used for this project:
message: the input text to be processed by the API. It consists of the combination of two text prompts: the system prompt, which is the same for 
all requests provides the model with some context of its role and objective, and the user prompt, which concatenates more specific instructions
regarding the output response, along with the data from the dataframe.

model: this parameter specifies the language model to be used. To strive for consistency and reproducibility of the results, we used specific versions
of the GPT-3 and GPT-4 models; gpt-3.5-turbo-0125 and gpt-4-0613, respectively.

temperature: this parameter can be any value between 0 and 2. It controls the randomness of the output response. Low temperatures result in more 
deterministic responses, while high temperatures result in more random responses. We used a temperature of 0 to ensure the responses were as 
deterministic as possible.

lobprobs: this parameters controls whether the output includes the log probabilities of the tokens. We set it to True to provide more parametric 
information about the output.

The prompts are as follows:

System prompt:
"You are a physician with expertise in determining underlying causes of death in Sierra Leone by assigning ICD-10 codes for deaths using verbal autopsy 
narratives. Return only the ICD-10 code without description. E.g. A00 
If there are multiple ICD-10 codes, show one code per line"

User prompt:
"With the highest certainty, determine the underlying cause of death and provide the most accurate ICD-10 code for a verbal autopsy narrative of a
AGE_VALUE_DEATH AGE_UNIT_DEATH old SEX_COD death in Sierra Leone: {open_narrative}"

    AGE_VALUE_DEATH and AGE_UNIT_DEATH: replaced with age_value_death and age_unit_death values from the age dataset.
    SEX_COD: replaced with sex_cod value from the questionnaire dataset.
    open_narrative: replaced with summary value from the open narrative dataset.

The response from the API is then dissected to extract relevant information in plain text, such as the response text, the log probabilities, 
along with other accounting information such as rowids, models used, timestamps, token consumption, into an array and exported into file.
        

1. Load results storage as array.
    If result storage is does not exist, create an empty array.
2. For each row in the dataframe:
    Check if rowid is in the result storage.
        If rowid is in the result storage, skip the row.
    Compose the two prompts and generate a response using the OpenAI API.
    Store the response and other relevant information in the result storage.
    Save the result storage to a file periodically.
    
"""

In [None]:
import os
import pandas as pd
import numpy as np
from openai import OpenAI
import textwrap
import datetime
import pytz
import json
import re


open_api_key = os.environ.get('OPEN_API_KEY')
client = OpenAI(api_key=open_api_key)

# Input datasets
INPUT_DATASET_FILE = "output.csv"

# Output file
OUTPUT_DATA_FILE = "testing_response.json"


WORDWRAP_WIDTH = 100
SAVE_FREQ = 5

# Models
GPT4 = "gpt-4-0613"
GPT3 = "gpt-3.5-turbo-0125"
MODEL_NAME = GPT3

# Set the timezone to Eastern Time
TIMEZONE = pytz.timezone('US/Eastern')


SYS_PROMPT = """You are a physician with expertise in determining underlying causes of death in Sierra Leone by assigning ICD-10 codes for deaths using verbal autopsy narratives. Return only the ICD-10 code without description. E.g. A00 
If there are multiple ICD-10 codes, show one code per line."""

USR_PROMPT = """With the highest certainty, determine the underlying cause of death and provide the most accurate ICD-10 code for a verbal autopsy narrative of a AGE_VALUE_DEATH AGE_UNIT_DEATH old SEX_COD death in Sierra Leone: {open_narrative}"""

In [None]:
# Variables
dataset_passed = True
required_colnames = ['u_id', 'rowid', 'age_value_death', 'age_unit_death',
                     'open_narrative', 'sex_cod']

# Load dataset
merged_df = pd.read_csv(INPUT_DATASET_FILE)

# TEST: making up a new column
merged_df = merged_df.assign(other_2="extra stuff")

# Check all required columns are in the df
for colnames in required_colnames:
    if colnames not in merged_df.columns:
        print(f"Missing column \"{colnames}\"")
        dataset_passed = False

if not dataset_passed:
    print("Please ensure dataset has all the required columns.")
    raise ValueError(f"Error: Missing columns required for processing.")
    
# Get columns names that are not required
extra_colnames = [colname for colname in merged_df.columns if colname not in required_colnames]

# Transform non-required columns as dictionary in a new column
merged_df['model_residual_columns'] = merged_df[extra_colnames].apply(lambda x: x.to_dict(), axis=1)

In [None]:
# ***** REMOVE THIS WHEN DONE TESTING *****
merged_df = merged_df.sample(5)
merged_df

In [None]:
# F(x): Misc. functions

# Function to get current time in string YYMMDD_HHMMSS format
def get_current_str_time():
    return datetime.datetime.now(tz=TIMEZONE).strftime("%y%m%d_%H%M%S")

# Used to convert ChatCompletion object to a dictionary, recursively
def recursive_dict(obj):
    if isinstance(obj, dict):
        return {k: recursive_dict(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [recursive_dict(v) for v in obj]
    elif hasattr(obj, '__dict__'):
        return recursive_dict(obj.__dict__)
    else:
        return obj

In [None]:
# F(x): Initialize the data storage dictionary

def load_data(filename=OUTPUT_DATA_FILE):
    
    if os.path.exists(filename):
        print(f"{filename} found. Loading data...")
        with open(filename, 'r') as file:
            data = json.load(file)
        return data
    else:
        print(f"{filename} not found. Initializing empty dictionary...")
        return {}

def save_data(data, filename=OUTPUT_DATA_FILE):           
    # Save data to a file   
    with open(filename, 'w') as file:
        json.dump(data, file)

In [None]:
# F(x): Send a message to the chatbot
def get_completion(
    messages: list[dict[str, str]],
    model: str = "gpt-3.5-turbo-0125",
    # model: str = "gpt-3.5-turbo-0125",
    # max_tokens=500,
    temperature=0,
    # stop=None,
    # seed=123,
    tools=None,
    logprobs=None,
    top_logprobs=None,
) -> str:

    params = {
        "model": model,
        # "response_format": { "type": "json_object" },
        "messages": messages,
        # "max_tokens": max_tokens,
        "temperature": temperature,
        # "stop": stop,
        # "seed": seed,
        "logprobs": logprobs,
        "top_logprobs": top_logprobs,
    }
    if tools:
        params["tools"] = tools

    completion = client.chat.completions.create(**params)
    return completion

In [None]:
# Load existing data or initialize an empty dictionary
data_storage = load_data()
skipped_rows = []
repeated_skips = False
print()


for index, row in merged_df.iterrows():
    # Access the values of each column in the current row
    # hijacking row 
    # row = merged_df[merged_df['rowid'] == 14005966].iloc[0]
    
    u_id = row['u_id']    
    
    # Check if rowid already processed. Testing both because json changes int keys to str    
    if (u_id) in data_storage or str(u_id) in data_storage:
        # if repeated_skips:
        #     print("\r", end='', flush=True)
        # print(f"Skipping index {index}, row {u_id} - Already processed.", end='', flush=True)
        repeated_skips = True
        skipped_rows.append(u_id)
        continue

    rowid = row['rowid']
    narrative = row['open_narrative']
    sex_cod = row['sex_cod']
    age_value_death = row['age_value_death']
    age_unit_death = row['age_unit_death']
    other_columns = row['model_residual_columns']
    
    prompt = USR_PROMPT
    prompt = prompt.replace('AGE_VALUE_DEATH', str(age_value_death))
    prompt = prompt.replace('AGE_UNIT_DEATH', age_unit_death.lower())
    prompt = prompt.replace('SEX_COD', sex_cod.lower())
    prompt = prompt.format(open_narrative=narrative)
    
    # print("Prompt:")    
    # print(textwrap.fill(prompt, width=WORDWRAP_WIDTH))
    # print()
    
    # for a in range(5):
    completion = get_completion(
        [
            {"role": "system", "content": SYS_PROMPT},
            {"role": "user", "content": prompt}
        ] ,
        model=MODEL_NAME,
        logprobs=True,
        # top_logprobs=2,
    )
    
    # print(completion.choices[0].message)
    
    # for token in completion.choices[0].logprobs.content:
    #     print(f"{repr(str(token.token)).ljust(15)}  {str(token.logprob).ljust(20)} {np.round(np.exp(token.logprob)*100,2)}%")
        
    output_msg = completion.choices[0].message.content
    logprob_data = [(token.token, float(token.logprob)) for token in completion.choices[0].logprobs.content]
    usage_data = list(completion.usage)    
    current_time = datetime.datetime.now(tz=TIMEZONE).isoformat()
       
    data_storage[str(u_id)] = {
        'u_id': u_id,               # 'u_id' is the unique identifier for the dataset
        'rowid': rowid,
        'model': MODEL_NAME,
        'system_prompt': SYS_PROMPT,
        'user_prompt': prompt,
        'output_msg': output_msg,
        'logprobs': logprob_data,
        'usage': usage_data,
        'timestamp': current_time,
        'other_columns': other_columns,
        'raw': recursive_dict(completion),
    }

    # Save data periodically (you can adjust the frequency based on your needs)    
    if index % SAVE_FREQ == 0 and index > 0:
        if repeated_skips:
            print("\n", flush=True)
        repeated_skips = False
        
        save_data(data_storage)
        print(f"Saving index: {str(index).ljust(8)} Processing: {str(u_id).ljust(12)} Rows skipped: {len(skipped_rows)}", sep=' ', end='\r', flush=True)
        # break
    
try:
    save_data(data_storage)
    print(f"Saving index: {str(index).ljust(8)} Processing: {str(u_id).ljust(12)} Rows skipped: {len(skipped_rows)}", sep=' ', end='\r', flush=True)
    print("\nData saved successfully. Processing Complete.")
except Exception as e:
    print(f"Error saving data: {e}")

if len(skipped_rows) > 0:
    print(f"{len(skipped_rows)} rows skipped. Check skipped_rows for details.")    

In [None]:
with open(f"response_report_{get_current_str_time()}.txt", "w") as file:
    file.write(f"The follow rows are skipped because they were already processed.\n")
    for item in skipped_rows:        
        file.write(f"{str(item)}\n")
