In [1]:
from paths import *
from utils import *
import json
import openai

# Authenticate API
with open(OPENAI_API_KEY_PATH, 'r') as f:
    openai.api_key = json.loads(f.read())['API-KEY']
    
# Read config
with open(TRIAGE_CONFIG_FILE, 'r') as f:
    config = json.loads(f.read())
    
import math
import itertools
from collections import Counter
from tqdm import tqdm
import pandas as pd
import numpy as np

np.random.seed(config['SEED'])

NUM_PROMPTS_PER_VIGN = 10

In [2]:
def construct_all_prompts(vignettes, prompt_fn):
    
    prompts = []
    for i, row in vignettes.iterrows():
        
        if config['NPRIME'] == 0:
             prompts.append([i, prompt_fn(None, row)])
        else:
            
            # Remove vignette being evaluated from pool of priming examples
            priming_vignettes = vignettes.drop(i)

            # Get a priming example from each label
            if config['NPRIME'] == 'representative':
                for comb in itertools.product(
                    *[list(priming_vignettes.loc[priming_vignettes['Correct Triage'] == triage].index) 
                      for triage in vignettes['Correct Triage'].unique()]):
                    
                    prompts.append([
                        i,
                        prompt_fn(
                            priming_vignettes.loc[list(comb)],
                            row
                        )
                    ])
                
            elif config['NPRIME'] == 'representative-sample':
                if config["representative-sample-size"] == 0:
                    raise ValueError("representative-sample-size parameter needs to be >0 when NPRIME=representative-sample")
                
                labelwise_vignette_indices = [list(priming_vignettes.loc[priming_vignettes['Correct Triage'] == triage].index) 
                                              for triage in vignettes['Correct Triage'].unique()]
                
                sample_indices = set(
                    np.random.choice(np.prod([len(x) for x in labelwise_vignette_indices]), 
                                     replace=False,
                                     size=config["representative-sample-size"])
                )
                
                for comb_idx, comb in enumerate(itertools.product(*labelwise_vignette_indices)):
                    if comb_idx in sample_indices:
                        prompts.append([
                            i,
                            prompt_fn(
                                priming_vignettes.loc[list(comb)],
                                row
                            )
                        ])
                        sample_indices.remove(comb_idx)
                
            elif config['NPRIME'] > 0:
                for comb in itertools.combinations(priming_vignettes.index, config['NPRIME']):
                    prompts.append([
                        i,            
                        prompt_fn(
                            priming_vignettes.loc[list(comb)],
                            row,
                        )
                    ])
            
            else:
                raise ValueError(f"Found misspecified value for NPRIME:", {config['NPRIME']})
                         
    return prompts

# Vignettes - 2015

### Read file

In [3]:
vignettes15 = parse_vignettes_2015(VIGNETTES_2015_FP)
print(vignettes15['Correct Triage'].value_counts())
print(vignettes15.shape)
vignettes15.head()

Non-emergent    15
Self-care       15
Emergent        15
Name: Correct Triage, dtype: int64
(45, 4)


Unnamed: 0,Correct Diagnosis,Problem,Simplified,Correct Triage
0,Acute liver failure,A 48-year-old woman with a history of migraine...,"48 y/o f, confusion, disorientation, increasin...",Emergent
1,Appendicitis,A 12-year-old girl presents with sudden-onset ...,"12 y/o f, sudden onset severe abdominal pain, ...",Emergent
2,Asthma,A 27-year-old woman with a history of moderate...,"27 y/o f, Hx of asthma, mild shortness of brea...",Emergent
3,COPD flare (more severe),A 67-year-old woman with a history of COPD pre...,"67 y/o f, Hx of COPD, 3 days worsening shortne...",Emergent
4,Deep vein thrombosis,A 65-year-old woman presents with unilateral l...,"65 y/o f, 5 days swelling, pain in one leg, re...",Emergent


### Prompts

In [4]:
def construct_triage_prompt(priming_vignettes: pd.DataFrame, 
                            eval_vignette: pd.Series,
                            separator: str ='\n###\n') -> str:
    
    CONTEXT = "Given a clinical vignette, classify into one of the 3 "\
              "triage categories: Emergent, Non-emergent, and Self-care."

    PROMPT = "Vignette: {}\nTriage: {}"

    completion_prompt = CONTEXT + separator
    
    if priming_vignettes is not None:
        for _, row in priming_vignettes.iterrows():
            completion_prompt += PROMPT.format(row['Problem'], row['Correct Triage']) + separator

    completion_prompt += PROMPT.format(eval_vignette['Problem'], '').strip() 

    return completion_prompt

