In [62]:
import pickle
import pandas as pd
import os
import openai
import numpy as np
import ipdb
import re
from tqdm import tqdm

from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
import spacy
import scipy

nlp = spacy.load("en_core_web_sm")
openai.api_key= os.environ['OPENAI_KEY']

from data_utils import *
from gpt3_utils import *
from eval_utils import *

pd.set_option('max_rows',500,'max_colwidth',10000)
pd.options.display.float_format = "{:,.2f}".format

In [4]:
#Loading BC5CDR Chemical and Disease Training Sets

chemical_train = pd.read_csv('../data/bc5cdr_chemical.train.processed.tsv',sep='\t')
disease_train = pd.read_csv('../data/bc5cdr_disease.train.processed.tsv',sep='\t')

In [5]:
train_half1 = disease_train[0:int(len(disease_train)/2)]
train_half2 = disease_train[int(len(disease_train)/2):]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_df['test_ready_prompt'] = [few_shot_prompt+'\n\n'+empty_prompt for empty_prompt in test_df['empty_prompts']]


In [None]:
def create_prompt_dataset_stratified(train_df, test_df, seed, few_shot_num, number_of_test_samples, selection_strategy,sep=', '):
    
    random = np.random.RandomState(seed)
    
    few_shot_prompt, chosen_prompt_ids = create_few_shot_prompt(train_df, random, few_shot_num, selection_strategy)
    
    test_df['test_ready_prompt'] = [few_shot_prompt+'\n\n'+empty_prompt for empty_prompt in test_df['empty_prompts']]
    
    if number_of_test_samples != 'all':
        #Making sure the samples are the same as the first batch
        random = np.random.RandomState(42)
        random.permutation(train_df.index)
        chosen_test_ids = random.permutation(test_df.index)[:number_of_test_samples]
        
        chosen_test_df = test_df.loc[chosen_test_ids]
    else:
        chosen_test_df = test_df
    
    return {'seed':seed,'few_shot_prompt': few_shot_prompt, 'chosen_prompt_ids':chosen_prompt_ids,'sep': sep,'test_df': chosen_test_df}

In [30]:
def test_prompt_selection(engine, train_df, eval_df, few_shot_seeds, dev_set_seeds, few_shot_size=5, eval_size=50):

    result_by_seeds = []
    result_dfs_by_seeds = {}

    for few_shot_seed in few_shot_seeds:
        
        for dev_set_seed in dev_set_seeds:
            
            if dev_set_seed > 0:
                permuted_eval_df = eval_df.sample(frac=1, random_state=np.random.RandomState(dev_set_seed))
            else:
                permuted_eval_df = eval_df
                
            dev_data = create_prompt_dataset(train_df, permuted_eval_df, few_shot_seed, few_shot_size, eval_size, 'random')
            dev_df = dev_data['test_df']
            prompts = dev_df.test_ready_prompt.values

            result_df = run_gpt3_on_df(engine, dev_df, prompts, max_tokens=30, sep=dev_data['sep'], logit_bias=10, sep_logit_bias=10, new_line_logit_bias=10)

            df = create_bio_preds(result_df, "predictions")
            f1, precision, recall = conlleval_eval(df.ner_seq,df.bio_preds)

            result_by_seeds.append((few_shot_seed, dev_set_seed, f1, precision, recall, prompts[0]))
            result_dfs_by_seeds[(few_shot_seed, dev_set_seed)] = result_df
        
    return pd.DataFrame(result_by_seeds, columns=["few_shot_seed", "dev_set_seed", "f1", "precision", "recall", "prompt"]), result_dfs_by_seeds

In [33]:
disease_performance_by_prompt_sel, disease_results_by_prompt_sel = test_prompt_selection('davinci',
                                                                                         train_half1, 
                                                                                         train_half2, 
                                                                                         [0,2,4,6],
                                                                                         [0,1])

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_df['test_ready_prompt'] = [few_shot_prompt+'\n\n'+empty_prompt for empty_prompt in test_df['empty_prompts']]
50it [00:19,  2.59it/s]
0it [00:00, ?it/s]

