In [90]:
"""
This script takes the raw responses from the OpenAI API, extracts the ICD-10 codes and their associated probabilities,

Pseudo code
-----------
1. Load the data from file.
2. Parse the feature "logprobs" and match with various ICD10 code patterns.
    2.1 Extract only the tokens that forms the ICD10 code pattern and its associated probability.
    2.2 Calculate the mean linear probability of all the tokens involved.
    2.3 Save the ICD10 code, mean linear probability, and relevant information in to "output_probs"
3. Sort ICD10 codes in "output_probs" by their mean linear probability in descending order.
4. Extract the top 5 ICD10 codes and their associated mean linear probabilities into their own columns.
5. Reorder the columns and save the dataframe to file.


Details regarding #2.1 of the pseudo code:
------------------------------------------

Although we specified the requirements in the API prompts, the response output sometimes contain additional information, 
such as extra descriptions, multiple ICD10 codes, or other unrelated information. There are only few dozens of such cases
in over ten thousand responses. Nevertheless, these need to be handled as there can only be one best ICD10 code. In 
general, we look into the output message, find all the ICD10 codes, calculate their mean probability, and save only ICD10 
codes with the top 5 highest mean probabilities.


Details regarding #2.1 of the pseudo code:
------------------------------------------

An output message may show only one ICD10 code, but behind the scenes, the code is formed by a number of tokens. For 
example, an output message of "M54.2" is composed of four tokens: "M", "54", ".", and "2". Each token has its own log 
probability. All probabilities are recorded in the "logprobs" feature as an array. 

The mean probability of a single ICD10 code is simple to calculate as we can just take the mean of the whole array.
However, when output message consists of multiple ICD10 codes or unrelated text, 'logprobs' must be parsed to extract 
only the relevant tokens.

We use a sliding window of various sizes to match different ICD10 code pattern using regular expressions. The pattern is
as follows:

    - ANN.ANNN
    - ANN.ANN
    - ANN

... where A is a letter and N is a number.

This will allow us to capture from the most detailed ICD10 code (e.g. G83.9) to the broadest (e.g. B54).


Details regarding #2.2 of the pseudo code:
------------------------------------------
The formula used for calculating the mean linear probability is:

    Linear_Mean_Probability = (1/n) * sum(exp(logprob_i) for i in 1 to n)

... where "logprobs" is a list of log probabilities associated with the tokens that form the ICD10 code.

        
Details regarding #5 of the pseudo code:
----------------------------------------

Below is the data structure of the parsed data:

    dataframe() = []
        'cause(n)_icd10': the unique identifier for the response. (n) can be 1 to 5.
        'cause(n)_icd10_prob': the mean linear probability of the ICD10 code. (n) can be 1 to 5.
        'output_timestamp': 
        'output_model': 
        'output_system_prompt': 
        'output_user_prompt':
        'output_usage_completion_tokens': Number of tokens used by completion
        'output_usage_prompt_tokens': Number of tokens used by prompt
        'output_probs': Extracted ICD-10 codes, linear mean, and their associated token probabilities.        
        'other_columns': columns carried over from the original dataframe. (optional; by setting)
        'raw': the original raw response (optional; by setting)
    ]
    
Notes
- limit to 40 max-tokens
- don't care whether the best ICD is within the max tokens

"""
pass

In [91]:
import os
import pandas as pd
import numpy as np
import json
import re
from datetime import datetime
import logging


# return the current date and time as a string
def get_datetime_string():
    return datetime.now().strftime('%Y%m%d_%H%M')

# File Settings
IMPORT_O2_JSON_DATA = "./_working_data_240315/02_(sampled)_gpt4_0315.json"
# OUTPUT_FILE filename is automatically generated

# Data Analysis Settings
PAIRS = 5           # Generate up to 5 ICDs

# Output Settings
DROP_EXCESS_COLUMNS = False         # Set True to remove 'other_columns' from output dataframe
DROP_RAW = False                    # Set True to remove the 'raw' column, the original raw response, from the export file



In [92]:
# F(x): Get export filenames based on input filename

def generate_export_filename(file_path) -> tuple[str, str]:
    '''
    Takes a file path as input and returns a tuple containing the names of the parsed JSON and CSV files.

    Parameters:
        file_path (str): The full path of the input file.

    Returns:
        tuple[str(json), str(csv)]: A tuple containing the names of the parsed JSON and CSV files.
    '''
    directory = os.path.dirname(file_path)
    file = os.path.basename(file_path)
    
    temp = file.split(".json")[0]
    temp = temp[2:]
    
    # return {'sorted_icd: f"{directory}/03{temp}_parsed_sorted_ICD.csv"}, f"{directory}/03{temp}_parsed_first_ICD.csv"
            
    return {
        'sorted_icd': f"{directory}/03{temp}_parsed_sorted_ICD.csv",
        'first_icd': f"{directory}/03{temp}_parsed_first_ICD.csv"
    }

In [93]:
# F(x): Initialize the data storage dictionary

def load_data(filename=IMPORT_O2_JSON_DATA):
    if os.path.exists(filename):
        print(f"{filename} found. Loading data...")
        with open(filename, 'r') as file:
            data = json.load(file)
        return data
    else:
        print(f"{filename} not found. Initializing empty dictionary...")
        return {}