In [5]:
all_prompts = construct_all_prompts(vignettes15, construct_triage_prompt)
print(f"Total no. of prompts: {len(all_prompts)}\n") 
print(all_prompts[0][1])

Total no. of prompts: 450

Given a clinical vignette, classify into one of the 3 triage categories: Emergent, Non-emergent, and Self-care.
###
Vignette: A 65-year-old woman presents with unilateral leg pain and swelling of 5 days' duration. There is a history of hypertension, mild CHF, and recent hospitalization for pneumonia. She had been recuperating at home but on beginning to mobilize and walk, the right leg became painful, tender, and swollen. On examination, the right calf is 4 cm greater in circumference than the left when measured 10 cm below the tibial tuberosity. Superficial veins in the leg are more dilated on the right foot and the right leg is slightly redder than the left. There is some tenderness on palpation in the popliteal fossa behind the knee.
Triage: Emergent
###
Vignette: A 45-year-old man presents with acute onset of pain and redness of the skin of his lower extremity. Low-grade fever is present and the pretibial area is erythematous, edematous, and tender.
Triag

### Get completions

In [6]:
results = []
for eval_vignette_idx, prompt in tqdm(all_prompts):
    predicted_triage, prob_predicted_triage = get_completion_prob(prompt=prompt, gpt3_params=config['GPT-3-params'])
    results.append({
        'index': eval_vignette_idx,
        'Prompt': prompt,
        'Predicted Triage': predicted_triage, 
        'LogProb Predicted Triage': prob_predicted_triage}
    ) 
    
# Convert to dataframe
results15 = vignettes15.join(
    pd.DataFrame(results).set_index('index')
)

# Filter to keep selected columns and reorder
results15 = results15[['Prompt', 'Problem', 'Correct Triage', 'Predicted Triage', 'LogProb Predicted Triage']]

100%|██████████| 450/450 [06:38<00:00,  1.13it/s]


### Format and save results

#### (i) All completions

In [7]:
# Save completions
results15.to_csv(PROCESSED/'vignettes15_triage_prediction_all.tsv', sep='\t')

#### (ii) Aggregated completions (exact)

In [8]:
results15 = pd.read_csv(PROCESSED/'vignettes15_triage_prediction_all.tsv', sep='\t', index_col='Unnamed: 0')

# Convert to probability
results15['Prob Triage'] = results15['LogProb Predicted Triage'].agg(lambda s: math.exp(s))

# See if predictions are correct or not
results15['Correct(Yes/No)'] = results15.apply(
    lambda row: 'Yes' if (row['Correct Triage'] == row['Predicted Triage'])  else 'No', axis=1)

# Aggregate correct or not at the vignette level
vignette_pred_map = results15.groupby('Problem')['Correct(Yes/No)'].\
    agg(lambda series: 'Yes' if 'Yes' in series.mode().values else 'No').to_dict()
results15['Correct(Yes/No)'] = results15['Problem'].replace(vignette_pred_map)

# Aggregate predictions and probability scores
results15_agg_ex = results15.groupby(['Problem', 'Correct(Yes/No)', 'Correct Triage', 'Predicted Triage'])\
    ['Prob Triage'].agg(['count', 'mean'])
results15_agg_ex.columns = ['Number of Predictions', 'Average Probability']

# Save completions
results15_agg_ex.to_excel(PROCESSED/'vignettes15_triage_prediction_agg_exact.xlsx')

print(f"Accuracy: {sum(val=='Yes' for val in vignette_pred_map.values()) / len(vignette_pred_map)}")

Accuracy: 0.4666666666666667


# Vignettes - 2020

### Read file

In [3]:
vignettes20 = parse_vignettes_2020(VIGNETTES_2020_FP)
print(vignettes20['Correct Triage'].value_counts())
print(vignettes20.shape)
vignettes20.head()

Emergent     12
1-day        12
1-week       12
Self-care    12
Name: Correct Triage, dtype: int64
(48, 4)