processed 1125 tokens with 37 phrases; found: 32 phrases; correct: 18.
accuracy:  95.73%; (non-O)
accuracy:  95.73%; precision:  56.25%; recall:  48.65%; FB1:  52.17%
                X: precision:  56.25%; recall:  48.65%; FB1:  52.17%  32


50it [00:21,  2.35it/s]
0it [00:00, ?it/s]

processed 1470 tokens with 57 phrases; found: 47 phrases; correct: 23.
accuracy:  93.40%; (non-O)
accuracy:  93.40%; precision:  48.94%; recall:  40.35%; FB1:  44.23%
                X: precision:  48.94%; recall:  40.35%; FB1:  44.23%  47


50it [00:23,  2.14it/s]
0it [00:00, ?it/s]

processed 1125 tokens with 37 phrases; found: 43 phrases; correct: 26.
accuracy:  96.27%; (non-O)
accuracy:  96.27%; precision:  60.47%; recall:  70.27%; FB1:  65.00%
                X: precision:  60.47%; recall:  70.27%; FB1:  65.00%  43


50it [00:22,  2.21it/s]
0it [00:00, ?it/s]

processed 1470 tokens with 57 phrases; found: 58 phrases; correct: 25.
accuracy:  93.88%; (non-O)
accuracy:  93.88%; precision:  43.10%; recall:  43.86%; FB1:  43.48%
                X: precision:  43.10%; recall:  43.86%; FB1:  43.48%  58


50it [00:20,  2.44it/s]
0it [00:00, ?it/s]

processed 1125 tokens with 37 phrases; found: 38 phrases; correct: 20.
accuracy:  95.91%; (non-O)
accuracy:  95.91%; precision:  52.63%; recall:  54.05%; FB1:  53.33%
                X: precision:  52.63%; recall:  54.05%; FB1:  53.33%  38


50it [00:20,  2.49it/s]
0it [00:00, ?it/s]

processed 1470 tokens with 57 phrases; found: 43 phrases; correct: 20.
accuracy:  93.33%; (non-O)
accuracy:  93.33%; precision:  46.51%; recall:  35.09%; FB1:  40.00%
                X: precision:  46.51%; recall:  35.09%; FB1:  40.00%  43


50it [00:20,  2.45it/s]
0it [00:00, ?it/s]

processed 1125 tokens with 37 phrases; found: 26 phrases; correct: 19.
accuracy:  96.62%; (non-O)
accuracy:  96.62%; precision:  73.08%; recall:  51.35%; FB1:  60.32%
                X: precision:  73.08%; recall:  51.35%; FB1:  60.32%  26


50it [00:19,  2.51it/s]

processed 1470 tokens with 57 phrases; found: 35 phrases; correct: 24.
accuracy:  94.69%; (non-O)
accuracy:  94.69%; precision:  68.57%; recall:  42.11%; FB1:  52.17%
                X: precision:  68.57%; recall:  42.11%; FB1:  52.17%  35





In [36]:
disease_performance_by_prompt_sel.sort_values('f1',ascending=False)[["few_shot_seed", "dev_set_seed", "f1", "precision", "recall"]]

Unnamed: 0,few_shot_seed,dev_set_seed,f1,precision,recall
2,2,0,65.0,60.465116,70.27027
6,6,0,60.31746,73.076923,51.351351
4,4,0,53.333333,52.631579,54.054054
0,0,0,52.173913,56.25,48.648649
7,6,1,52.173913,68.571429,42.105263
1,0,1,44.230769,48.93617,40.350877
3,2,1,43.478261,43.103448,43.859649
5,4,1,40.0,46.511628,35.087719


In [70]:
disease_performance_by_prompt_sel.groupby('few_shot_seed').agg({'mean','std'}).sort_values(('f1','mean'),ascending=False)