def save_data(data, filename=IMPORT_O2_JSON_DATA):
    with open(filename, 'w') as file:
        json.dump(data, file)

In [94]:
# Load JSON data and convert to dataframe
data_storage = load_data()
df = pd.DataFrame(data_storage).T


./_working_data_240315/02_(sampled)_gpt4_0315.json found. Loading data...


In [95]:
# F(x): Extract ICD probabilities from tokens


def extract_icd_probabilities(logprobs, debug=False):
    """
    Extracts ICD-10 codes and their associated probabilities from a list of tokens and log probabilities.

    This function iterates over the list of tokens and log probabilities, concatenating tokens together 
    and checking if they match the pattern of an ICD-10 code. If a match is found, it calculates the mean 
    linear probability of the ICD-10 code and packages the ICD-10 code, mean linear probability, and 
    associated tokens and log probabilities into a dictionary. It then appends this dictionary to a list 
    of parsed ICD-10 codes.

    Args:
        logprobs (list): A list of lists, where each inner list contains a token and its associated log probability.
        debug (bool, optional): If set to True, the function prints debug information. Defaults to False.

    Returns:
        list: A list of dictionaries, where each dictionary contains an ICD-10 code, its mean linear probability, 
              and a dictionary of associated tokens and log probabilities.
    """
    parsed_icds = []
    tmp_df = pd.DataFrame(logprobs)
    if debug > 0:
        print(repr(''.join(tmp_df.iloc[:,0])))
    tmp_df_limit = len(tmp_df)
    for pos in range(tmp_df_limit):
        # Concatenate 2, 4, or 5 tokens to form ICD-10 codes
        temp_concat_ANN = ''.join(tmp_df.iloc[pos:pos+2, 0]).strip()
        temp_concat_ANN_NNN = ''.join(tmp_df.iloc[pos:pos+4, 0]).strip()
        temp_concat_ANN_NNN_A = ''.join(tmp_df.iloc[pos:pos+5, 0]).strip()
        temp_concat_ANA_NNN = ''.join(tmp_df.iloc[pos:pos+5, 0]).strip()
        
        # Reference: https://www.webpt.com/blog/understanding-icd-10-code-structure
        
        # Regular expression pattern for various ICD-10 codes in the format
        # 'ANN' (e.g., 'A10')
        # 'ANN.NNN' (e.g., 'A10.001')
        # 'ANN.NNNA' (e.g., 'A10.001A') 
        # Note: last alphabet valid only if there are 6 characters before it
        # pattern_ANN = r"^[A-Z]\d{2}$"
        pattern_ANN = r"^[A-Z]\d[0-9A-Z]$"
        # pattern_ANN_NNN = r"^[A-Z]\d{2}\.\d{1,3}$"        
        pattern_ANN_NNN = r"^[A-Z]\d[0-9A-Z]\.\d{1,3}$"        
        # pattern_ANN_NNN_A = r"^[A-Z]\d{2}\.\d{3}[A-Z]$"
        pattern_ANN_NNN_A = r"^[A-Z]\d[0-9A-Z]\.\d{3}[A-Z]$"        
        
        # Check if the concatenated tokens match the ICD-10 code patterns
        match_ANN = re.match(pattern_ANN, temp_concat_ANN)
        match_ANN_NNN = re.match(pattern_ANN_NNN, temp_concat_ANN_NNN)
        match_ANN_NNN_A = re.match(pattern_ANN_NNN_A, temp_concat_ANN_NNN_A)
        match_ANA_NNN = re.match(pattern_ANN_NNN, temp_concat_ANA_NNN)
        
        # [debug] Each line will show which of the 3 patterns matched for the 3 token
        if debug == 2:
            print(
                str(pos).ljust(4), 
                repr(temp_concat_ANN).ljust(10), 
                ('yes' if match_ANN else 'no').ljust(15), 
                repr(temp_concat_ANN_NNN).ljust(10), 
                ('yes' if match_ANN_NNN else 'no').ljust(15), 
                repr(temp_concat_ANN_NNN_A).ljust(10), 
                ('yes' if match_ANN_NNN_A else 'no').ljust(15),
                repr(temp_concat_ANA_NNN).ljust(10), 
                ('yes' if match_ANA_NNN else 'no').ljust(5)
                )
        
        # Check match from longest to shortest
        # If a match is found, calculate the mean linear probability 
        # and package the ICD-10 code and associated data
        if match_ANN_NNN_A:
            winning_df = pd.DataFrame(logprobs[pos:pos+5])
            winning_icd = temp_concat_ANN_NNN_A
        elif match_ANA_NNN:
            winning_df = pd.DataFrame(logprobs[pos:pos+5])
            winning_icd = temp_concat_ANA_NNN
        elif match_ANN_NNN:
            winning_df = pd.DataFrame(logprobs[pos:pos+4])
            winning_icd = temp_concat_ANN_NNN            
        elif match_ANN:
            winning_df = pd.DataFrame(logprobs[pos:pos+2])
            winning_icd = temp_concat_ANN            
        else:
            continue

        # detect any rows that are just whitespace (e.g. \n), and drop those rows
        whitespc_index = winning_df[winning_df.loc[:, 0].str.isspace()].index.tolist()
        winning_df = winning_df.drop(whitespc_index)
        
        # [debug] Display the winning ICD-10 code and its associated data
        if debug == 2:
            print(f"**** {winning_icd} - VALID ICD ****")
            display(winning_df)
        
        # Convert log probabilities to linear probabilities and calculate the mean
        winning_mean = np.exp(winning_df.iloc[:, 1]).mean()
        
        # Package the ICD-10 code and associated data
        winning_package = {
            'icd': winning_icd,
            'icd_linprob_mean': winning_mean,
            'logprobs': winning_df.rename(columns={0: 'token', 1:'logprob'}).to_dict(orient='list')
        }

        # check if this ICD-10 is already in the list
        if winning_package in parsed_icds:
            if debug > 0:
                logging.debug("Duplicate ICD-10 code found. Skipping...")
            continue
        
        # Append the package to the list of parsed ICD-10 codes
        parsed_icds.append(winning_package)
    
    # [debug] Display the parsed ICD-10 codes
    if debug > 0:
        display(parsed_icds) 
    
    # Check if parsed_icds is empty
    if not parsed_icds:
        # If it is, raise an error and show the logprobs in question
        logging.warning(f"ICD-10 code not found in this logprobs: {logprobs}")
        
        # winning_package = {
        #     'icd': 'R99',
        #     'icd_linprob_mean': 1,
        #     'logprobs': {'token': ['R', '99'], 'logprob': [-0.00001, -0.00001]}
        # }
        
        # parsed_icds.append(winning_package)
        # raise ValueError(f"No ICD-10 codes could be parsed from the provided logprobs: {logprobs}")    
    
    # Drop the last element if there are more than 5 ICD10 extracted.
    if len(parsed_icds) > 5:
        parsed_icds = parsed_icds[:-1]

    return parsed_icds

