In [1]:
from paths import *
from utils import *
import json
import openai

# Authenticate API
with open(OPENAI_API_KEY_PATH, 'r') as f:
    openai.api_key = json.loads(f.read())['API-KEY']
    
# Read config
with open(TRIAGE_CONFIG_FILE, 'r') as f:
    config = json.loads(f.read())
    
import itertools
from collections import Counter
from tqdm import tqdm
import pandas as pd
import numpy as np

np.random.seed(config['SEED'])

In [2]:
def construct_all_prompts(vignettes, prompt_fn):
    
    prompts = []
    for i, row in vignettes.iterrows():
        
        if config['NPRIME'] == 0:
             prompts.append([i, prompt_fn(None, row)])
        else:
            
            # Remove vignette being evaluated from pool of priming examples
            priming_vignettes = vignettes.drop(i)

            # Get a priming example from each label
            if config['NPRIME'] == 'representative':
                for comb in itertools.product(
                    *[list(priming_vignettes.loc[priming_vignettes['Correct Triage'] == triage].index) 
                      for triage in vignettes['Correct Triage'].unique()]):
                    
                    prompts.append([
                        i,
                        prompt_fn(
                            priming_vignettes.loc[list(comb)],
                            row
                        )
                    ])
                
            elif config['NPRIME'] == 'representative-sample':
                if config["representative-sample-size"] == 0:
                    raise ValueError("representative-sample-size parameter needs to be >0 when NPRIME=representative-sample")
                
                labelwise_vignette_indices = [list(priming_vignettes.loc[priming_vignettes['Correct Triage'] == triage].index) 
                                              for triage in vignettes['Correct Triage'].unique()]
                
                sample_indices = set(
                    np.random.choice(np.prod([len(x) for x in labelwise_vignette_indices]), 
                                     replace=False,
                                     size=config["representative-sample-size"])
                )
                
                for comb_idx, comb in enumerate(itertools.product(*labelwise_vignette_indices)):
                    if comb_idx in sample_indices:
                        prompts.append([
                            i,
                            prompt_fn(
                                priming_vignettes.loc[list(comb)],
                                row
                            )
                        ])
                        sample_indices.remove(comb_idx)
                
            elif config['NPRIME'] > 0:
                for comb in itertools.combinations(priming_vignettes.index, config['NPRIME']):
                    prompts.append([
                        i,            
                        prompt_fn(
                            priming_vignettes.loc[list(comb)],
                            row,
                        )
                    ])
            
            else:
                raise ValueError(f"Found misspecified value for NPRIME:", {config['NPRIME']})
                         
    return prompts

# Vignettes - 2015

### Read file

In [3]:
vignettes15 = parse_vignettes_2015(VIGNETTES_2015_FP)
print(vignettes15['Correct Triage'].value_counts())
print(vignettes15.shape)
vignettes15.head()

Non-emergent    15
Emergent        15
Self-care       15
Name: Correct Triage, dtype: int64
(45, 4)


Unnamed: 0,Correct Diagnosis,Problem,Simplified,Correct Triage
0,Acute liver failure,A 48-year-old woman with a history of migraine...,"48 y/o f, confusion, disorientation, increasin...",Emergent
1,Appendicitis,A 12-year-old girl presents with sudden-onset ...,"12 y/o f, sudden onset severe abdominal pain, ...",Emergent
2,Asthma,A 27-year-old woman with a history of moderate...,"27 y/o f, Hx of asthma, mild shortness of brea...",Emergent
3,COPD flare (more severe),A 67-year-old woman with a history of COPD pre...,"67 y/o f, Hx of COPD, 3 days worsening shortne...",Emergent
4,Deep vein thrombosis,A 65-year-old woman presents with unilateral l...,"65 y/o f, 5 days swelling, pain in one leg, re...",Emergent


### Prompts

