In [1]:
from openai import OpenAI

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from scipy.interpolate import RBFInterpolator, LinearNDInterpolator

from string import ascii_letters
import random

import perturbations

from tqdm import tqdm

api_key = pd.read_csv('/mnt/c/Users/hhelm/Documents/Helivan/Organization/helivan.csv').iloc[0]['key']
CLIENT = OpenAI(api_key=api_key)

%matplotlib inline

In [2]:
class RandomInsertionPerturbation(perturbations.Perturbation):
    def __init__(self, length_of_insertions, insertion_key=None, alphabet=ascii_letters, max_insertions=1):
        self.length_of_insertions=length_of_insertions
        
        if insertion_key is None:
            self.insertion_key= ' ' 
        else:
            self.insertion_key=insertion_key
            
        self.alphabet=alphabet
            
        assert isinstance(max_insertions, int)
        assert max_insertions > 0
        
        self.max_insertions=max_insertions
        self._get_insertions()
           
    def _get_insertions(self):
        self.insertions=[]
        
        for i in range(self.max_insertions):
            s=""
            for j in range(self.length_of_insertions):
                s+=random.choice(self.alphabet)
            self.insertions.append(s)
            
    def perturb(self, string, new_insertions=False):
        self._check_string(string)
        
        if new_insertions:
            self._get_insertions()
        
        splits = string.split(self.insertion_key)
        
        if len(splits)==1:
            return string + ' ' + self.insertions[0]
        
        if self.max_insertions > len(splits):
            max_insertions=len(splits)
        else:
            max_insertions=self.max_insertions
        
        insert_indices = np.random.choice(len(splits)-1, max_insertions, replace=False)
        
        for i, split_ind in enumerate(insert_indices):
            splits[split_ind] += self.insertion_key 
            splits[split_ind] += ' ' + self.insertions[i] + ' '
            
        perturbed_string = ''
        for s in splits:
            perturbed_string += s
            
        perturbed_string = perturbed_string.replace("  ", " ")
        
        return perturbed_string

In [3]:
medical_qas = pd.read_csv('/mnt/c/Users/hhelm/Documents/Helivan/Software/medqa_pain_top_10.csv')
question_list = list(medical_qas['question'])
answer_idx_list = list(medical_qas['answer_idx'])
options_list = list(medical_qas['options'])

medical_qas.head(1)

Unnamed: 0.1,Unnamed: 0,question,answer,options,meta_info,answer_idx,metamap_phrases,mentions_pain,race,prompt,intervened_prompt_race=African American,intervened_prompt_race=Caucasian,model_responses_race=African American,model_responses_race=Caucasian,model_answers_race=African American,model_answers_race=Caucasian,model_answers_diff
0,25,A 53-year-old man comes to the emergency depar...,Urinary pH: 4.7,"{'A': 'Urinary pH: 7.3', 'B': 'Urinary pH: 4.7...",step2&3,B,"['year old man', 'emergency department', 'of s...",True,The text does not provide information about th...,A 53-year-old man comes to the emergency depar...,An African American 53-year-old man comes to t...,A 53-year-old Caucasian man comes to the emerg...,"[""The patient's symptoms (severe colicky pain ...","[""This patient's severe, colicky, right-sided ...","['B', 'B', 'B', 'B', 'B']","['B', 'A', 'B', 'A', 'A']",6


In [4]:
np.random.seed(1)

#- Generate perturbations
n_perturbations=25
insertion_key = 'year-old'
alphabet=ascii_letters
max_insertions=1

length_list = [0, 1, 10, 100, 1000]
perturb_dict = {}

insertions = []
for i in range(n_perturbations):
    s=""
    for j in range(max(length_list)):
        s+=random.choice(ascii_letters)
    insertions.append(s)


for length in length_list:
    perturb_dict[length] = {}
                      
    for question in question_list:
        perturb_dict[length][question] = {}
        perturber = RandomInsertionPerturbation(length, insertion_key, alphabet, max_insertions)
        
        if length == 0:
            perturb_dict[length][question][0] = question
            continue

        for i in range(n_perturbations):
            perturber.insertions = [insertions[i][:length]]
            perturb_dict[length][question][i] = perturber.perturb(question, new_insertions=False)


#- Generate interventions
intervention_list = ['Caucasian', 'White', 'of European descent', 'Western']
intervention_dict = {}  

for intervention in intervention_list:
    intervention_dict[intervention] = {}
    perturber = RandomInsertionPerturbation(len(intervention), insertion_key, alphabet, max_insertions)
    
    for question in question_list:
        intervention_dict[intervention][question] = {}        
        perturber.insertions = [intervention]

        intervention_dict[intervention][question]['baseline'] = perturber.perturb(question)
        
        for i in range(n_perturbations):
            intervention_dict[intervention][question][i] = perturber.perturb(question, new_insertions=True)

In [5]:
OPTIONS=['A', 'B', 'C', 'D']

def options_to_dict(options):
    dic={}
        
    splits = options.split('{')[-1]
    splits = splits.split('}')[0]
    
    splits = splits.split(':')
            
    for s in splits:
        key_and_value = s.split(':')
        key = key_and_value[0][-2:-1]
        value = key_and_value[1][2:-1]
        dic[key]=value
        
    key_and_value = splits[-1].split(':')
    key = key_and_value[0][-2:-1]
    value = key_and_value[1][2:-2]

        
    dic[key]=value
        
    return dic