Unnamed: 0_level_0,Correct Diagnosis,Correct Triage,Problem,Additional Details
Case #,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,Liver failure,Emergent,Confused and sleepy for several hours.,48-year-old female; Can't answer where she is ...
2,Appendicitis,Emergent,Sudden severe abdominal pain.,12-year-old female; Temperature 104 F (40 C); ...
3,Heart attack,Emergent,Chest pain for 1 day.,64-year-old male; Pain is in middle of chest u...
4,Kidney stones,Emergent,Left-sided pain between armpit and hip for 1 h...,45-year-old male; Severe pain; Pain starts on ...
5,Meningitis,Emergent,Headache for 3 days.,18-year-old male; Has fever 102; Neck stiff; L...


### Prompts

In [4]:
def construct_triage_prompt(
    priming_vignettes: pd.DataFrame, 
    eval_vignette: pd.Series,
    separator: str ='\n###\n') -> str:
    
    CONTEXT = "Given a clinical vignette and additional details, classify into one of the 4 "\
              "triage categories: 1-day, 1-week, Emergent, and Self-care."
    PROMPT = "Vignette: {}\nAdditional Details: {}\nTriage: {}" 

    completion_prompt = CONTEXT + separator
    
    if priming_vignettes is not None:
        for _, row in priming_vignettes.iterrows():
            completion_prompt += PROMPT.format(row['Problem'], row['Additional Details'], row['Correct Triage']) \
                + separator

    completion_prompt += PROMPT.format(eval_vignette['Problem'], eval_vignette['Additional Details'], '').strip() 

    return completion_prompt

In [5]:
all_prompts = construct_all_prompts(vignettes20, construct_triage_prompt)
print(f"Total no. of prompts: {len(all_prompts)}\n") 
print(all_prompts[0][1])

Total no. of prompts: 480

Given a clinical vignette and additional details, classify into one of the 4 triage categories: 1-day, 1-week, Emergent, and Self-care.
###
Vignette: Fever and rash for 2 days.
Additional Details: 8-year-old male; Rash is worse on the ankles and wrists; Has joint pain and headache; Was camping recently.
Triage: Emergent
###
Vignette: Trouble breathing for 3 days.
Additional Details: 27-year-old female; Recent cold; Wheezing and coughing, especially at night; Has asthma; Inhalers only help for a couple of hours.
Triage: 1-day
###
Vignette: Thirsty and peeing a lot for 4 weeks.
Additional Details: 52-year-old male; Feels tired all the time; Has blurry vision on and off.
Triage: 1-week
###
Vignette: Painful swollen right eyelid for 1 day.
Additional Details: 30-year-old male; Pain is at edge of eyelid; Hurts to touch it; No change in vision.
Triage: Self-care
###
Vignette: Confused and sleepy for several hours.
Additional Details: 48-year-old female; Can't answe

### Completions

In [12]:
results = []
for eval_vignette_idx, prompt in tqdm(all_prompts):
    predicted_triage, prob_predicted_triage = get_completion_prob(prompt=prompt, gpt3_params=config['GPT-3-params'])
    results.append({
        'index': eval_vignette_idx,
        'Prompt': prompt,
        'Predicted Triage': predicted_triage, 
        'LogProb Predicted Triage': prob_predicted_triage}
    ) 
    
# Convert to dataframe
results20 = vignettes20.join(
    pd.DataFrame(results).set_index('index')
)

# Filter to keep selected columns and reorder
results20 = results20[['Prompt', 'Problem', 'Additional Details', 'Correct Triage',
                       'Predicted Triage', 'LogProb Predicted Triage']]

100%|██████████| 480/480 [05:42<00:00,  1.40it/s]


### Format and save results

#### (i) All completions

In [13]:
# Save completions
results20.to_csv(PROCESSED/'vignettes20_triage_prediction_all.tsv', sep='\t')

#### (ii) Aggregated completions (exact)

In [3]:
results20 = pd.read_csv(PROCESSED/'vignettes20_triage_prediction_all.tsv', sep='\t', index_col='Unnamed: 0')

# Calculate probability of prediction by normalizing over all queries
results20 = results20.groupby(["Problem", "Additional Details", "Correct Triage", "Predicted Triage"]).agg({
    "Prompt": len
})
results20["Pr(prediction)"] = results20 / NUM_PROMPTS_PER_VIGN 
results20 = results20.reset_index()

# See if predictions are correct or not
results20['Correct'] = results20["Correct Triage"] == results20["Predicted Triage"]

