In [35]:
import pandas as pd
import matplotlib.pyplot as plt

In [36]:
df = pd.read_json('data.json')

if 'no_literals' in df.columns:
    df = df[df['no_literals'] == False].drop(labels=['no_literals'], axis=1).reset_index(drop=True)

# add zeroshot to all retrieval methods
zeroshot = pd.DataFrame(df[df['nshot'] == 0])
for retrieval_method in df['retrieval_method'].unique():
    if retrieval_method != 'random':
        zeroshot['retrieval_method'] = [retrieval_method] * len(zeroshot)
        df = pd.concat([df, zeroshot])

print(len(df))
print(df['retrieval_method'].value_counts())
assert len(set(df['retrieval_method'].value_counts().to_list())) == 1
df = df.sort_values(by=['retrieval_method', 'nshot']).reset_index(drop=True)
df['generation'] = df['generation'].map(lambda i: i.strip())

# let's join masked and unmasked
masked = df[df['masked'] == True].reset_index(drop=True)
unmasked = df[df['masked'] == False].reset_index(drop=True)

joined = unmasked.join(
    masked[['generation', 'precision', 'recall', 'f1']],
    lsuffix='_unmasked',
    rsuffix='_masked'
    ).drop(labels=['masked'], axis=1)

for metric in ['precision', 'recall', 'f1']:
    joined[f'{metric}_diff'] = joined[f'{metric}_unmasked'].to_numpy() - joined[f'{metric}_masked'].to_numpy()

joined.head()

18120
column_jaccard    6040
tfidf             6040
random            6040
Name: retrieval_method, dtype: int64


Unnamed: 0,question,query,generation_unmasked,retrieval_method,nshot,precision_unmasked,recall_unmasked,f1_unmasked,generation_masked,precision_masked,recall_masked,f1_masked,precision_diff,recall_diff,f1_diff
0,what is warfarin sodium 2.5 mg po tabs's way o...,select distinct medication.routeadmin from med...,The query extracts all of the patient's medica...,column_jaccard,0,0.590455,0.66114,0.623802,The query selects all medications that have th...,0.59909,0.635618,0.616814,-0.008635,0.025522,0.006988
1,"what is the method for ingestion of ns 1,000 ml?",select distinct medication.routeadmin from med...,This query selects all the rows from the medic...,column_jaccard,0,0.563709,0.675585,0.614598,This query SELECTs all medications that have t...,0.492766,0.627699,0.552108,0.070943,0.047886,0.06249
2,how is atorvastatin calcium 80 mg po tabs taken?,select distinct medication.routeadmin from med...,This query is looking for all medications with...,column_jaccard,0,0.653431,0.766081,0.705286,The query selects all medications where the fi...,0.643358,0.761101,0.697294,0.010073,0.00498,0.007992
3,what is metoprolol succinate er 50 mg po tb24'...,select distinct medication.routeadmin from med...,This query presents a list of distinct medicat...,column_jaccard,0,0.696146,0.752771,0.723352,This query selects the difference between the ...,0.599704,0.748873,0.666039,0.096442,0.003898,0.057313
4,tell me the price of a procedure called agent ...,select distinct cost.cost from cost where cost...,This query will return all of the different co...,column_jaccard,0,0.677709,0.668955,0.673304,The query selects all column values in cost th...,0.589344,0.745333,0.658223,0.088365,-0.076378,0.01508


In [37]:
cols = ['retrieval_method', 'nshot']
for m in ['precision', 'recall', 'f1']:
    cols += [i for i in joined.columns if m in i]
joined = joined[cols]

In [38]:
d = joined[joined['retrieval_method'] == 'random']
d.groupby('nshot').mean().round(3)

  d.groupby('nshot').mean().round(3)


Unnamed: 0_level_0,precision_unmasked,precision_masked,precision_diff,recall_unmasked,recall_masked,recall_diff,f1_unmasked,f1_masked,f1_diff
nshot,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
0,0.53,0.5,0.03,0.539,0.533,0.005,0.533,0.514,0.019
1,0.63,0.611,0.019,0.616,0.6,0.016,0.621,0.603,0.018
2,0.649,0.624,0.025,0.644,0.61,0.034,0.645,0.615,0.03
3,0.665,0.637,0.028,0.662,0.632,0.031,0.662,0.632,0.03


In [39]:
d = joined[joined['retrieval_method'] == 'column_jaccard']
d.groupby('nshot').mean().round(3)

  d.groupby('nshot').mean().round(3)


Unnamed: 0_level_0,precision_unmasked,precision_masked,precision_diff,recall_unmasked,recall_masked,recall_diff,f1_unmasked,f1_masked,f1_diff
nshot,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
0,0.53,0.5,0.03,0.539,0.533,0.005,0.533,0.514,0.019
1,0.715,0.699,0.016,0.703,0.684,0.02,0.708,0.69,0.018
2,0.748,0.738,0.01,0.743,0.729,0.014,0.745,0.732,0.012
3,0.767,0.761,0.006,0.761,0.752,0.009,0.763,0.755,0.007


In [40]:
d = joined[joined['retrieval_method'] == 'tfidf']
d.groupby('nshot').mean().round(3)

  d.groupby('nshot').mean().round(3)


Unnamed: 0_level_0,precision_unmasked,precision_masked,precision_diff,recall_unmasked,recall_masked,recall_diff,f1_unmasked,f1_masked,f1_diff
nshot,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
0,0.53,0.5,0.03,0.539,0.533,0.005,0.533,0.514,0.019
1,0.737,0.722,0.014,0.727,0.707,0.02,0.73,0.713,0.017
2,0.758,0.749,0.01,0.751,0.731,0.019,0.753,0.738,0.015
3,0.77,0.762,0.008,0.764,0.75,0.014,0.766,0.755,0.011