Unnamed: 0_level_0,dev_set_seed,dev_set_seed,f1,f1,precision,precision,recall,recall
Unnamed: 0_level_1,std,mean,std,mean,std,mean,std,mean
few_shot_seed,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
6,0.71,0.5,5.76,56.25,3.19,70.82,6.54,46.73
2,0.71,0.5,15.22,54.24,12.28,51.78,18.68,57.06
0,0.71,0.5,5.62,48.2,5.17,52.59,5.87,44.5
4,0.71,0.5,9.43,46.67,4.33,49.57,13.41,44.57


In [59]:
print(disease_results_by_prompt_sel[(4,1)].test_ready_prompt.values[0])

Sentence: Desferrioxamine withdrawal resulted in a complete recovery of visual function in 1 patient and partial recovery in 3 , and a complete reversal of hearing loss in 3 patients and partial recovery in 3 .
Diseases: hearing loss

Sentence: The block was successful and surgery was conducted as scheduled despite persisting atrial fibrillation .
Diseases: atrial fibrillation

Sentence: An objective causality assessment revealed that the adverse drug event was probably related to the use of ticlopidine .
Diseases: 

Sentence: CONCLUSION : ATT - ALF constituted 5 . 7 % of ALF at our center and had a high mortality rate .
Diseases: alf

Sentence: A policy of unrestricted prescription of appetite suppressants may lead to a high incidence of associated primary pulmonary hypertension .
Diseases: primary pulmonary hypertension

Sentence: METHODS : We present the first case report of a woman with hyperthyroidism treated with propylthiouracil in whom a syndrome of pericarditis , fever , and g

In [40]:
print(disease_results_by_prompt_sel[(2,1)].test_ready_prompt.values[0])

Sentence: We present a 43 - year - old man who developed a coronary aneurysm in the right coronary artery 6 months after receiving a paclitaxel - eluting stent .
Diseases: coronary aneurysm

Sentence: In 44 ( 62 . 8 % ) patients , ATT was prescribed empirically without definitive evidence of tuberculosis .
Diseases: tuberculosis

Sentence: The aim of the present study was to investigate changes in the plasma calcitonin gene - related peptide ( CGRP ) concentration and platelet serotonin ( 5 - hydroxytriptamine , 5 - HT ) content during the immediate headache and the delayed genuine migraine attack provoked by nitroglycerin .
Diseases: headache, migraine

Sentence: Heparan sulphate - associated anionic sites in the glomerular basement membrane were studied in rats 8 months after induction of diabetes by streptozotocin and in age - adn sex - matched control rats , employing the cationic dye cuprolinic blue .
Diseases: diabetes

Sentence: EBFF did not change during PGE1 infusion whereas i

In [39]:
print(disease_results_by_prompt_sel[(0,0)].test_ready_prompt.values[0])

Sentence: In recent years working memory deficits have been reported in users of MDMA ( 3 , 4 - methylenedioxymethamphetamine , ecstasy ) .
Diseases: memory deficits

Sentence: Molecularly , ANF mRNA increased 250 % and SERCA2 mRNA decreased 57 % .
Diseases: 

Sentence: Since nonsteroidal anti - inflammatory agents interfere with this compensatory mechanism and may cause acute renal failure , they should be used with caution in such patients .
Diseases: acute renal failure

Sentence: In conclusion , CNS complications are frequent events during ALL therapy , and require rapid detection and prompt treatment to limit permanent damage .
Diseases: all

Sentence: The relative amounts of alphaENaC , betaENaC and gammaENaC mRNAs were determined in kidneys from these rats by real - time quantitative TaqMan PCR , and the amounts of proteins by Western blot .
Diseases: 

Sentence: The site of common side effects of sumatriptan .
Diseases:


In [58]:
print(disease_results_by_prompt_sel[(6,0)].test_ready_prompt.values[0])