# # Uncomment the following lines to test the function. 
# # `test` is an example of the `logprobs` field from the JSON data.
# test = [['A', -0.63648945],  ['09', -1.4643841], ['\n', -0.9866263], ['R', -0.6599979], ['50', -1.5362289],
#  ['.', -0.05481864],  ['9', -0.002321772], ['\n', -0.3524723], ['R', -0.56709456], ['11', -1.263591],
#  ['.', -0.05834798], ['0', -0.73551023], ['\n', -0.5051807], ['R', -0.65759194], ['63', -1.0282977],
#  ['.', -0.0006772888], ['4', -0.71002203]]

# test_output = extract_icd_probabilities(test)
# test_output

# # Uncomment to test a specific case
# extract_icd_probabilities
# (df.loc['14008356_0','output']['choices'][0]['logprobs']['content'])


In [96]:
# F(x): Given a list of ICDs in form of a list of tuples, convert each ICD into 1-dimension Series

def output_icds_to_cols(value, pairs=PAIRS, sort_probs=True):
    """
    Converts a list of ICD-10 codes and their associated probabilities into a one-dimensional pandas Series.

    This function takes a list of tuples, where each tuple contains an ICD-10 code and its associated 
    probability. It converts this list into a DataFrame, sorts the DataFrame by descending probability, 
    drops the 'logprobs' column, reshapes the DataFrame into a one-dimensional Series, and pads the Series 
    to fill a specified number of columns.

    Args:
        value (list): A list of tuples, where each tuple contains an ICD-10 code and its associated probability.
        pairs (int, optional): The number of columns to pad the Series to. Defaults to PAIRS.

    Returns:
        pandas.Series: A one-dimensional Series containing the ICD-10 codes and their associated probabilities.
    """
    if value == []:
        return pd.Series([np.nan] * pairs * 2).astype(object)

    tmp = pd.DataFrame(value) # convert list of tuples to dataframe
    
    if sort_probs:
        tmp = tmp.sort_values(by="icd_linprob_mean", ascending=False) # sort by descending probability
        
    tmp = tmp.drop(columns=['logprobs'])
    tmp = tmp.stack().reset_index(drop=True) # convert to 1 row
    tmp = tmp.reindex(range(pairs*2), axis=1) # pad to fill PAIRS*2 columns
    

    return tmp

# Test
# output_icds_to_cols(test_output)

In [97]:
export_filenames = generate_export_filename(IMPORT_O2_JSON_DATA)
EXPORT_SORTED_ICD_CSV_FILE = export_filenames['sorted_icd']
EXPORT_FIRST_ICD_CSV_FILE = export_filenames['first_icd']

In [98]:
# Get unrecognized colnames
required_colnames = ['uid', 'rowid', 'param_model', 'param_temperature',
                     'param_logprobs', 'param_system_prompt', 'param_user_prompt',
                     'timestamp', 'output']

# Get columns names that are not required
extra_colnames = [colname for colname in df.columns if colname not in required_colnames]


In [99]:
df = df.assign(
    output_msg = df.output.apply(lambda x: x['choices'][0]['message']['content']),
    output_logprobs = df.output.apply(lambda x: [(token['token'], float(token['logprob'])) for token in x['choices'][0]['logprobs']['content']]),
    output_usage_completion_tokens = df.output.apply(lambda x: x['usage']['completion_tokens']),
    output_usage_prompt_tokens = df.output.apply(lambda x: x['usage']['prompt_tokens'])
    
)

