In [250]:
import pandas as pd
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score
from sklearn.preprocessing import MultiLabelBinarizer
import os
import re

In [583]:
# Read in GOLD labels. For AHRQ, these are pre-set annotations from the AHRQ database. For NaNDA, these are results of consensus from our 3 human annotators.
gold_NaNDA = pd.read_excel("../LLM SDOH Classification/annotator_aggreement/total_NaNDA_variable_annotations.xlsx")
gold_AHRQ = pd.read_csv("../LLM SDOH Classification/huggingface_dataset_upload/INPUT_AHRQ_tract_2010-2018.csv")

# Extract LLM predictions:

In [None]:

# Read in NaNDA results:
# Get a list of CSV files in the current directory

# Define the subdirectory name
# Zero shots: "ZERO_NaNDA" ZERO_AHRQ"
# One shots: "ONE_NaNDA" ONE_AHRQ"

subdirectory = "ONE_NaNDA" 

# Get a list of CSV files in the subdirectory
csv_files = [os.path.join(subdirectory, file) for file in os.listdir(subdirectory) if file.endswith('.csv')]
# Create an empty dictionary to store dataframes of LLM "Domain" predictions
dfs_map = {}

# Loop through each CSV file and read it into a dataframe
for file in csv_files:
    df = pd.read_csv(file)
    dfs_map[file] = df





# Only accepts answers with "(1)" ... "(5)" as correct.
def STRICT_map_output_to_extracted_domain(output_value):
    output_value = str(output_value)
    
    if any(value in output_value for value in ['(1)']):
        return 1
    elif any(value in output_value for value in ['(2)']):
        return 2
    elif any(value in output_value for value in ['(3)']):
        return 3
    elif any(value in output_value for value in ['(4)']):
        return 4
    elif any(value in output_value for value in ['(5)']):
        return 5
    elif '?' in output_value:
        """
        REFUSAL: 

        '?' in response: catches 
        (i) a simple response of only "?", and 
        (ii) responses asking for more context via a response question: e.g.,) "[/INST] I'm unable to determine the domain of the variable "count_emp_491" based on its name alone. Could you please provide more context or information about what this variable represents?"
        """
        return '?'
    else:
        # To catch PROMPT NON-ADHEARANCE 
        return 'UNKNOWN'

# For FLAN-T5 models: Define a function to map output values to extracted_domain values
def FLAN_T5_map_output_to_extracted_domain(output_value):
    output_value = str(output_value)
    if any(value in output_value for value in ['unable']):
        # To catch REFUSAL.
        return '?'
    elif any(value in output_value for value in ['1']):
        return 1
    elif any(value in output_value for value in ['2']):
        return 2
    elif any(value in output_value for value in ['3']):
        return 3
    elif any(value in output_value for value in ['4']):
        return 4
    elif any(value in output_value for value in ['5']):
        return 5
    elif '?' in output_value:
        """
        REFUSAL: 

        '?' in response: catches 
        (i) a simple response of only "?", and 
        (ii) responses asking for more context via a response question: e.g.,) "[/INST] I'm unable to determine the domain of the variable "count_emp_491" based on its name alone. Could you please provide more context or information about what this variable represents?"
        """
        return '?'
    else:
        # To catch PROMPT NON-ADHEARANCE 
        return 'UNKNOWN'


# Gemma helper: Gemma models tend to return '1' (for example) nested within the response. 
# We CANNOT say 'correct' if there's a '1' anywhere in the response because the variable name may have a '1'.
# This function returns True iff the '1' does not have any alphnumeric neighbors, as it would within a variable name.
def contains_lonely_digit(s, digit):
    regex = r'\b{}\b'.format(digit)
    return bool(re.search(regex, s))