In [4]:
def construct_triage_prompt(priming_vignettes: pd.DataFrame, 
                            eval_vignette: pd.Series,
                            separator: str ='\n###\n') -> str:
    
    CONTEXT = "Given a clinical vignette, classify into one of the 3 "\
              "triage categories: Emergent, Non-emergent, and Self-care."

    PROMPT = "Vignette: {}\nTriage: {}"

    completion_prompt = CONTEXT + separator
    
    if priming_vignettes is not None:
        for _, row in priming_vignettes.iterrows():
            completion_prompt += PROMPT.format(row['Problem'], row['Correct Triage']) + separator

    completion_prompt += PROMPT.format(eval_vignette['Problem'], '').strip() 

    return completion_prompt

In [5]:
all_prompts = construct_all_prompts(vignettes15, construct_triage_prompt)
print(f"Total no. of prompts: {len(all_prompts)}\n") 
print(all_prompts[0][1])

Total no. of prompts: 450

Given a clinical vignette, classify into one of the 3 triage categories: Emergent, Non-emergent, and Self-care.
###
Vignette: A 65-year-old woman presents with unilateral leg pain and swelling of 5 days' duration. There is a history of hypertension, mild CHF, and recent hospitalization for pneumonia. She had been recuperating at home but on beginning to mobilize and walk, the right leg became painful, tender, and swollen. On examination, the right calf is 4 cm greater in circumference than the left when measured 10 cm below the tibial tuberosity. Superficial veins in the leg are more dilated on the right foot and the right leg is slightly redder than the left. There is some tenderness on palpation in the popliteal fossa behind the knee.
Triage: Emergent
###
Vignette: A 45-year-old man presents with acute onset of pain and redness of the skin of his lower extremity. Low-grade fever is present and the pretibial area is erythematous, edematous, and tender.
Triag

### Get completions

In [6]:
results = []
for eval_vignette_idx, prompt in tqdm(all_prompts):
    predicted_triage, prob_predicted_triage = get_completion_prob(prompt=prompt, gpt3_params=config['GPT-3-params'])
    results.append({
        'index': eval_vignette_idx,
        'Prompt': prompt,
        'Predicted Triage': predicted_triage, 
        'LogProb Predicted Triage': prob_predicted_triage}
    ) 
    
# Convert to dataframe
results15 = vignettes15.join(
    pd.DataFrame(results).set_index('index')
)

# Filter to keep selected columns and reorder
results15 = results15[['Prompt', 'Problem', 'Correct Triage', 'Predicted Triage', 'LogProb Predicted Triage']]

100%|██████████| 450/450 [07:04<00:00,  1.06it/s]


### Format and save results

#### (i) All completions

In [60]:
# Save completions
results15.to_csv(PROCESSED/'vignettes15_triage_prediction_all.tsv', sep='\t')

(ii) Aggregated completions

In [56]:
if config['NPRIME'] != 0:
    results15_agg = results15.groupby(['Problem', 'Correct Triage'])['Predicted Triage'].agg(
        lambda s: Counter(s)).reset_index()
    results15_agg['Correct(Yes/No)'] = results15_agg.apply(
        lambda row: 'Yes' if (row['Predicted Triage'][row['Correct Triage']] >= row['Predicted Triage'].most_common(1)[0][1])
                          else 'No',
        axis=1
    )
    
print(f"Accuracy: {(results15_agg['Correct(Yes/No)'] == 'Yes').mean()}")

Accuracy: 0.37777777777777777


In [63]:
# Save aggregated results
results15_agg.to_csv(PROCESSED/'vignettes15_triage_prediction_agg.tsv', sep='\t')

# Vignettes - 2020

### Read file

In [27]:
vignettes20 = parse_vignettes_2020(VIGNETTES_2020_FP)
print(vignettes20['Correct Triage'].value_counts())
print(vignettes20.shape)
vignettes20.head()

1-week       12
Emergent     12
1-day        12
Self-care    12
Name: Correct Triage, dtype: int64
(48, 4)


Unnamed: 0_level_0,Correct Diagnosis,Correct Triage,Problem,Additional Details
Case #,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,Liver failure,Emergent,Confused and sleepy for several hours.,48-year-old female; Can't answer where she is ...
2,Appendicitis,Emergent,Sudden severe abdominal pain.,12-year-old female; Temperature 104 F (40 C); ...
3,Heart attack,Emergent,Chest pain for 1 day.,64-year-old male; Pain is in middle of chest u...
4,Kidney stones,Emergent,Left-sided pain between armpit and hip for 1 h...,45-year-old male; Severe pain; Pain starts on ...
5,Meningitis,Emergent,Headache for 3 days.,18-year-old male; Has fever 102; Neck stiff; L...