# Extract ICD-10 codes and their associated probabilities to a new column
df = df.assign(output_probs=df['output_logprobs'].apply(extract_icd_probabilities))

# Count the number of ICD-10 codes in each response
df['icd10_count'] = df['output_probs'].apply(len)

In [100]:
# Generate column names for the exploded ICDs in cause{n}_icd10 and cause{n}_icd10_prob format
icd_column_names_mapping = {i: f"cause{i // 2 + 1}_icd10" if i % 2 == 0 else f"cause{i // 2 + 1}_icd10_prob" for i in range(PAIRS*2)}

# Apply the `output_icds_to_cols` function to the `output_probs` column
# This will explode the ICDs into separate columns
# parsed_df = df.merge(df.output_probs.apply(output_icds_to_cols).rename(columns=icd_column_names_mapping), left_index=True, right_index=True)

# cause1...5 are filled in the order of the highest probability
parsed_sorted_icd10_df = df.merge(df.output_probs.apply(lambda x: output_icds_to_cols(x, sort_probs=True)).rename(columns=icd_column_names_mapping), left_index=True, right_index=True)

# cause1...5 are filled in the order they appear in the logprobs
parsed_first_icd10_df = df.merge(df.output_probs.apply(lambda x: output_icds_to_cols(x, sort_probs=False)).rename(columns=icd_column_names_mapping), left_index=True, right_index=True)

In [156]:
parsed_sorted_icd10_df.columns

Index(['uid', 'rowid', 'param_model', 'param_temperature', 'param_logprobs',
       'param_system_prompt', 'param_user_prompt', 'timestamp', 'output',
       'age_group', 'round', 'output_msg', 'output_logprobs',
       'output_usage_completion_tokens', 'output_usage_prompt_tokens',
       'output_probs', 'icd10_count', 'cause1_icd10', 'cause1_icd10_prob',
       'cause2_icd10', 'cause2_icd10_prob', 'cause3_icd10',
       'cause3_icd10_prob', 'cause4_icd10', 'cause4_icd10_prob',
       'cause5_icd10', 'cause5_icd10_prob'],
      dtype='object')

In [161]:
for _, a in parsed_sorted_icd10_df[parsed_sorted_icd10_df.icd10_count > 3].sample(15).iterrows():
    print(a['cause1_icd10'])
    display(a['output_probs'])

R63.4