# Define a function to map output values to extracted_domain values
def map_output_to_extracted_domain(output_value):
    # First, cast to string.
    
    output_value = str(output_value)
    if any(value in output_value for value in ['unable']):
        # To catch REFUSAL.
        return '?'
    elif any(value in output_value for value in ['(1)', 'social_and_community_context', 'Social and Community Context']) or contains_lonely_digit(output_value, 1):
        return 1
    elif any(value in output_value for value in ['(2)', 'economic_stability', 'Economic Stability']) or contains_lonely_digit(output_value, 2):
        return 2
    elif any(value in output_value for value in ['(3)', 'education_access_and_quality', 'Education Access and Quality']) or contains_lonely_digit(output_value, 3):
        return 3
    elif any(value in output_value for value in ['(4)', 'neighborhood_and_built_environment', 'Neighborhood and Built Environment']) or contains_lonely_digit(output_value, 4):
        return 4
    elif any(value in output_value for value in ['(5)', 'health_care_and_quality', 'Healthcare Access and Quality']) or contains_lonely_digit(output_value, 5):
        return 5
    elif '?' in output_value:
        """
        REFUSAL: 

        '?' in response: catches 
        (i) a simple response of only "?", and 
        (ii) responses asking for more context via a response question: e.g.,) "[/INST] I'm unable to determine the domain of the variable "count_emp_491" based on its name alone. Could you please provide more context or information about what this variable represents?"
        """
        return '?'
    else:
        # To catch PROMPT NON-ADHEARANCE 
        return 'UNKNOWN'

# Loop through each dataframe in the dictionary and create extracted_domain column
for filename, df in dfs_map.items():
    # Note: Because of different output repsonse formats, use Flan-T5 models RAW responses (the 'raw_output' column).
    # ... for all other models, we use the 'output' column. The 'output' column strips the prompt prefix from the response.
    # Imporantly: This is a trivial text parsing correction we took, not a limitation of Flan-T5 models.

    if any(value in filename for value in ['flan', 'Flan']):
        df['extracted_domain'] = df['raw_output'].apply(FLAN_T5_map_output_to_extracted_domain)
        # print("APPLIED")
    else: # Evalute on-FLAN models:
        df['extracted_domain'] = df['output'].apply(map_output_to_extracted_domain)
    print(df['extracted_domain'][0])



# Calculate accuracy compared to gold labels (gold_NaNDA, gold_AHRQ)

In [None]:

# Define a list to store the rates
fnames = []
exact_match_rates = []
f1_scores_macro = []
f1_scores_micro = []
qmark_rates = []
garbage_rates = [] # % of model responses that didn't include domains (1)-(5) or ?


# Iterate over each dataframe in dfs_map
for filename, df in dfs_map.items():

    fnames.append(filename)
    # Calculate the total number of rows in the dataframe
    total_rows = len(df)
    
    # gold_labels = gold_AHRQ['domain'].astype(str)
    #merged_preds_labels = None

    if 'NaNDA' in subdirectory:
        # gold_labels = gold_NaNDA['RESOLVED_CONSENSUS'].astype(str)
        # Merge on 'variable_name'
        merged_preds_labels = pd.merge(df, gold_NaNDA, on='variable_name')
        predicted_labels = merged_preds_labels['extracted_domain'].astype(str)
        gold_labels = merged_preds_labels['RESOLVED_CONSENSUS'].astype(str) 
    elif 'AHRQ' in subdirectory:
        merged_preds_labels = pd.merge(df, gold_AHRQ, on=['variable_name', 'domain'])
        predicted_labels = merged_preds_labels['extracted_domain'].astype(str)
        gold_labels = merged_preds_labels['domain'].astype(str)


    # Calculate the number of exact string matches in the 'output' column
    exact_matches = (predicted_labels == gold_labels).sum()
    
    # Calculate the percentage of exact string matches
    match_rate = (exact_matches / total_rows) * 100
    exact_match_rates.append(match_rate)

    # Track poor performances: 
    pct_qmark = (predicted_labels == '?').sum() / total_rows
    qmark_rates.append(pct_qmark)
    pct_garbage = (predicted_labels == 'UNKNOWN').sum() / total_rows
    garbage_rates.append(pct_garbage)
    

    # Calculate the F1 score
    f1_mac = f1_score(gold_labels, predicted_labels, average='macro')
    f1_micro = f1_score(gold_labels, predicted_labels, average='micro') # 'weighted'
    f1_scores_macro.append(f1_mac)
    f1_scores_micro.append(f1_micro)