### Prompts

In [28]:
def construct_triage_prompt(
    priming_vignettes: pd.DataFrame, 
    eval_vignette: pd.Series,
    separator: str ='\n###\n') -> str:
    
    CONTEXT = "Given a clinical vignette and additional details, classify into one of the 4 "\
              "triage categories: 1-day, 1-week, Emergent, and Self-care."
    PROMPT = "Vignette: {}\nAdditional Details: {}\nTriage: {}" 

    completion_prompt = CONTEXT + separator
    
    if priming_vignettes is not None:
        for _, row in priming_vignettes.iterrows():
            completion_prompt += PROMPT.format(row['Problem'], row['Additional Details'], row['Correct Triage']) \
                + separator

    completion_prompt += PROMPT.format(eval_vignette['Problem'], eval_vignette['Additional Details'], '').strip() 

    return completion_prompt

In [30]:
all_prompts = construct_all_prompts(vignettes20, construct_triage_prompt)
print(f"Total no. of prompts: {len(all_prompts)}\n") 
print(all_prompts[0][1])

Total no. of prompts: 480

Given a clinical vignette and additional details, classify into one of the 4 triage categories: 1-day, 1-week, Emergent, and Self-care.
###
Vignette: Sudden severe abdominal pain.
Additional Details: 12-year-old female; Temperature 104 F (40 C); Has nausea, vomiting, and diarrhea.
Triage: Emergent
###
Vignette: Belly pain and diarrhea for 7 days.
Additional Details: 4-year-old male; Nothing unusual in diet though did have a hamburger at a cookout 3 days before pain started; Has fever; Diarrhea may have blood in it.
Triage: 1-day
###
Vignette: Upper abdominal pain for 2 months.
Additional Details: 40-year-old male; Worse with spicy and fried foods; Worse if he eats late at night before sleeping; Voice is becoming hoarse; .
Triage: 1-week
###
Vignette: White stuff coming out of vagina for 2 days.
Additional Details: 40-year-old female; Vagina also itchy; Doesn't hurt to pee; No abdominal pain; No fever.
Triage: Self-care
###
Vignette: Confused and sleepy for se

### Completions

In [31]:
results = []
for eval_vignette_idx, prompt in tqdm(all_prompts):
    predicted_triage, prob_predicted_triage = get_completion_prob(prompt=prompt, gpt3_params=config['GPT-3-params'])
    results.append({
        'index': eval_vignette_idx,
        'Prompt': prompt,
        'Predicted Triage': predicted_triage, 
        'LogProb Predicted Triage': prob_predicted_triage}
    ) 
    
# Convert to dataframe
results20 = vignettes20.join(
    pd.DataFrame(results).set_index('index')
)

# Filter to keep selected columns and reorder
results20 = results20[['Prompt', 'Problem', 'Additional Details', 'Correct Triage',
                       'Predicted Triage', 'LogProb Predicted Triage']]

100%|██████████| 480/480 [06:18<00:00,  1.27it/s]


### Format and save results

#### (i) All completions

In [53]:
# Save completions
results20.to_csv(PROCESSED/'vignettes20_triage_prediction.all.tsv', sep='\t')

(ii) Aggregated completions

In [49]:
results20_agg = results20.groupby(['Problem', 'Additional Details', 'Correct Triage'])['Predicted Triage'].agg(
    lambda s: Counter(s)).reset_index()
results20_agg['Correct(Yes/No)'] = results20_agg.apply(
    lambda row: 'Yes' if (row['Predicted Triage'][row['Correct Triage']] >= row['Predicted Triage'].most_common(1)[0][1])
                      else 'No',
    axis=1
)
print(f"Accuracy: {(results20_agg['Correct(Yes/No)'] == 'Yes').mean()}")

Accuracy: 0.4375


In [62]:
# Save aggregated results
results20_agg.to_csv(PROCESSED/'vignettes20_triage_prediction_agg.tsv', sep='\t')