Sentence: Furthermore , the effects are mediated through dopamine rather than norepinephrine and do not require the carotid sinus baroreceptors .
Diseases: 

Sentence: Amiloride reduced the drinking and urine volume of rats in an acute ( 6 or 12 h ) and a subacute ( 3 days ) experiment .
Diseases: 

Sentence: An electroencephalogram showed continuous , generalized irregular slowing with admixed periodic triphasic waves indicating symptomatic encephalopathy .
Diseases: encephalopathy

Sentence: Whatever was the dose , the central administration of U - II had no effect on body temperature , nociception , apomorphine - induced penile erection and climbing behavior , and stress - induced plasma corticosterone level .
Diseases: penile erection

Sentence: Most of these patients had more than two metastatic sites , with lung metastasis predominant .
Diseases: 

Sentence: The site of common side effects of sumatriptan .
Diseases:


In [47]:
result_by_seeds = []

for s1,s2 in disease_results_by_prompt_sel.keys(): 
    result_df = disease_results_by_prompt_sel[(s1,s2)]
    
    df = create_bio_preds(result_df, "predictions")
    b_true = [s.replace('I','B') for s in df.ner_seq]
    b_pred = [s.replace('I','B') for s in df.bio_preds]
    
    f1, precision, recall = conlleval_eval(b_true,b_pred)

    result_by_seeds.append((s1, s2, f1, precision, recall))

processed 1125 tokens with 55 phrases; found: 43 phrases; correct: 25.
accuracy:  95.73%; (non-O)
accuracy:  95.73%; precision:  58.14%; recall:  45.45%; FB1:  51.02%
                X: precision:  58.14%; recall:  45.45%; FB1:  51.02%  43
processed 1470 tokens with 103 phrases; found: 79 phrases; correct: 44.
accuracy:  93.61%; (non-O)
accuracy:  93.61%; precision:  55.70%; recall:  42.72%; FB1:  48.35%
                X: precision:  55.70%; recall:  42.72%; FB1:  48.35%  79
processed 1125 tokens with 55 phrases; found: 58 phrases; correct: 36.
accuracy:  96.36%; (non-O)
accuracy:  96.36%; precision:  62.07%; recall:  65.45%; FB1:  63.72%
                X: precision:  62.07%; recall:  65.45%; FB1:  63.72%  58
processed 1470 tokens with 103 phrases; found: 93 phrases; correct: 56.
accuracy:  94.29%; (non-O)
accuracy:  94.29%; precision:  60.22%; recall:  54.37%; FB1:  57.14%
                X: precision:  60.22%; recall:  54.37%; FB1:  57.14%  93
processed 1125 tokens with 55 phrases;

In [52]:
result_by_seeds = pd.DataFrame(result_by_seeds).sort_values(2)

In [71]:
result_by_seeds.groupby(0).agg({'mean','std'}).sort_values((2,'mean'),ascending=False)

Unnamed: 0_level_0,1,1,2,2,3,3,4,4
Unnamed: 0_level_1,std,mean,std,mean,std,mean,std,mean
0,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
2,0.71,0.5,4.65,60.43,1.31,61.14,7.84,59.91
6,0.71,0.5,3.71,54.2,1.8,74.49,3.99,42.63
4,0.71,0.5,11.33,49.93,3.48,57.15,15.83,45.17
0,0.71,0.5,1.89,49.69,1.73,56.92,1.93,44.09


In [56]:
result_by_seeds

Unnamed: 0,0,1,2,3,4
5,4,1,41.916168,54.6875,33.980583
1,0,1,48.351648,55.696203,42.718447
0,0,0,51.020408,58.139535,45.454545
7,6,1,51.572327,73.214286,39.805825
6,6,0,56.818182,75.757576,45.454545
3,2,1,57.142857,60.215054,54.368932
4,4,0,57.943925,59.615385,56.363636
2,2,0,63.716814,62.068966,65.454545