# Print the list of rates
fnames = [s.split("/", 1)[1] for s in fnames]
print(fnames)
print(exact_match_rates)
print("MAX ACCURACY: ", max(exact_match_rates))
idx_max_acc = exact_match_rates.index(max(exact_match_rates))
print("FNAME: ", fnames[idx_max_acc])

print(f1_scores_macro)
print("MAX F1-Macro: ", max(f1_scores_macro))
idx_max_acc = f1_scores_macro.index(max(f1_scores_macro))
print("FNAME: ", fnames[idx_max_acc])

print(f1_scores_micro)
print("MAX F1-Micro: ", max(f1_scores_micro))
idx_max_acc = f1_scores_micro.index(max(f1_scores_micro))
print("FNAME: ", fnames[idx_max_acc])

print("Q MARK RATES: ", qmark_rates)
print("GARBAGE RATES: ", garbage_rates)

# Get max values:
# idx_max_acc = exact_match_rates.index(max(exact_match_rates))
# print(fnames[idx_max_acc])


In [586]:
# Define a custom sorting key function to output prompt ablation results in ascending order.
def custom_sort_key(x):
    if x.startswith('abc'):
        return 7
    elif x.startswith('a') and not x.startswith('ab') and not x.startswith('ac'):
        return 1
    elif x.startswith('b') and not x.startswith('bc'):
        return 2
    elif x.startswith('c'):
        return 3
    elif x.startswith('ab') and not x.startswith('abc'):
        return 4
    elif x.startswith('ac'):
        return 5
    elif x.startswith('bc'):
        return 6
    else:
        return 7

# Display LLM inference results as LaTeX table

In [None]:
# Zip models + performances.

nanda_zero_performance = pd.DataFrame({'Model': fnames, 'Refusal': qmark_rates}) # 'F1-micro': f1_scores_micro
# nanda_zero_performance = pd.DataFrame({'Model': fnames, 'Accuracy': exact_match_rates}) # ACCURACY
# nanda_zero_performance = pd.DataFrame({'Model': fnames, 'F1-macro': f1_scores_macro}) # Macro F1
# nanda_zero_performance = pd.DataFrame({'Model': fnames, '? Rate': qmark_rates, 'Prompt Non-adherence': garbage_rates}) # NON-ADHERENCE
# nanda_zero_performance = pd.DataFrame({'Model': fnames, 'Prompt Non-adherence': garbage_rates}) # REFUSAL


# Filter for model:
# FOR ROW ORDER:
z_shot_nanda_names = ['llama7b', 'llama13b', 'llama70b', 'gemma-2b', 'gemma7b', 'mistral7bv1', 'mistral7bv2', 'flant5xl', 'flant5xxl']
z_shot_AHRQ_names = ['llama7b', 'llama13b', 'Llama-2-70b', 'gemma-2b', 'gemma7b', 'mistral7bv1', 'mistral7bv2', 'flant5xl']
o_shot_nanda_names = ['Llama-2-7b', 'Llama-2-13b', 'Llama-2-70b', 'gemma-2b', 'gemma-7b', 'Mistral-7B-Instruct-v0.1', 'Mistral-7B-Instruct-v0.2', 'flan-t5-xl', 'flan-t5-xxl']
o_shot_AHRQ_names = []
filtered_df = nanda_zero_performance[nanda_zero_performance['Model'].str.contains(o_shot_nanda_names[8])] # Mistral-7B-Instruct-v0.2

# filtered_df.sort_values(by='Model')
print_df = filtered_df.iloc[filtered_df['Model'].map(custom_sort_key).argsort()].T


print(print_df.to_latex(index=False,
                  float_format="{:.3f}".format))