[{'icd': 'I12.0',
  'icd_linprob_mean': 0.8595338383543164,
  'logprobs': {'token': ['I', '12', '.', '0'],
   'logprob': [-0.059283797, -0.11453045, -8.180258e-06, -0.5043144]}},
 {'icd': 'K76.7',
  'icd_linprob_mean': 0.7139134648352159,
  'logprobs': {'token': ['K', '76', '.', '7'],
   'logprob': [-0.36636382, -0.8042357, -1.9816675e-06, -0.33550695]}},
 {'icd': 'R50.9',
  'icd_linprob_mean': 0.8049610423354654,
  'logprobs': {'token': ['R', '50', '.', '9'],
   'logprob': [-0.46263978, -0.5234484, -0.00024001303, -0.0020111948]}},
 {'icd': 'R63.4',
  'icd_linprob_mean': 0.9263324364480308,
  'logprobs': {'token': ['R', '63', '.', '4'],
   'logprob': [-0.06300457, -0.2529098, -1.4855664e-05, -0.010184187]}},
 {'icd': 'R63.0',
  'icd_linprob_mean': 0.7758363820003169,
  'logprobs': {'token': ['R', '63', '.', '0'],
   'logprob': [-0.36066207, -0.79785955, -8.776276e-06, -0.0451564]}},
 {'icd': 'R53',
  'icd_linprob_mean': 0.5055507677707753,
  'logprobs': {'token': ['R', '53'], 'logprob

R63.4


[{'icd': 'I12.0',
  'icd_linprob_mean': 0.7486459334800051,
  'logprobs': {'token': ['I', '12', '.', '0'],
   'logprob': [-0.19359861, -0.5945167, -1.8624639e-06, -0.48002517]}},
 {'icd': 'K76.7',
  'icd_linprob_mean': 0.6806250642559467,
  'logprobs': {'token': ['K', '76', '.', '7'],
   'logprob': [-0.4046402, -1.0076097, -2.1008714e-06, -0.37078124]}},
 {'icd': 'R50.9',
  'icd_linprob_mean': 0.788408252852581,
  'logprobs': {'token': ['R', '50', '.', '9'],
   'logprob': [-0.43220943, -0.6790432, -0.00022153647, -0.0023235565]}},
 {'icd': 'R63.4',
  'icd_linprob_mean': 0.918117321290755,
  'logprobs': {'token': ['R', '63', '.', '4'],
   'logprob': [-0.07451365, -0.28137907, -1.7239736e-05, -0.01050545]}},
 {'icd': 'R63.0',
  'icd_linprob_mean': 0.7565178975517381,
  'logprobs': {'token': ['R', '63', '.', '0'],
   'logprob': [-0.3943446, -0.94907904, -1.0921943e-05, -0.035768703]}},
 {'icd': 'R53',
  'icd_linprob_mean': 0.5210137308484936,
  'logprobs': {'token': ['R', '53'], 'logprob'

R56.0


[{'icd': 'A16.9',
  'icd_linprob_mean': 0.7549412745065133,
  'logprobs': {'token': ['A', '16', '.', '9'],
   'logprob': [-0.47826242, -0.6174271, -0.001908289, -0.14794128]}},
 {'icd': 'E46',
  'icd_linprob_mean': 0.5177382100769704,
  'logprobs': {'token': ['E', '46'], 'logprob': [-1.0526751, -0.3761876]}},
 {'icd': 'J21.9',
  'icd_linprob_mean': 0.7123521996760981,
  'logprobs': {'token': ['J', '21', '.', '9'],
   'logprob': [-0.5624804, -0.7556753, -0.00027968953, -0.21047275]}},
 {'icd': 'R56.0',
  'icd_linprob_mean': 0.8696457208189023,
  'logprobs': {'token': ['R', '56', '.', '0'],
   'logprob': [-0.15237157, -0.13016531, -4.7517467e-05, -0.29839128]}},
 {'icd': 'R09.2',
  'icd_linprob_mean': 0.8243076867458965,
  'logprobs': {'token': ['R', '09', '.', '2'],
   'logprob': [-0.2549476, -0.5719325, -1.2664457e-06, -0.04307318]}}]

R63.4


[{'icd': 'I12.0',
  'icd_linprob_mean': 0.8387875647182439,
  'logprobs': {'token': ['I', '12', '.', '0'],
   'logprob': [-0.04348118, -0.16591269, -1.4140442e-05, -0.59675825]}},
 {'icd': 'K76.7',
  'icd_linprob_mean': 0.7047669164629409,
  'logprobs': {'token': ['K', '76', '.', '7'],
   'logprob': [-0.40434068, -0.85563546, -2.220075e-06, -0.3193239]}},
 {'icd': 'R50.9',
  'icd_linprob_mean': 0.790847532318225,
  'logprobs': {'token': ['R', '50', '.', '9'],
   'logprob': [-0.51402193, -0.5655048, -0.00029840085, -0.002473159]}},
 {'icd': 'R63.4',
  'icd_linprob_mean': 0.9202575343938686,
  'logprobs': {'token': ['R', '63', '.', '4'],
   'logprob': [-0.07130606, -0.27135584, -1.7954959e-05, -0.012552388]}},
 {'icd': 'R63.0',
  'icd_linprob_mean': 0.7990392195908556,
  'logprobs': {'token': ['R', '63', '.', '0'],
   'logprob': [-0.34031492, -0.63670427, -8.657073e-06, -0.04543028]}},
 {'icd': 'R53',
  'icd_linprob_mean': 0.5657417653548672,
  'logprobs': {'token': ['R', '53'], 'logprob

J45.9


[{'icd': 'J45.9',
  'icd_linprob_mean': 0.9487856807158619,
  'logprobs': {'token': ['J', '45', '.', '9'],
   'logprob': [-0.05504319, -0.14052296, -0.009451235, -0.010857277]}},
 {'icd': 'I10',
  'icd_linprob_mean': 0.9481808581608021,
  'logprobs': {'token': ['I', '10'], 'logprob': [-0.01702128, -0.09075771]}},
 {'icd': 'J96.0',
  'icd_linprob_mean': 0.6973696449269214,
  'logprobs': {'token': ['J', '96', '.', '0'],
   'logprob': [-0.65572906, -0.9181754, -0.002814744, -0.1346989]}},
 {'icd': 'R50.9',
  'icd_linprob_mean': 0.8495021538394582,
  'logprobs': {'token': ['R', '50', '.', '9'],
   'logprob': [-0.56631273, -0.18474178, -0.0006089136, -0.00031615852]}},
 {'icd': 'R04.2',
  'icd_linprob_mean': 0.7939085551183627,
  'logprobs': {'token': ['R', '04', '.', '2'],
   'logprob': [-0.11815869, -1.1070409, -1.3856493e-06, -0.04442748]}}]

A09


[{'icd': 'A09',
  'icd_linprob_mean': 0.9795656856537089,
  'logprobs': {'token': ['A', '09'],
   'logprob': [-0.00074626005, -0.04094976]}},
 {'icd': 'J22',
  'icd_linprob_mean': 0.30538814074036125,
  'logprobs': {'token': ['J', '22'], 'logprob': [-1.1002022, -1.2802331]}},
 {'icd': 'R50.9',
  'icd_linprob_mean': 0.9592187309050202,
  'logprobs': {'token': ['R', '50', '.', '9'],
   'logprob': [-0.085616924, -0.018034976, -0.06523978, -4.036525e-05]}},
 {'icd': 'R51',
  'icd_linprob_mean': 0.7902674742998481,
  'logprobs': {'token': ['R', '51'], 'logprob': [-0.5415736, -0.0012978541]}},
 {'icd': 'B50',
  'icd_linprob_mean': 0.6626379732667138,
  'logprobs': {'token': ['B', '50'], 'logprob': [-0.19849193, -0.68258405]}}]

J18.9


[{'icd': 'I12.0',
  'icd_linprob_mean': 0.7733234824536765,
  'logprobs': {'token': ['I', '12', '.', '0'],
   'logprob': [-0.58620036, -0.0021130242, -1.0087517e-05, -0.61808187]}},
 {'icd': 'J18.9',
  'icd_linprob_mean': 0.8131794646799151,
  'logprobs': {'token': ['J', '18', '.', '9'],
   'logprob': [-0.46300825, -0.1582657, -6.4444386e-05, -0.26166102]}},
 {'icd': 'R53',
  'icd_linprob_mean': 0.5325025594749777,
  'logprobs': {'token': ['R', '53'], 'logprob': [-0.46157458, -0.83306533]}},
 {'icd': 'R64',
  'icd_linprob_mean': 0.5411849021257585,
  'logprobs': {'token': ['R', '64'], 'logprob': [-0.23820716, -1.2230524]}},
 {'icd': 'R50.9',
  'icd_linprob_mean': 0.7655547346890994,
  'logprobs': {'token': ['R', '50', '.', '9'],
   'logprob': [-0.1967769, -1.4185665, -0.0010841365, -0.00013214473]}},
 {'icd': 'R11',
  'icd_linprob_mean': 0.4556114044823798,
  'logprobs': {'token': ['R', '11'], 'logprob': [-0.42561468, -1.3553588]}}]

P23.9


[{'icd': 'P23.9',
  'icd_linprob_mean': 0.8228421604535032,
  'logprobs': {'token': ['P', '23', '.', '9'],
   'logprob': [-0.18527304, -0.61671746, -0.0038645624, -0.078356005]}},
 {'icd': 'A41.9',
  'icd_linprob_mean': 0.6580560431139925,
  'logprobs': {'token': ['A', '41', '.', '9'],
   'logprob': [-1.1077374, -1.1974221, -9.849109e-06, -4.310693e-05]}},
 {'icd': 'J21.9',
  'icd_linprob_mean': 0.6629082471756018,
  'logprobs': {'token': ['J', '21', '.', '9'],
   'logprob': [-1.2968919, -0.8435561, -4.9305523e-05, -0.053269897]}},
 {'icd': 'I50.9',
  'icd_linprob_mean': 0.7667977467211441,
  'logprobs': {'token': ['I', '50', '.', '9'],
   'logprob': [-1.528862, -0.14139628, -0.015381669, -0.0024755395]}},
 {'icd': 'K70.9',
  'icd_linprob_mean': 0.6977478815767088,
  'logprobs': {'token': ['K', '70', '.', '9'],
   'logprob': [-1.2378669, -0.67586374, -0.00013703208, -0.007619399]}}]

A09


[{'icd': 'A09',
  'icd_linprob_mean': 0.9747631819476541,
  'logprobs': {'token': ['A', '09'], 'logprob': [-0.00061319396, -0.0511466]}},
 {'icd': 'J22',
  'icd_linprob_mean': 0.346677683284046,
  'logprobs': {'token': ['J', '22'], 'logprob': [-1.0042969, -1.1176322]}},
 {'icd': 'R50.9',
  'icd_linprob_mean': 0.9730464967753444,
  'logprobs': {'token': ['R', '50', '.', '9'],
   'logprob': [-0.05633184, -0.007069959, -0.04705593, -2.8444882e-05]}},
 {'icd': 'R51',
  'icd_linprob_mean': 0.7937275606164953,
  'logprobs': {'token': ['R', '51'], 'logprob': [-0.53029317, -0.0009777903]}},
 {'icd': 'B54',
  'icd_linprob_mean': 0.6973305375603687,
  'logprobs': {'token': ['B', '54'], 'logprob': [-0.12193996, -0.6744048]}}]

R56.0


[{'icd': 'A16.9',
  'icd_linprob_mean': 0.7876525355486815,
  'logprobs': {'token': ['A', '16', '.', '9'],
   'logprob': [-0.46410313, -0.41124478, -0.004308107, -0.14689387]}},
 {'icd': 'E46',
  'icd_linprob_mean': 0.5618042611510842,
  'logprobs': {'token': ['E', '46'], 'logprob': [-0.93872166, -0.31131786]}},
 {'icd': 'J21.9',
  'icd_linprob_mean': 0.7307602884097296,
  'logprobs': {'token': ['J', '21', '.', '9'],
   'logprob': [-0.53369945, -0.64114875, -0.00046033994, -0.21024847]}},
 {'icd': 'R56.0',
  'icd_linprob_mean': 0.848775986708083,
  'logprobs': {'token': ['R', '56', '.', '0'],
   'logprob': [-0.1379347, -0.097457, -6.146429e-05, -0.48309943]}},
 {'icd': 'R09.2',
  'icd_linprob_mean': 0.8239937126077306,
  'logprobs': {'token': ['R', '09', '.', '2'],
   'logprob': [-0.29230177, -0.54757, -1.3856493e-06, -0.029346924]}}]

R63.4


[{'icd': 'I12.0',
  'icd_linprob_mean': 0.7713639979852831,
  'logprobs': {'token': ['I', '12', '.', '0'],
   'logprob': [-0.11272752, -0.3851168, -1.5332478e-05, -0.67000484]}},
 {'icd': 'K70.9',
  'icd_linprob_mean': 0.6072918717432707,
  'logprobs': {'token': ['K', '70', '.', '9'],
   'logprob': [-0.4261989, -1.0091586, -0.00017517358, -0.8871431]}},
 {'icd': 'R50.9',
  'icd_linprob_mean': 0.8273524221318174,
  'logprobs': {'token': ['R', '50', '.', '9'],
   'logprob': [-0.30295455, -0.5601619, -7.767599e-05, -0.00026193185]}},
 {'icd': 'R63.4',
  'icd_linprob_mean': 0.9163524951448951,
  'logprobs': {'token': ['R', '63', '.', '4'],
   'logprob': [-0.073521554, -0.28840703, -4.6563837e-05, -0.01320283]}},
 {'icd': 'R17',
  'icd_linprob_mean': 0.4897818176039933,
  'logprobs': {'token': ['R', '17'], 'logprob': [-0.39330685, -1.1882898]}}]

A09


[{'icd': 'A09',
  'icd_linprob_mean': 0.712999284399908,
  'logprobs': {'token': ['A', '09'], 'logprob': [-0.24463004, -0.4416037]}},
 {'icd': 'G40',
  'icd_linprob_mean': 0.6562899063125897,
  'logprobs': {'token': ['G', '40'], 'logprob': [-1.1551665, -0.0024282173]}},
 {'icd': 'R56',
  'icd_linprob_mean': 0.4188756327642587,
  'logprobs': {'token': ['R', '56'], 'logprob': [-0.5782107, -1.2842788]}},
 {'icd': 'R53',
  'icd_linprob_mean': 0.4313935521236462,
  'logprobs': {'token': ['R', '53'], 'logprob': [-0.60247046, -1.1541368]}},
 {'icd': 'R63.4',
  'icd_linprob_mean': 0.6388304828547613,
  'logprobs': {'token': ['R', '63', '.', '4'],
   'logprob': [-0.7492716, -1.6239448, -0.0517336, -0.06623616]}},
 {'icd': 'R63.5',
  'icd_linprob_mean': 0.6026414282847385,
  'logprobs': {'token': ['R', '63', '.', '5'],
   'logprob': [-0.80144227, -0.98324984, -0.0046396917, -0.5235396]}},
 {'icd': 'R06.02',
  'icd_linprob_mean': 0.4533704538532517,
  'logprobs': {'token': ['R', '06', '.', '02'],

A09


[{'icd': 'A09',
  'icd_linprob_mean': 0.9881784467706127,
  'logprobs': {'token': ['A', '09'],
   'logprob': [-0.022954604, -0.00095040654]}},
 {'icd': 'E86',
  'icd_linprob_mean': 0.8537254681067601,
  'logprobs': {'token': ['E', '86'], 'logprob': [-0.29495215, -0.03782262]}},
 {'icd': 'K72.9',
  'icd_linprob_mean': 0.7271886073885627,
  'logprobs': {'token': ['K', '72', '.', '9'],
   'logprob': [-0.77779555, -0.034148473, -0.0019136423, -0.7239764]}},
 {'icd': 'R57.0',
  'icd_linprob_mean': 0.6049286154919002,
  'logprobs': {'token': ['R', '57', '.', '0'],
   'logprob': [-1.1260089, -0.95826876, -0.0014445223, -0.33788612]}},
 {'icd': 'R09.2',
  'icd_linprob_mean': 0.5291197373531094,
  'logprobs': {'token': ['R', '09', '.', '2'],
   'logprob': [-0.79848677, -1.3784374, -3.4121115e-06, -0.88067997]}}]

N18.9


[{'icd': 'E11.9',
  'icd_linprob_mean': 0.5504363627977326,
  'logprobs': {'token': ['E', '11', '.', '9'],
   'logprob': [-0.8412005, -1.0157803, -0.037181444, -0.8098342]}},
 {'icd': 'N18.9',
  'icd_linprob_mean': 0.7719542684814799,
  'logprobs': {'token': ['N', '18', '.', '9'],
   'logprob': [-0.97446537, -0.3062638, -0.0017336098, -0.02433088]}},
 {'icd': 'I12.0',
  'icd_linprob_mean': 0.7681456216585981,
  'logprobs': {'token': ['I', '12', '.', '0'],
   'logprob': [-0.90316814, -0.032456398, -1.7432603e-06, -0.35776705]}},
 {'icd': 'K72.9',
  'icd_linprob_mean': 0.7566464756615245,
  'logprobs': {'token': ['K', '72', '.', '9'],
   'logprob': [-1.135973, -0.29676807, -3.4121115e-06, -0.038467042]}},
 {'icd': 'R53',
  'icd_linprob_mean': 0.5257772860892567,
  'logprobs': {'token': ['R', '53'], 'logprob': [-0.45512187, -0.8742281]}},
 {'icd': 'R60.9',
  'icd_linprob_mean': 0.7012935308963191,
  'logprobs': {'token': ['R', '60', '.', '9'],
   'logprob': [-0.5617938, -0.76205784, -0.00

R56.0


[{'icd': 'A06.4',
  'icd_linprob_mean': 0.7330490629679691,
  'logprobs': {'token': ['A', '06', '.', '4'],
   'logprob': [-0.010746774, -0.890556, -0.00011248347, -0.63004005]}},
 {'icd': 'R50.9',
  'icd_linprob_mean': 0.798805873883522,
  'logprobs': {'token': ['R', '50', '.', '9'],
   'logprob': [-0.98804533, -0.16709937, -0.023264816, -0.00019948746]}},
 {'icd': 'R17',
  'icd_linprob_mean': 0.6827161117960667,
  'logprobs': {'token': ['R', '17'], 'logprob': [-0.52139723, -0.25910527]}},
 {'icd': 'R56.0',
  'icd_linprob_mean': 0.8634883073890831,
  'logprobs': {'token': ['R', '56', '.', '0'],
   'logprob': [-0.33912933, -0.20115697, -0.00032294946, -0.07893308]}},
 {'icd': 'R63.0',
  'icd_linprob_mean': 0.6439339464690506,
  'logprobs': {'token': ['R', '63', '.', '0'],
   'logprob': [-0.8385104, -1.8954664, -0.0016641122, -0.0052173943]}},
 {'icd': 'R58',
  'icd_linprob_mean': 0.23228242007528363,
  'logprobs': {'token': ['R', '58'], 'logprob': [-1.3477597, -1.585999]}}]

In [153]:
parsed_sorted_icd10_df[['cause1_icd10', 'output_logprobs']]

Unnamed: 0,cause1_icd10,output_logprobs
14006015_0,C10.9,"[(C, -0.4171243), (10, -0.22050741), (., -7.89..."
14003152_0,A15.9,"[(A, -0.0015137888), (15, -0.0008190385), (., ..."
14004789_0,V89.2,"[(V, -1.7432603e-06), (89, -0.05246823), (., -..."
14008356_0,R50.9,"[(A, -0.23412237), (18, -0.024403129), (., -1...."
14004298_0,K25,"[(I, -0.6327079), (10, -0.09191374), (\n, -0.1..."
...,...,...
24002603_9,G91.9,"[(G, -0.0097376695), (91, -0.27452958), (., -0..."
24002738_9,B50.0,"[(B, -0.0010274507), (50, -0.00011272187), (.,..."
24000569_9,A87,"[(A, -0.09199736), (87, -0.45179173)]"
24002421_9,P22.0,"[(P, -0.00045902873), (22, -0.11226488), (., -..."


In [12]:
# Define the mapping variable
column_mapping = {
    'model': 'output_model',
    'system_prompt': 'output_system_prompt',
    'user_prompt': 'output_user_prompt',
    'user_prompt': 'output_user_prompt',
    'timestamp': 'output_created',
}

# Rename the columns using the mapping
parsed_sorted_icd10_df = parsed_sorted_icd10_df.rename(columns=column_mapping)
parsed_first_icd10_df = parsed_first_icd10_df.rename(columns=column_mapping)

export_columns = []
export_columns += ['rowid']
export_columns += list(icd_column_names_mapping.values())
export_columns += [
                    'output_created',
                    'param_model',
                    'param_system_prompt' , 
                    'param_user_prompt', 
                    'output_usage_completion_tokens', 
                    'output_usage_prompt_tokens', 
                    'output_msg',
                    'icd10_count',
                    'output_probs',
                ]

if not DROP_EXCESS_COLUMNS:
    export_columns += extra_colnames
    
if not DROP_RAW:
    export_columns += ['output']


# Show only relevant columns in the final dataframe
export_parsed_sorted_icd10_df = parsed_sorted_icd10_df[export_columns]
export_parsed_first_icd10_df = parsed_first_icd10_df[export_columns]

In [13]:
print(f"Export Dataframe shape: {export_parsed_sorted_icd10_df.shape}")
print(f"Processed export_parsed_sorted_icd10_df exporting to: {EXPORT_SORTED_ICD_CSV_FILE}")
print(f"Processed export_parsed_first_icd10_df exporting to: {EXPORT_FIRST_ICD_CSV_FILE}")

# Save the parsed data to a JSON file
export_parsed_sorted_icd10_df.to_csv(EXPORT_SORTED_ICD_CSV_FILE, index=False)
export_parsed_first_icd10_df.to_csv(EXPORT_FIRST_ICD_CSV_FILE, index=False)

Export Dataframe shape: (11887, 23)
Processed export_parsed_sorted_icd10_df exporting to: ./_working_data_240315/03_(all)_gpt3_0309_parsed_sorted_ICD.csv
Processed export_parsed_first_icd10_df exporting to: ./_working_data_240315/03_(all)_gpt3_0309_parsed_first_ICD.csv


In [14]:

# icd10 counts 
# export_parsed_sorted_icd10_df[export_parsed_sorted_icd10_df.icd10_count < 1]

# verify that the sort is working
# export_parsed_sorted_icd10_df[export_parsed_sorted_icd10_df.icd10_count > 2][list(icd_column_names_mapping.values()) + ['output_probs']]

# verify that the non-sort is working
# export_parsed_first_icd10_df[export_parsed_first_icd10_df.icd10_count > 2][list(icd_column_names_mapping.values()) + ['output_msg', 'output_probs']]