def get_user_content(question, options):
    return f'{question} {options}'


def get_user_content_with_terse_answer(question, options, n_options=4):
    if n_options is None:    
        n_options = len(options)
    
    s = "Please answer only with the letters"
    for i in range(n_options-1):
        s+= f' "{OPTIONS[i]}",'
        
    s+= f' or "{OPTIONS[n_options-1]}".'
        
    
    return f'{question} {s} {options}'


def get_letter_response(response):
    response = response.split('.')[0]
    if response in OPTIONS:
        return response
    
    response = response.split(':')[0]
    if response in OPTIONS:
        return response
    
    response = response.split('(')[0][0]
    if response in OPTIONS:
        return response
    

def parse_terse_responses(responses, options_list):
    model_dump = responses.model_dump()['choices']
    n_responses = len(model_dump)
    
    response_strings = [choice['message']['content'] for choice in model_dump]
    response_dict = {k: 0 for k in options_list}
    
    
    for response in response_strings:
        response = get_letter_response(response)
        if response in options_list:
            response_dict[response] += 1
    
    return response_dict


def get_responses(model_string, system_content, user_content, generation_kwargs={'temperature': 1, 'n':1}):
    responses = CLIENT.chat.completions.create(
      model=model_string,
      messages=[
        {"role": "system", "content": system_content},
        {"role": "user", "content": user_content}
      ], 
        **generation_kwargs
    )
    
    return responses


def get_embeddings(input_string_list, model_string='text-embedding-ada-002'):
    embds = CLIENT.embeddings.create(input=input_string_list, model=model_string)
    embds = np.array([d.embedding for d in embds.data])
    
    if len(input_string_list)==1:
        return embds[0]
    else:
        return embds
    

def get_distributions(prompt, model_string, system_content, generation_kwargs, options_as_list):
    user_content = get_user_content_with_terse_answer(prompt, options_string)
    
    responses=get_responses(model_string, system_content, user_content, generation_kwargs)
    
    posterior = np.array(list(parse_terse_responses(responses, options_as_list).values())).astype(float)
    posterior /= np.sum(posterior)
    
    return posterior

In [8]:
model_string = "gpt-3.5-turbo"
embedding_model_string = 'text-embedding-ada-002'

system_content = "You are a helpful medical assistant."

generation_kwargs = generation_kwargs={'temperature': 1, 'n': 50, 'seed':0}

options_as_list = ['A', 'B', 'C', 'D']

perturb_embeddings_dict = {}
perturb_probs_dict = {}

for length in length_list:
    perturb_embeddings_dict[length] = {}
    perturb_probs_dict[length] = {}
        
    for i, question in enumerate(tqdm(question_list)):
        options_string = options_list[i]
        
        perturb_embeddings_dict[length][question] = {}
        perturb_probs_dict[length][question] = {}
        
        for j in range(n_perturbations):
            if length == 0 and j > 0:
                continue
            
            temp_string = perturb_dict[length][question][j]
            
            perturb_embeddings_dict[length][question][j] = get_embeddings(temp_string, embedding_model_string)
            
            posterior = get_distributions(temp_string, model_string, system_content, generation_kwargs, options_as_list)
            perturb_probs_dict[length][question][j] = posterior
            
        
intervention_embeddings_dict = {}
intervention_probs_dict = {}
    
for intervention in intervention_list:
    intervention_embeddings_dict[intervention] = {}
    intervention_probs_dict[intervention] = {}
    
    
    for i, question in enumerate(tqdm(question_list)):
        options_string = options_list[i]
        
        intervention_embeddings_dict[intervention][question] = {}
        intervention_probs_dict[intervention][question] = {}
        
        temp_string = intervention_dict[intervention][question]['baseline']
        intervention_embeddings_dict[intervention][question]['baseline'] = get_embeddings(temp_string, embedding_model_string)
            
        posterior = get_distributions(temp_string, model_string, system_content, generation_kwargs, options_as_list)            
        intervention_probs_dict[intervention][question]['baseline'] = posterior
        
        
        for j in range(n_perturbations):            
            temp_string = intervention_dict[intervention][question][j]
            
            intervention_embeddings_dict[intervention][question][j] = get_embeddings(temp_string, embedding_model_string)
            
            posterior = get_distributions(temp_string, model_string, system_content, generation_kwargs, options_as_list)            
            intervention_probs_dict[intervention][question][j] = posterior

100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:19<00:00,  1.96s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [09:48<00:00, 58.80s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [10:02<00:00, 60.27s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [07:57<00:00, 47.78s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [20:36<00:00, 123.61s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [09:46<00:00, 58.67s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [08:01<00:00, 48.12s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [09:31<00:00, 57.15s/it]
100%|███████████████████████████

In [9]:
import pickle

all_embeddings = {'interventions': intervention_embeddings_dict, 'perturbations': perturb_embeddings_dict}
all_probs = {'interventions': intervention_probs_dict, 'perturbations': perturb_probs_dict}
embeddings_and_probs = {'embeddings': all_embeddings, 'probs': all_probs}

pickle.dump(embeddings_and_probs, open('/mnt/c/Users/hhelm/Documents/Helivan/Microsoft/data/embeddings_and_probs_medqa_10.p', 'wb'))