# Get prob of correct prediction to decide whether a vignette was correctly triaged
results20["Pr(correct prediction)"] = results20["Pr(prediction)"] * results20["Correct"]

In [4]:
# Per vignette results
results20_exact = results20.groupby(["Additional Details", "Problem", "Correct Triage"])\
    ["Pr(correct prediction)"].agg(lambda s: s.sum() >= 0.5).rename("Correct").reset_index()

In [5]:
# Retrieve the top predicted triage (non-randomly chosen in case of duplicates - but not used anywhere so okay)
results20_top1 = results20.\
    sort_values(["Additional Details", "Problem", "Correct Triage", "Prompt"], ascending=False).\
    reset_index().\
    drop_duplicates(["Additional Details", "Problem", "Correct Triage"], keep="first")

# Add to per vignette results table
results20_exact = results20_exact.merge(
    results20_top1[["Additional Details", "Problem", "Correct Triage", "Predicted Triage", "Pr(prediction)"]],
    on=["Additional Details", "Problem", "Correct Triage"]
)

In [6]:
# Retrieve prob. of correct triage
results20_correct = (results20.
                     groupby(["Additional Details", "Problem", "Correct Triage"])["Pr(correct prediction)"].
                     agg(sum).reset_index())

# Add to per vignette results table
results20_exact = results20_exact.merge(
    results20_correct[["Additional Details", "Problem", "Correct Triage", "Pr(correct prediction)"]],
    on=["Additional Details", "Problem", "Correct Triage"]
)

In [7]:
results20_exact.to_csv(PROCESSED/'vignettes20_triage_prediction_exact.tsv', sep='\t', index=None)

#### (iii) Aggregated completions (dichotimized)

In [8]:
# Read predictions
results20 = pd.read_csv(PROCESSED/'vignettes20_triage_prediction_all.tsv', sep='\t', index_col='Unnamed: 0')

# Map to grouped labels
dichotimized_triage = {
    'Emergent': 'Emergent/1-day',
    '1-day': 'Emergent/1-day',
    '1-week': '1-week/Self-care',
    'Self-care': '1-week/Self-care'
}
results20['Correct Triage'] = results20['Correct Triage'].map(lambda s: dichotimized_triage[s])
results20['Predicted Triage'] = results20['Predicted Triage'].map(lambda s: dichotimized_triage[s])

# Calculate probability of prediction by normalizing over all queries
results20 = results20.groupby(["Problem", "Additional Details", "Correct Triage", "Predicted Triage"]).agg({
    "Prompt": len
})
results20["Pr(prediction)"] = results20 / NUM_PROMPTS_PER_VIGN 
results20 = results20.reset_index()

# See if predictions are correct or not
results20['Correct'] = results20["Correct Triage"] == results20["Predicted Triage"]

# Get prob of correct prediction to decide whether a vignette was correctly triaged
results20["Pr(correct prediction)"] = results20["Pr(prediction)"] * results20["Correct"]

In [9]:
# Per vignette results
results20_dich = results20.groupby(["Additional Details", "Problem", "Correct Triage"])\
    ["Pr(correct prediction)"].agg(lambda s: s.sum() >= 0.5).rename("Correct").reset_index()

In [10]:
# Retrieve the top predicted triage
results20_top1 = results20.\
    sort_values(["Additional Details", "Problem", "Correct Triage", "Prompt"], ascending=False).\
    reset_index().\
    drop_duplicates(["Additional Details", "Problem", "Correct Triage"], keep="first")

# Add to per vignette results table
results20_dich = results20_dich.merge(
    results20_top1[["Additional Details", "Problem", "Correct Triage", "Predicted Triage", "Pr(prediction)"]],
    on=["Additional Details", "Problem", "Correct Triage"]
)

In [11]:
# Retrieve prob. of correct triage
results20_correct = (results20.
                     groupby(["Additional Details", "Problem", "Correct Triage"])["Pr(correct prediction)"].
                     agg(sum).reset_index())

# Add to per vignette results table
results20_dich = results20_dich.merge(
    results20_correct[["Additional Details", "Problem", "Correct Triage", "Pr(correct prediction)"]],
    on=["Additional Details", "Problem", "Correct Triage"]
)

In [12]:
results20_dich.to_csv(PROCESSED/'vignettes20_triage_prediction_dich.tsv', sep='\t', index=None)