# References

Dataset Source: https://figshare.com/articles/dataset/DDXPlus_Dataset/20043374

Dataset Files for local use with notebook (unzip in main dir of repo, should be ignored in .gitignore for commits/push to repo): https://gtvault-my.sharepoint.com/:f:/g/personal/agullapalli3_gatech_edu/ElSECU9qSLdAghI4zvFlXegBczZ9LFDRCFIo5SEWSeejQw?e=CKC9V8

Notes: Translate from French to English

# Imports

In [1]:
import pandas as pd
import numpy as np
import json

# Load Dataset

In [2]:
dataset_base_path = './data/original_dataset/'

In [3]:
train_data = f'{dataset_base_path}release_train_patients.csv'
val_data = f'{dataset_base_path}release_validate_patients.csv'
test_data = f'{dataset_base_path}release_test_patients.csv'

In [4]:
with open(f'{dataset_base_path}release_evidences.json', 'r', encoding='utf-8') as f:
    release_evidences = json.load(f)
release_evidences

{'fievre': {'name': 'fievre',
  'code_question': 'fievre',
  'question_fr': 'Avez-vous objectivé ou ressenti de la fièvre?',
  'question_en': 'Do you have a fever (either felt or measured with a thermometer)?',
  'is_antecedent': False,
  'default_value': 0,
  'value_meaning': {},
  'possible-values': [],
  'data_type': 'B'},
 'douleurxx_endroitducorps': {'name': 'douleurxx_endroitducorps',
  'code_question': 'douleurxx',
  'question_fr': 'Avez-vous de la douleur quelque part?',
  'question_en': 'Do you feel pain somewhere?',
  'is_antecedent': False,
  'default_value': 'nulle_part',
  'value_meaning': {'nulle_part': {'fr': 'nulle part', 'en': 'nowhere'},
   'aile_iliaque_D_': {'fr': 'aile iliaque(D)', 'en': 'iliac wing(R)'},
   'aile_iliaque_G_': {'fr': 'aile iliaque(G)', 'en': 'iliac wing(L)'},
   'aine_D_': {'fr': 'aine(D)', 'en': 'groin(R)'},
   'aine_G_': {'fr': 'aine(G)', 'en': 'groin(L)'},
   'aisselle_G_': {'fr': 'aisselle(G)', 'en': 'axilla(L)'},
   'aisselle_D_': {'fr': 'aiss

In [5]:
with open(f'{dataset_base_path}release_conditions.json', 'r', encoding='utf-8') as f:
    release_conditions = json.load(f)
release_conditions

{'Pneumothorax spontané': {'condition_name': 'Pneumothorax spontané',
  'cond-name-fr': 'Pneumothorax spontané',
  'cond-name-eng': 'Spontaneous pneumothorax',
  'icd10-id': 'J93',
  'symptoms': {'douleurxx_endroitducorps': {},
   'douleurxx': {},
   'douleurxx_irrad': {},
   'douleurxx_carac': {},
   'douleurxx_soudain': {},
   'douleurxx_intens': {},
   'douleurxx_precis': {},
   'dyspn': {},
   'ww_respi': {},
   'ww_effort': {},
   'angor_repos': {},
   'oedeme': {}},
  'antecedents': {'f17.210': {},
   'pneumothorax': {},
   'ap_pneumothorax': {},
   'j44_j42': {},
   'trav1': {}},
  'severity': 2},
 'Céphalée en grappe': {'condition_name': 'Céphalée en grappe',
  'cond-name-fr': 'Céphalée en grappe',
  'cond-name-eng': 'Cluster headache',
  'icd10-id': 'g44.009',
  'symptoms': {'douleurxx_endroitducorps': {},
   'douleurxx': {},
   'douleurxx_irrad': {},
   'douleurxx_carac': {},
   'douleurxx_soudain': {},
   'douleurxx_intens': {},
   'douleurxx_precis': {},
   'larmes': {},
  

In [6]:
train_df = pd.read_csv(train_data, encoding='utf-8')
train_df

Unnamed: 0,AGE,DIFFERENTIAL_DIAGNOSIS,SEX,PATHOLOGY,EVIDENCES,INITIAL_EVIDENCE
0,18,"[['Bronchite', 0.19171203430383882], ['Pneumon...",M,IVRS ou virémie,"['crowd', 'diaph', 'douleurxx', 'douleurxx_car...",fievre
1,21,"[['VIH (Primo-infection)', 0.5189500564407601]...",M,VIH (Primo-infection),"['adp_dlr', 'atcd_its', 'diaph', 'diarrhee', '...",diaph
2,19,"[['Bronchite', 0.11278064619119596], ['Pneumon...",F,Pneumonie,"['douleurxx', 'douleurxx_carac_@_un_coup_de_co...",expecto
3,34,"[['IVRS ou virémie', 0.23859396799565236], ['C...",F,IVRS ou virémie,"['crowd', 'douleurxx', 'douleurxx_carac_@_une_...",douleurxx
4,36,"[['IVRS ou virémie', 0.23677812769175735], ['P...",M,IVRS ou virémie,"['dayc', 'diaph', 'douleurxx', 'douleurxx_cara...",toux
...,...,...,...,...,...,...
1025597,18,"[['Épiglottite', 0.28156957795466475], ['VIH (...",M,Épiglottite,"['bw_bending', 'douleurxx', 'douleurxx_carac_@...",fievre
1025598,28,"[['Épiglottite', 0.3703962237298842], ['Laryng...",F,Épiglottite,"['douleurxx', 'douleurxx_carac_@_vive', 'doule...",fievre
1025599,0,"[['Épiglottite', 0.13193905052537108], ['Laryn...",F,Épiglottite,"['bw_bending', 'douleurxx', 'douleurxx_carac_@...",stridor
1025600,26,"[['Épiglottite', 0.3028258988138983], ['Laryng...",F,Épiglottite,"['douleurxx', 'douleurxx_carac_@_un_coup_de_co...",stridor


In [7]:
val_df = pd.read_csv(val_data, encoding='utf-8')
val_df

Unnamed: 0,AGE,DIFFERENTIAL_DIAGNOSIS,SEX,PATHOLOGY,EVIDENCES,INITIAL_EVIDENCE
0,55,"[['Anémie', 0.25071110167158567], ['Fibrillati...",F,Anémie,"['Mauv_aliment', 'atcd_anem', 'atcd_fam_anem',...",pale
1,10,"[['Syndrome de Guillain-Barré', 0.135558991316...",F,Attaque de panique,"['anxiete_s', 'atcdpsyfam', 'diaph', 'douleurx...",psy_depers
2,68,[['Possible influenza ou syndrome virémique ty...,F,Possible influenza ou syndrome virémique typique,"['diaph', 'douleurxx', 'douleurxx_carac_@_une_...",douleurxx
3,13,"[['Anémie', 0.18697604010451876], ['Fibrillati...",M,Anémie,"['Mauv_aliment', 'atcd_anem', 'atcd_fam_anem',...",douleurxx
4,48,"[['Syndrome de Boerhaave', 1.0]]",M,Syndrome de Boerhaave,"['douleurxx', 'douleurxx_carac_@_déchirante', ...",douleurxx
...,...,...,...,...,...,...
132443,27,"[['Pharyngite virale', 0.22702125813983617], [...",M,Pharyngite virale,"['contact', 'crowd', 'douleurxx', 'douleurxx_c...",toux
132444,57,"[['OAP/Surcharge pulmonaire', 0.12078088376840...",M,OAP/Surcharge pulmonaire,"['J81', 'douleurxx', 'douleurxx_carac_@_pénibl...",oedeme
132445,52,"[['RGO', 0.24494427036287517], ['Bronchite', 0...",F,RGO,"['douleurxx', 'douleurxx_carac_@_lancinante_/_...",pyrosis
132446,10,"[['Épiglottite', 0.2969684152571116], ['VIH (P...",M,Épiglottite,"['douleurxx', 'douleurxx_carac_@_un_coup_de_co...",fievre


In [8]:
test_df = pd.read_csv(test_data, encoding='utf-8')
test_df

Unnamed: 0,AGE,DIFFERENTIAL_DIAGNOSIS,SEX,PATHOLOGY,EVIDENCES,INITIAL_EVIDENCE
0,49,"[['Bronchite', 0.20230062181160519], ['RGO', 0...",F,RGO,"['douleurxx', 'douleurxx_carac_@_lancinante_/_...",toux
1,2,"[['Asthme exacerbé ou bronchospasme', 0.080220...",M,Bronchite,"['douleurxx', 'douleurxx_carac_@_une_brûlure_o...",douleurxx
2,49,"[['Réaction dystonique aïgue', 0.6267050848165...",M,Réaction dystonique aïgue,"['antipsy_récent', 'laryngospasme', 'nau_psy_r...",laryngospasme
3,64,"[['Bronchite', 0.2748608320637265], ['Laryngit...",M,Laryngite aigue,"['crowd', 'dayc', 'douleurxx', 'douleurxx_cara...",douleurxx
4,70,"[['IVRS ou virémie', 0.21257615919851483], ['P...",F,IVRS ou virémie,"['contact', 'diaph', 'douleurxx', 'douleurxx_c...",toux
...,...,...,...,...,...,...
134524,52,"[['Possible NSTEMI / STEMI', 0.268768209851499...",M,Lupus érythémateux disséminé (LED),"['I30', 'douleurxx', 'douleurxx_carac_@_vive',...",douleurxx
134525,88,"[['néoplasie pulmonaire', 0.09094757620611861]...",F,néoplasie pulmonaire,"['crach_sg', 'douleurxx', 'douleurxx_carac_@_u...",fatig_mod
134526,29,"[['Attaque de panique', 0.29281344656090524], ...",F,Syndrome de Boerhaave,"['douleurxx', 'douleurxx_carac_@_déchirante', ...",dyspn
134527,8,"[['Scombroïde', 0.1389590231491235], ['TSVP', ...",M,Scombroïde,"['dyspn', 'faible', 'flushing', 'lesions_peau'...",palpit


# Translate Pathology

In [77]:
def translate_pathology(df):
    def translate_pathology_french_to_english(x):
        if x in release_conditions:
            return release_conditions[x]['cond-name-eng']
        else:
            print('not found')
            return ''
    df['English Pathology'] = df['PATHOLOGY'].map(translate_pathology_french_to_english)
    return df

# Translate Evidences

In [78]:
# sample_evids = test_df['EVIDENCES'][500]
# print(sample_evids)

#we didn't have to split the symptoms

def translate_evidences(df):
    def get_qa_en(sample_evids):
        sample_evids = [evid[1:-1] for evid in sample_evids.strip('][').split(', ')]
        new_evid = []
        final_dict = {}
        for i in sample_evids:
            temp = i.split('_@_')
            new_evid.append(temp)
        for i in new_evid:
            if len(i) == 1:
                i.append('True')

        for i in new_evid:
            q_en = release_evidences[i[0]]["question_en"] 
            answ = ''
            if i[1] == "True" or i[1].isnumeric():
                answ = i[1]
            else:
                answ = release_evidences[i[0]]['value_meaning'][i[1]]['en']  
                if answ == 'Y':
                    answ = "Yes"
                elif answ == 'N':
                    answ = "No"

            final_dict[q_en] = answ # answer
        return final_dict
    df['English Evidences'] = df['EVIDENCES'].map(get_qa_en)
    return df
        
# print(get_qa_en(sample_evids))            
    
    


# Add ICD-10

In [79]:
icd10_data = './data/englishPathology2ICD.xlsx'

In [80]:
icd10_df = pd.read_excel(icd10_data)
icd10_df

Unnamed: 0,Pathology,ICD-10
0,acute copd exacerbation / infection,J44.1
1,acute dystonic reactions,G24.9
2,acute laryngitis,J04.0
3,acute otitis media,H66.9
4,acute pulmonary edema,J81.0
5,acute rhinosinusitis,J01.9
6,allergic sinusitis,J01
7,anaphylaxis,T78.0
8,anemia,D64.9
9,atrial fibrillation,I48.9


In [81]:
def add_icd10_code(df):
    def return_icd10_code(x):
        return icd10_df[icd10_df['Pathology '] == x.lower()].iloc[0]['ICD-10']
    df['ICD-10'] = df['English Pathology'].map(return_icd10_code)
    return df

# Apply Translation for Pathology & Evidences and Add ICD10 Code

In [82]:
train_df = translate_pathology(train_df)
train_df = translate_evidences(train_df)
train_df = add_icd10_code(train_df)
train_df

Unnamed: 0,AGE,DIFFERENTIAL_DIAGNOSIS,SEX,PATHOLOGY,EVIDENCES,INITIAL_EVIDENCE,English Pathology,English Evidences,ICD-10
0,18,"[['Bronchite', 0.19171203430383882], ['Pneumon...",M,IVRS ou virémie,"['crowd', 'diaph', 'douleurxx', 'douleurxx_car...",fievre,URTI,"{'Do you live with 4 or more people?': 'True',...",J06.9
1,21,"[['VIH (Primo-infection)', 0.5189500564407601]...",M,VIH (Primo-infection),"['adp_dlr', 'atcd_its', 'diaph', 'diarrhee', '...",diaph,HIV (initial infection),{'Do you have swollen or painful lymph nodes?'...,B20
2,19,"[['Bronchite', 0.11278064619119596], ['Pneumon...",F,Pneumonie,"['douleurxx', 'douleurxx_carac_@_un_coup_de_co...",expecto,Pneumonia,"{'Do you have pain somewhere, related to your ...",J18.9
3,34,"[['IVRS ou virémie', 0.23859396799565236], ['C...",F,IVRS ou virémie,"['crowd', 'douleurxx', 'douleurxx_carac_@_une_...",douleurxx,URTI,"{'Do you live with 4 or more people?': 'True',...",J06.9
4,36,"[['IVRS ou virémie', 0.23677812769175735], ['P...",M,IVRS ou virémie,"['dayc', 'diaph', 'douleurxx', 'douleurxx_cara...",toux,URTI,{'Do you attend or work in a daycare?': 'True'...,J06.9
...,...,...,...,...,...,...,...,...,...
1025597,18,"[['Épiglottite', 0.28156957795466475], ['VIH (...",M,Épiglottite,"['bw_bending', 'douleurxx', 'douleurxx_carac_@...",fievre,Epiglottitis,{'Do you have pain that improves when you lean...,J05.1
1025598,28,"[['Épiglottite', 0.3703962237298842], ['Laryng...",F,Épiglottite,"['douleurxx', 'douleurxx_carac_@_vive', 'doule...",fievre,Epiglottitis,"{'Do you have pain somewhere, related to your ...",J05.1
1025599,0,"[['Épiglottite', 0.13193905052537108], ['Laryn...",F,Épiglottite,"['bw_bending', 'douleurxx', 'douleurxx_carac_@...",stridor,Epiglottitis,{'Do you have pain that improves when you lean...,J05.1
1025600,26,"[['Épiglottite', 0.3028258988138983], ['Laryng...",F,Épiglottite,"['douleurxx', 'douleurxx_carac_@_un_coup_de_co...",stridor,Epiglottitis,"{'Do you have pain somewhere, related to your ...",J05.1


In [83]:
val_df = translate_pathology(val_df)
val_df = translate_evidences(val_df)
val_df = add_icd10_code(val_df)
val_df

Unnamed: 0,AGE,DIFFERENTIAL_DIAGNOSIS,SEX,PATHOLOGY,EVIDENCES,INITIAL_EVIDENCE,English Pathology,English Evidences,ICD-10
0,55,"[['Anémie', 0.25071110167158567], ['Fibrillati...",F,Anémie,"['Mauv_aliment', 'atcd_anem', 'atcd_fam_anem',...",pale,Anemia,"{'Do you have a poor diet?': 'True', 'Have you...",D64.9
1,10,"[['Syndrome de Guillain-Barré', 0.135558991316...",F,Attaque de panique,"['anxiete_s', 'atcdpsyfam', 'diaph', 'douleurx...",psy_depers,Panic attack,"{'Do you feel anxious?': 'True', 'Do any membe...",F41.0
2,68,[['Possible influenza ou syndrome virémique ty...,F,Possible influenza ou syndrome virémique typique,"['diaph', 'douleurxx', 'douleurxx_carac_@_une_...",douleurxx,Influenza,{'Have you had significantly increased sweatin...,J10 or J11
3,13,"[['Anémie', 0.18697604010451876], ['Fibrillati...",M,Anémie,"['Mauv_aliment', 'atcd_anem', 'atcd_fam_anem',...",douleurxx,Anemia,"{'Do you have a poor diet?': 'True', 'Have you...",D64.9
4,48,"[['Syndrome de Boerhaave', 1.0]]",M,Syndrome de Boerhaave,"['douleurxx', 'douleurxx_carac_@_déchirante', ...",douleurxx,Boerhaave,"{'Do you have pain somewhere, related to your ...",K22.3
...,...,...,...,...,...,...,...,...,...
132443,27,"[['Pharyngite virale', 0.22702125813983617], [...",M,Pharyngite virale,"['contact', 'crowd', 'douleurxx', 'douleurxx_c...",toux,Viral pharyngitis,{'Have you been in contact with a person with ...,J02.9
132444,57,"[['OAP/Surcharge pulmonaire', 0.12078088376840...",M,OAP/Surcharge pulmonaire,"['J81', 'douleurxx', 'douleurxx_carac_@_pénibl...",oedeme,Acute pulmonary edema,{'Have you ever had fluid in your lungs?': 'Tr...,J81.0
132445,52,"[['RGO', 0.24494427036287517], ['Bronchite', 0...",F,RGO,"['douleurxx', 'douleurxx_carac_@_lancinante_/_...",pyrosis,GERD,"{'Do you have pain somewhere, related to your ...",K21
132446,10,"[['Épiglottite', 0.2969684152571116], ['VIH (P...",M,Épiglottite,"['douleurxx', 'douleurxx_carac_@_un_coup_de_co...",fievre,Epiglottitis,"{'Do you have pain somewhere, related to your ...",J05.1


In [88]:
test_df = translate_pathology(test_df)
test_df = translate_evidences(test_df)
test_df = add_icd10_code(test_df)
test_df

Unnamed: 0,AGE,DIFFERENTIAL_DIAGNOSIS,SEX,PATHOLOGY,EVIDENCES,INITIAL_EVIDENCE,English Pathology,English Evidences,ICD-10
0,49,"[['Bronchite', 0.20230062181160519], ['RGO', 0...",F,RGO,"['douleurxx', 'douleurxx_carac_@_lancinante_/_...",toux,GERD,"{'Do you have pain somewhere, related to your ...",K21
1,2,"[['Asthme exacerbé ou bronchospasme', 0.080220...",M,Bronchite,"['douleurxx', 'douleurxx_carac_@_une_brûlure_o...",douleurxx,Bronchitis,"{'Do you have pain somewhere, related to your ...",J20.9
2,49,"[['Réaction dystonique aïgue', 0.6267050848165...",M,Réaction dystonique aïgue,"['antipsy_récent', 'laryngospasme', 'nau_psy_r...",laryngospasme,Acute dystonic reactions,{'Have you started or taken any antipsychotic ...,G24.9
3,64,"[['Bronchite', 0.2748608320637265], ['Laryngit...",M,Laryngite aigue,"['crowd', 'dayc', 'douleurxx', 'douleurxx_cara...",douleurxx,Acute laryngitis,"{'Do you live with 4 or more people?': 'True',...",J04.0
4,70,"[['IVRS ou virémie', 0.21257615919851483], ['P...",F,IVRS ou virémie,"['contact', 'diaph', 'douleurxx', 'douleurxx_c...",toux,URTI,{'Have you been in contact with a person with ...,J06.9
...,...,...,...,...,...,...,...,...,...
134524,52,"[['Possible NSTEMI / STEMI', 0.268768209851499...",M,Lupus érythémateux disséminé (LED),"['I30', 'douleurxx', 'douleurxx_carac_@_vive',...",douleurxx,SLE,"{'Have you ever had a pericarditis?': 'True', ...",M32.9
134525,88,"[['néoplasie pulmonaire', 0.09094757620611861]...",F,néoplasie pulmonaire,"['crach_sg', 'douleurxx', 'douleurxx_carac_@_u...",fatig_mod,Pulmonary neoplasm,"{'Have you been coughing up blood?': 'True', '...",C34.9
134526,29,"[['Attaque de panique', 0.29281344656090524], ...",F,Syndrome de Boerhaave,"['douleurxx', 'douleurxx_carac_@_déchirante', ...",dyspn,Boerhaave,"{'Do you have pain somewhere, related to your ...",K22.3
134527,8,"[['Scombroïde', 0.1389590231491235], ['TSVP', ...",M,Scombroïde,"['dyspn', 'faible', 'flushing', 'lesions_peau'...",palpit,Scombroid food poisoning,{'Are you experiencing shortness of breath or ...,T61.1


# Write DataFrames to CSV files for train-val-test splits

In [85]:
train_df.to_csv('./data/translated_train.csv', encoding='utf-8')

In [86]:
val_df.to_csv('./data/translated_val.csv', encoding='utf-8')

In [89]:
test_df.to_csv('./data/translated_test.csv', encoding='utf-8')

# Load Data Splits into DFs

In [132]:
import pandas as pd
import numpy as np

In [133]:
np.random.seed(42)

In [134]:
test_df = pd.read_csv('./data/translated_test.csv', encoding='utf-8')
test_df

Unnamed: 0.1,Unnamed: 0,AGE,DIFFERENTIAL_DIAGNOSIS,SEX,PATHOLOGY,EVIDENCES,INITIAL_EVIDENCE,English Pathology,English Evidences,ICD-10
0,0,49,"[['Bronchite', 0.20230062181160519], ['RGO', 0...",F,RGO,"['douleurxx', 'douleurxx_carac_@_lancinante_/_...",toux,GERD,"{'Do you have pain somewhere, related to your ...",K21
1,1,2,"[['Asthme exacerbé ou bronchospasme', 0.080220...",M,Bronchite,"['douleurxx', 'douleurxx_carac_@_une_brûlure_o...",douleurxx,Bronchitis,"{'Do you have pain somewhere, related to your ...",J20.9
2,2,49,"[['Réaction dystonique aïgue', 0.6267050848165...",M,Réaction dystonique aïgue,"['antipsy_récent', 'laryngospasme', 'nau_psy_r...",laryngospasme,Acute dystonic reactions,{'Have you started or taken any antipsychotic ...,G24.9
3,3,64,"[['Bronchite', 0.2748608320637265], ['Laryngit...",M,Laryngite aigue,"['crowd', 'dayc', 'douleurxx', 'douleurxx_cara...",douleurxx,Acute laryngitis,"{'Do you live with 4 or more people?': 'True',...",J04.0
4,4,70,"[['IVRS ou virémie', 0.21257615919851483], ['P...",F,IVRS ou virémie,"['contact', 'diaph', 'douleurxx', 'douleurxx_c...",toux,URTI,{'Have you been in contact with a person with ...,J06.9
...,...,...,...,...,...,...,...,...,...,...
134524,134524,52,"[['Possible NSTEMI / STEMI', 0.268768209851499...",M,Lupus érythémateux disséminé (LED),"['I30', 'douleurxx', 'douleurxx_carac_@_vive',...",douleurxx,SLE,"{'Have you ever had a pericarditis?': 'True', ...",M32.9
134525,134525,88,"[['néoplasie pulmonaire', 0.09094757620611861]...",F,néoplasie pulmonaire,"['crach_sg', 'douleurxx', 'douleurxx_carac_@_u...",fatig_mod,Pulmonary neoplasm,"{'Have you been coughing up blood?': 'True', '...",C34.9
134526,134526,29,"[['Attaque de panique', 0.29281344656090524], ...",F,Syndrome de Boerhaave,"['douleurxx', 'douleurxx_carac_@_déchirante', ...",dyspn,Boerhaave,"{'Do you have pain somewhere, related to your ...",K22.3
134527,134527,8,"[['Scombroïde', 0.1389590231491235], ['TSVP', ...",M,Scombroïde,"['dyspn', 'faible', 'flushing', 'lesions_peau'...",palpit,Scombroid food poisoning,{'Are you experiencing shortness of breath or ...,T61.1


In [135]:
val_df = pd.read_csv('./data/translated_val.csv', encoding='utf-8')
val_df

Unnamed: 0.1,Unnamed: 0,AGE,DIFFERENTIAL_DIAGNOSIS,SEX,PATHOLOGY,EVIDENCES,INITIAL_EVIDENCE,English Pathology,English Evidences,ICD-10
0,0,55,"[['Anémie', 0.25071110167158567], ['Fibrillati...",F,Anémie,"['Mauv_aliment', 'atcd_anem', 'atcd_fam_anem',...",pale,Anemia,"{'Do you have a poor diet?': 'True', 'Have you...",D64.9
1,1,10,"[['Syndrome de Guillain-Barré', 0.135558991316...",F,Attaque de panique,"['anxiete_s', 'atcdpsyfam', 'diaph', 'douleurx...",psy_depers,Panic attack,"{'Do you feel anxious?': 'True', 'Do any membe...",F41.0
2,2,68,[['Possible influenza ou syndrome virémique ty...,F,Possible influenza ou syndrome virémique typique,"['diaph', 'douleurxx', 'douleurxx_carac_@_une_...",douleurxx,Influenza,{'Have you had significantly increased sweatin...,J10 or J11
3,3,13,"[['Anémie', 0.18697604010451876], ['Fibrillati...",M,Anémie,"['Mauv_aliment', 'atcd_anem', 'atcd_fam_anem',...",douleurxx,Anemia,"{'Do you have a poor diet?': 'True', 'Have you...",D64.9
4,4,48,"[['Syndrome de Boerhaave', 1.0]]",M,Syndrome de Boerhaave,"['douleurxx', 'douleurxx_carac_@_déchirante', ...",douleurxx,Boerhaave,"{'Do you have pain somewhere, related to your ...",K22.3
...,...,...,...,...,...,...,...,...,...,...
132443,132443,27,"[['Pharyngite virale', 0.22702125813983617], [...",M,Pharyngite virale,"['contact', 'crowd', 'douleurxx', 'douleurxx_c...",toux,Viral pharyngitis,{'Have you been in contact with a person with ...,J02.9
132444,132444,57,"[['OAP/Surcharge pulmonaire', 0.12078088376840...",M,OAP/Surcharge pulmonaire,"['J81', 'douleurxx', 'douleurxx_carac_@_pénibl...",oedeme,Acute pulmonary edema,{'Have you ever had fluid in your lungs?': 'Tr...,J81.0
132445,132445,52,"[['RGO', 0.24494427036287517], ['Bronchite', 0...",F,RGO,"['douleurxx', 'douleurxx_carac_@_lancinante_/_...",pyrosis,GERD,"{'Do you have pain somewhere, related to your ...",K21
132446,132446,10,"[['Épiglottite', 0.2969684152571116], ['VIH (P...",M,Épiglottite,"['douleurxx', 'douleurxx_carac_@_un_coup_de_co...",fievre,Epiglottitis,"{'Do you have pain somewhere, related to your ...",J05.1


In [136]:
train_df = pd.read_csv('./data/translated_train.csv', encoding='utf-8')
train_df

Unnamed: 0.1,Unnamed: 0,AGE,DIFFERENTIAL_DIAGNOSIS,SEX,PATHOLOGY,EVIDENCES,INITIAL_EVIDENCE,English Pathology,English Evidences,ICD-10
0,0,18,"[['Bronchite', 0.19171203430383882], ['Pneumon...",M,IVRS ou virémie,"['crowd', 'diaph', 'douleurxx', 'douleurxx_car...",fievre,URTI,"{'Do you live with 4 or more people?': 'True',...",J06.9
1,1,21,"[['VIH (Primo-infection)', 0.5189500564407601]...",M,VIH (Primo-infection),"['adp_dlr', 'atcd_its', 'diaph', 'diarrhee', '...",diaph,HIV (initial infection),{'Do you have swollen or painful lymph nodes?'...,B20
2,2,19,"[['Bronchite', 0.11278064619119596], ['Pneumon...",F,Pneumonie,"['douleurxx', 'douleurxx_carac_@_un_coup_de_co...",expecto,Pneumonia,"{'Do you have pain somewhere, related to your ...",J18.9
3,3,34,"[['IVRS ou virémie', 0.23859396799565236], ['C...",F,IVRS ou virémie,"['crowd', 'douleurxx', 'douleurxx_carac_@_une_...",douleurxx,URTI,"{'Do you live with 4 or more people?': 'True',...",J06.9
4,4,36,"[['IVRS ou virémie', 0.23677812769175735], ['P...",M,IVRS ou virémie,"['dayc', 'diaph', 'douleurxx', 'douleurxx_cara...",toux,URTI,{'Do you attend or work in a daycare?': 'True'...,J06.9
...,...,...,...,...,...,...,...,...,...,...
1025597,1025597,18,"[['Épiglottite', 0.28156957795466475], ['VIH (...",M,Épiglottite,"['bw_bending', 'douleurxx', 'douleurxx_carac_@...",fievre,Epiglottitis,{'Do you have pain that improves when you lean...,J05.1
1025598,1025598,28,"[['Épiglottite', 0.3703962237298842], ['Laryng...",F,Épiglottite,"['douleurxx', 'douleurxx_carac_@_vive', 'doule...",fievre,Epiglottitis,"{'Do you have pain somewhere, related to your ...",J05.1
1025599,1025599,0,"[['Épiglottite', 0.13193905052537108], ['Laryn...",F,Épiglottite,"['bw_bending', 'douleurxx', 'douleurxx_carac_@...",stridor,Epiglottitis,{'Do you have pain that improves when you lean...,J05.1
1025600,1025600,26,"[['Épiglottite', 0.3028258988138983], ['Laryng...",F,Épiglottite,"['douleurxx', 'douleurxx_carac_@_un_coup_de_co...",stridor,Epiglottitis,"{'Do you have pain somewhere, related to your ...",J05.1


# Compute Dataset Statistics

In [137]:
print("Num Train Samples:", len(train_df))
print("Num Val Samples:", len(val_df))
print("Num Test Samples:", len(test_df))

Num Train Samples: 1025602
Num Val Samples: 132448
Num Test Samples: 134529


In [139]:
print("Train Data ICD-10 Distribution:\n", train_df['ICD-10'].value_counts())
print()
print("Val Data ICD-10 Distribution:\n", val_df['ICD-10'].value_counts())
print()
print("Test Data ICD-10 Distribution:\n", test_df['ICD-10'].value_counts())

Train Data ICD-10 Distribution:
 J06.9         64368
J02.9         61642
D64.9         50665
B20           29013
R60.0         27825
T78.0         27718
I26           27468
J10 or J11    26812
J20.9         26400
J01           26203
G24.9         25982
K21           25979
H66.9         25917
J18.9         25761
F41.0         25019
J04.0         24129
G61.0         22867
I30.9         22785
D86.9         21285
I21.4         21260
I20.0         21244
I48.9         21036
G44.0         20804
J32.9         20579
K40           20235
J98.0         19875
J81.0         19018
C25.9         18846
J47           18795
I47.1         18781
G70           18566
T61.1         18535
J44.1         17661
J05.1         17209
I20.9         16995
A15.9         16245
K22.3         15080
C34.9         14457
J01.9         13578
M32.9         11867
I51.4         11073
J38.5         10998
J93.1         10162
B57.5          9252
A37.9          6070
S22.3          5712
J05.0          2852
A98.4           718
J21    

# Sample Test Data for project

In [90]:
num_test_samples = 10000

In [91]:
sample_test_df = test_df.sample(n=num_test_samples, random_state=42)
sample_test_df

Unnamed: 0,AGE,DIFFERENTIAL_DIAGNOSIS,SEX,PATHOLOGY,EVIDENCES,INITIAL_EVIDENCE,English Pathology,English Evidences,ICD-10
109830,21,"[['Asthme exacerbé ou bronchospasme', 0.117008...",F,Asthme exacerbé ou bronchospasme,"['dyspn', 'fam_j45', 'j32', 'j45', 'toux', 'tr...",toux,Bronchospasm / acute asthma exacerbation,{'Are you experiencing shortness of breath or ...,J98.0
11402,31,"[['Réaction dystonique aïgue', 0.1781816671868...",M,Réaction dystonique aïgue,"['drogues_stimul', 'dyspn', 'laryngospasme', '...",protu_langue,Acute dystonic reactions,{'Do you regularly take stimulant drugs?': 'Tr...,G24.9
30840,3,"[['Rhinosinusite chronique', 0.189267488945076...",F,Rhinosinusite chronique,"['douleurxx', 'douleurxx_carac_@_une_brûlure_o...",hyponos,Chronic rhinosinusitis,"{'Do you have pain somewhere, related to your ...",J32.9
14889,46,"[['VIH (Primo-infection)', 0.48414955175409613...",F,VIH (Primo-infection),"['atcd_its', 'diaph', 'diarrhee', 'douleurxx',...",diaph,HIV (initial infection),{'Have you ever had a sexually transmitted inf...,B20
131903,72,"[['Attaque de panique', 0.09172200566348035], ...",F,Attaque de panique,"['anxiete_s', 'atcdpsyfam', 'diaph', 'douleurx...",palpit,Panic attack,"{'Do you feel anxious?': 'True', 'Do any membe...",F41.0
...,...,...,...,...,...,...,...,...,...
113311,0,"[['VIH (Primo-infection)', 0.3030252901252109]...",M,VIH (Primo-infection),"['atcd_its', 'diarrhee', 'douleurxx', 'douleur...",fatig_ext,HIV (initial infection),{'Have you ever had a sexually transmitted inf...,B20
46200,35,"[['Possible NSTEMI / STEMI', 0.170304300298058...",M,Possible NSTEMI / STEMI,"['HIV', 'diaph', 'douleurxx', 'douleurxx_carac...",fatig_mod,Possible NSTEMI / STEMI,{'Are you infected with the human immunodefici...,I21.4
54048,71,"[['Pharyngite virale', 0.5232406274878201], ['...",F,Pharyngite virale,"['crach_sg', 'dayc', 'douleurxx', 'douleurxx_c...",crach_sg,Viral pharyngitis,"{'Have you been coughing up blood?': 'True', '...",J02.9
27233,61,"[['Possible NSTEMI / STEMI', 0.479718656849398...",M,Sarcoïdose,"['convulsion', 'douleurxx', 'douleurxx_carac_@...",dyspn,Sarcoidosis,{'Have you lost consciousness associated with ...,D86.9


In [92]:
sample_test_df['English Evidences'][109830]

{'Are you experiencing shortness of breath or difficulty breathing in a significant way?': 'True',
 'Do you have any family members who have asthma?': 'True',
 'Have you been diagnosed with chronic sinusitis?': 'True',
 'Do you have asthma or have you ever had to use a bronchodilator in the past?': 'True',
 'Do you have a cough?': 'True',
 'Have you traveled out of the country in the last 4 weeks?': 'No',
 'Do you live in in a big city?': 'True',
 'Have you noticed a wheezing sound when you exhale?': 'True'}

In [93]:
sample_test_df.to_csv('./data/Project_Test_Data.csv', encoding='utf-8')

# Sample Training Data for Few-Shot Learning (one sample for each gender of each pathology)

In [116]:
train_samples_df = train_df.groupby(["SEX", "English Pathology"], group_keys=False).apply(lambda x: x.sample(n=6, random_state=1))
train_samples_df

Unnamed: 0.1,Unnamed: 0,AGE,DIFFERENTIAL_DIAGNOSIS,SEX,PATHOLOGY,EVIDENCES,INITIAL_EVIDENCE,English Pathology,English Evidences,ICD-10
890172,890172,67,[['Exacerbation aigue de MPOC et/ou surinfecti...,F,Exacerbation aigue de MPOC et/ou surinfection ...,"['bode', 'dyspn', 'expecto', 'f17.210', 'j44_j...",wheez,Acute COPD exacerbation / infection,{'Do you have severe Chronic Obstructive Pulmo...,J44.1
519403,519403,72,"[['Bronchite', 0.2429175698894958], ['Exacerba...",F,Exacerbation aigue de MPOC et/ou surinfection ...,"['eampoc1', 'expecto', 'f17.210', 'j44_j42', '...",wheez,Acute COPD exacerbation / infection,{'Have you had one or several flare ups of chr...,J44.1
764990,764990,63,"[['Bronchite', 0.2553296635383857], ['Exacerba...",F,Exacerbation aigue de MPOC et/ou surinfection ...,"['bode', 'eampoc1', 'expecto', 'j44_j42', 'k21...",toux,Acute COPD exacerbation / infection,{'Do you have severe Chronic Obstructive Pulmo...,J44.1
162274,162274,35,[['Exacerbation aigue de MPOC et/ou surinfecti...,F,Exacerbation aigue de MPOC et/ou surinfection ...,"['bode', 'dyspn', 'eampoc1', 'f17.210', 'j44_j...",toux,Acute COPD exacerbation / infection,{'Do you have severe Chronic Obstructive Pulmo...,J44.1
903972,903972,68,"[['Bronchite', 0.11403953766958734], ['Bronchi...",F,Exacerbation aigue de MPOC et/ou surinfection ...,"['dyspn', 'eampoc1', 'expecto', 'f17.210', 'j4...",dyspn,Acute COPD exacerbation / infection,{'Are you experiencing shortness of breath or ...,J44.1
...,...,...,...,...,...,...,...,...,...,...
890578,890578,23,"[['Coqueluche', 0.5363081514145991], ['Bronchi...",M,Coqueluche,"['cont_coq', 'insp_siffla', 'j45', 'posttus_em...",posttus_emesis,Whooping cough,{'Have you been in contact with someone who ha...,A37.9
445642,445642,12,"[['Bronchite', 0.6097703129614549], ['Coqueluc...",M,Coqueluche,"['cont_coq', 'e66', 'posttus_emesis', 'trav1_@...",posttus_emesis,Whooping cough,{'Have you been in contact with someone who ha...,A37.9
127641,127641,6,"[['Coqueluche', 0.4380021077337765], ['Bronchi...",M,Coqueluche,"['cont_coq', 'e66', 'insp_siffla', 'j45', 'pos...",toux_sev,Whooping cough,{'Have you been in contact with someone who ha...,A37.9
90407,90407,81,"[['Coqueluche', 0.5363081514145991], ['Bronchi...",M,Coqueluche,"['cont_coq', 'e66', 'insp_siffla', 'j45', 'pos...",posttus_emesis,Whooping cough,{'Have you been in contact with someone who ha...,A37.9


In [117]:
# sample test
train_samples_df[train_samples_df['PATHOLOGY']=='Coqueluche']

Unnamed: 0.1,Unnamed: 0,AGE,DIFFERENTIAL_DIAGNOSIS,SEX,PATHOLOGY,EVIDENCES,INITIAL_EVIDENCE,English Pathology,English Evidences,ICD-10
24112,24112,71,"[['Coqueluche', 0.5125164486996661], ['Bronchi...",F,Coqueluche,"['cont_coq', 'e66', 'j45', 'posttus_emesis', '...",posttus_emesis,Whooping cough,{'Have you been in contact with someone who ha...,A37.9
702813,702813,41,"[['Bronchite', 0.6097703129614549], ['Coqueluc...",F,Coqueluche,"['cont_coq', 'e66', 'posttus_emesis', 'trav1_@...",posttus_emesis,Whooping cough,{'Have you been in contact with someone who ha...,A37.9
297291,297291,30,"[['Coqueluche', 0.5363081514145991], ['Bronchi...",F,Coqueluche,"['cont_coq', 'insp_siffla', 'j45', 'posttus_em...",insp_siffla,Whooping cough,{'Have you been in contact with someone who ha...,A37.9
3398,3398,38,"[['Coqueluche', 0.5363081514145991], ['Bronchi...",F,Coqueluche,"['cont_coq', 'insp_siffla', 'posttus_emesis', ...",insp_siffla,Whooping cough,{'Have you been in contact with someone who ha...,A37.9
452231,452231,43,"[['Coqueluche', 0.5363081514145991], ['Bronchi...",F,Coqueluche,"['insp_siffla', 'posttus_emesis', 'toux_sev', ...",posttus_emesis,Whooping cough,{'Do you wheeze while inhaling or is your brea...,A37.9
938194,938194,15,"[['Bronchite', 0.6097703129614549], ['Coqueluc...",F,Coqueluche,"['e66', 'j45', 'posttus_emesis', 'trav1_@_N', ...",posttus_emesis,Whooping cough,{'Are you significantly overweight compared to...,A37.9
109079,109079,19,"[['Coqueluche', 0.5363081514145991], ['Bronchi...",M,Coqueluche,"['cont_coq', 'insp_siffla', 'j45', 'posttus_em...",insp_siffla,Whooping cough,{'Have you been in contact with someone who ha...,A37.9
890578,890578,23,"[['Coqueluche', 0.5363081514145991], ['Bronchi...",M,Coqueluche,"['cont_coq', 'insp_siffla', 'j45', 'posttus_em...",posttus_emesis,Whooping cough,{'Have you been in contact with someone who ha...,A37.9
445642,445642,12,"[['Bronchite', 0.6097703129614549], ['Coqueluc...",M,Coqueluche,"['cont_coq', 'e66', 'posttus_emesis', 'trav1_@...",posttus_emesis,Whooping cough,{'Have you been in contact with someone who ha...,A37.9
127641,127641,6,"[['Coqueluche', 0.4380021077337765], ['Bronchi...",M,Coqueluche,"['cont_coq', 'e66', 'insp_siffla', 'j45', 'pos...",toux_sev,Whooping cough,{'Have you been in contact with someone who ha...,A37.9


In [120]:
train_samples_df = train_samples_df.rename(columns={"Unnamed: 0": "OriginalRowKey"})
train_samples_df

Unnamed: 0,OriginalRowKey,AGE,DIFFERENTIAL_DIAGNOSIS,SEX,PATHOLOGY,EVIDENCES,INITIAL_EVIDENCE,English Pathology,English Evidences,ICD-10
890172,890172,67,[['Exacerbation aigue de MPOC et/ou surinfecti...,F,Exacerbation aigue de MPOC et/ou surinfection ...,"['bode', 'dyspn', 'expecto', 'f17.210', 'j44_j...",wheez,Acute COPD exacerbation / infection,{'Do you have severe Chronic Obstructive Pulmo...,J44.1
519403,519403,72,"[['Bronchite', 0.2429175698894958], ['Exacerba...",F,Exacerbation aigue de MPOC et/ou surinfection ...,"['eampoc1', 'expecto', 'f17.210', 'j44_j42', '...",wheez,Acute COPD exacerbation / infection,{'Have you had one or several flare ups of chr...,J44.1
764990,764990,63,"[['Bronchite', 0.2553296635383857], ['Exacerba...",F,Exacerbation aigue de MPOC et/ou surinfection ...,"['bode', 'eampoc1', 'expecto', 'j44_j42', 'k21...",toux,Acute COPD exacerbation / infection,{'Do you have severe Chronic Obstructive Pulmo...,J44.1
162274,162274,35,[['Exacerbation aigue de MPOC et/ou surinfecti...,F,Exacerbation aigue de MPOC et/ou surinfection ...,"['bode', 'dyspn', 'eampoc1', 'f17.210', 'j44_j...",toux,Acute COPD exacerbation / infection,{'Do you have severe Chronic Obstructive Pulmo...,J44.1
903972,903972,68,"[['Bronchite', 0.11403953766958734], ['Bronchi...",F,Exacerbation aigue de MPOC et/ou surinfection ...,"['dyspn', 'eampoc1', 'expecto', 'f17.210', 'j4...",dyspn,Acute COPD exacerbation / infection,{'Are you experiencing shortness of breath or ...,J44.1
...,...,...,...,...,...,...,...,...,...,...
890578,890578,23,"[['Coqueluche', 0.5363081514145991], ['Bronchi...",M,Coqueluche,"['cont_coq', 'insp_siffla', 'j45', 'posttus_em...",posttus_emesis,Whooping cough,{'Have you been in contact with someone who ha...,A37.9
445642,445642,12,"[['Bronchite', 0.6097703129614549], ['Coqueluc...",M,Coqueluche,"['cont_coq', 'e66', 'posttus_emesis', 'trav1_@...",posttus_emesis,Whooping cough,{'Have you been in contact with someone who ha...,A37.9
127641,127641,6,"[['Coqueluche', 0.4380021077337765], ['Bronchi...",M,Coqueluche,"['cont_coq', 'e66', 'insp_siffla', 'j45', 'pos...",toux_sev,Whooping cough,{'Have you been in contact with someone who ha...,A37.9
90407,90407,81,"[['Coqueluche', 0.5363081514145991], ['Bronchi...",M,Coqueluche,"['cont_coq', 'e66', 'insp_siffla', 'j45', 'pos...",posttus_emesis,Whooping cough,{'Have you been in contact with someone who ha...,A37.9


In [121]:
train_samples_df.to_csv('./data/Project_Train_Data_FewShot_Learning_588samples.csv', encoding='utf-8')

In [14]:
val_samples_df = val_df.groupby(["SEX", "English Pathology"], group_keys=False).apply(lambda x: x.sample(n=6, random_state=1))
print(val_samples_df[val_samples_df['PATHOLOGY']=='Coqueluche'])
val_samples_df = val_samples_df.rename(columns={"Unnamed: 0": "OriginalRowKey"})

        Unnamed: 0  AGE                             DIFFERENTIAL_DIAGNOSIS  \
7150          7150   74  [['Coqueluche', 0.5363081514145991], ['Bronchi...   
98197        98197   32  [['Bronchite', 0.6097703129614549], ['Coqueluc...   
24267        24267   25  [['Coqueluche', 0.5363081514145991], ['Bronchi...   
74353        74353   41  [['Coqueluche', 0.5363081514145991], ['Bronchi...   
74036        74036   17  [['Coqueluche', 0.5363081514145991], ['Bronchi...   
86104        86104   19  [['Bronchite', 0.5295150093660963], ['Coqueluc...   
43342        43342   23  [['Coqueluche', 0.5125164486996661], ['Bronchi...   
6497          6497   40  [['Coqueluche', 0.5363081514145991], ['Bronchi...   
87059        87059    8  [['Coqueluche', 0.41467144518644367], ['Bronch...   
28040        28040   50  [['Bronchite', 0.5295150093660963], ['Coqueluc...   
119223      119223    5  [['Coqueluche', 0.5093576474125793], ['Bronchi...   
21316        21316   39  [['Coqueluche', 0.5125164486996661], ['

In [15]:
val_samples_df.to_csv('./data/Project_Val_Data_FewShot_Learning_588samples.csv', encoding='utf-8')

# Parse BigBird Output

In [115]:
import pandas as pd

In [116]:
groundtruth_icd10_df = pd.read_csv('./data/Project_Test_Data.csv')['ICD-10']
groundtruth_icd10_df

0       J98.0
1       G24.9
2       J32.9
3         B20
4       F41.0
        ...  
9995      B20
9996    I21.4
9997    J02.9
9998    D86.9
9999    I48.9
Name: ICD-10, Length: 10000, dtype: object

In [126]:
def parse_pred_gtruth_bigbird(bigbird_perf_dict, bigbird_df):
    num_correct = 0
    num_incomp_pred_icd10codes = 0
    for idx, row in enumerate(bigbird_df['Predicted Values']):
        pred = row
        gtruth = groundtruth_icd10_df[idx]
        bigbird_perf_dict[idx] = {'pred_orig': pred, 'gtruth_orig': gtruth}
        # parse gtruths
        if 'or' not in gtruth:
            bigbird_perf_dict[idx]['gtruth_letter_code'] = gtruth[0]
            bigbird_perf_dict[idx]['gtruth_before_decimal_code'] = gtruth[:3]
            if '.' in gtruth:
                bigbird_perf_dict[idx]['gtruth_icd10_code'] = gtruth[:5]
            else:
                bigbird_perf_dict[idx]['gtruth_icd10_code'] = ''
        else:
            bigbird_perf_dict[idx]['gtruth_letter_code'] = 'J'
            bigbird_perf_dict[idx]['gtruth_before_decimal_code'] = ['J10', 'J11']
            bigbird_perf_dict[idx]['gtruth_icd10_code'] = ''
        # parse preds
        bigbird_perf_dict[idx]['pred_letter_code'] = pred[0]
        bigbird_perf_dict[idx]['pred_before_decimal_code'] = pred[:3]
        if '.' in pred:
            bigbird_perf_dict[idx]['pred_icd10_code'] = pred[:5]
        else:
            num_incomp_pred_icd10codes += 1
            print(pred)
            bigbird_perf_dict[idx]['pred_icd10_code'] = ''

        if pred == gtruth:
            num_correct += 1
    print("Accuracy:", num_correct * 100 / len(bigbird_perf_dict))
    print("num incomplete pred icd10 codes:", num_incomp_pred_icd10codes)
    return bigbird_perf_dict

In [127]:
def compare_bigbird_preds_and_gtruths(return_dict):
    raw_res_dict = {'M_res_letter_code': 0, 'FN_res_letter_code': 0, 'FP_res_letter_code': 0, 'TP_res_letter_code': 0,
               'M_res_before_decimal_code': 0, 'FN_res_before_decimal_code': 0, 'FP_res_before_decimal_code': 0, 'TP_res_before_decimal_code': 0,
               'M_res_icd10_code': 0, 'FN_res_icd10_code': 0, 'FP_res_icd10_code': 0, 'TP_res_icd10_code': 0,
                   'list_gtruths_letter_code': [], 'list_preds_letter_code': [],
                   'list_gtruths_before_decimal_code': [], 'list_preds_before_decimal_code': [],
                   'list_gtruths_icd10_code': [], 'list_preds_icd10_code': []}
    for row_key in return_dict:
        
        # letter code comparison
        if return_dict[row_key]['gtruth_letter_code'] == '': #should not execute
            print("blank gtruth_letter_code")
        else:
            raw_res_dict['list_gtruths_letter_code'].append(return_dict[row_key]['gtruth_letter_code'])
            raw_res_dict['list_preds_letter_code'].append(return_dict[row_key]['pred_letter_code'])
            if return_dict[row_key]['pred_letter_code'] == '':
                return_dict[row_key]['res_letter_code'] = 'FN'
                raw_res_dict['FN_res_letter_code'] += 1
            else:
                if return_dict[row_key]['gtruth_letter_code'] == return_dict[row_key]['pred_letter_code']:
                    return_dict[row_key]['res_letter_code'] = 'TP'
                    raw_res_dict['TP_res_letter_code'] += 1
                else:
                    return_dict[row_key]['res_letter_code'] = 'FP'
                    raw_res_dict['FP_res_letter_code'] += 1
        
        # before decimal comparison
        if return_dict[row_key]['gtruth_before_decimal_code'] == '' or return_dict[row_key]['gtruth_before_decimal_code'] == []: #should not execute
            print("blank gtruth_before_decimal_code")
        else:
#             raw_res_dict['list_gtruths_before_decimal_code'].append(return_dict[row_key]['gtruth_before_decimal_code'])
#             raw_res_dict['list_preds_before_decimal_code'].append(return_dict[row_key]['pred_before_decimal_code'])
            
            if type(return_dict[row_key]['gtruth_before_decimal_code']) == list:
                if return_dict[row_key]['pred_before_decimal_code'] == 'J10':
                    raw_res_dict['list_gtruths_before_decimal_code'].append('J10')
                    raw_res_dict['list_preds_before_decimal_code'].append('J10')
                elif return_dict[row_key]['pred_before_decimal_code'] == 'J11':
                    raw_res_dict['list_gtruths_before_decimal_code'].append('J11')
                    raw_res_dict['list_preds_before_decimal_code'].append('J11')
                else:
                    raw_res_dict['list_gtruths_before_decimal_code'].append('J10') #append either, force wrong
                    raw_res_dict['list_preds_before_decimal_code'].append(return_dict[row_key]['pred_before_decimal_code'])
            else:
                raw_res_dict['list_gtruths_before_decimal_code'].append(return_dict[row_key]['gtruth_before_decimal_code'])
                raw_res_dict['list_preds_before_decimal_code'].append(return_dict[row_key]['pred_before_decimal_code'])
            
            if return_dict[row_key]['pred_before_decimal_code'] == '':
                return_dict[row_key]['res_before_decimal_code'] = 'FN'
                raw_res_dict['FN_res_before_decimal_code'] += 1
#                 print(return_dict[row_key]['pred_before_decimal_code'], return_dict[row_key]['gtruth_before_decimal_code'])
            else:
                if return_dict[row_key]['pred_before_decimal_code'] not in return_dict[row_key]['gtruth_before_decimal_code']:
#                     print(return_dict[row_key]['pred_before_decimal_code'], return_dict[row_key]['gtruth_before_decimal_code'])
                    return_dict[row_key]['res_before_decimal_code'] = 'FP'
                    raw_res_dict['FP_res_before_decimal_code'] += 1
                else:
                    return_dict[row_key]['res_before_decimal_code'] = 'TP'
                    raw_res_dict['TP_res_before_decimal_code'] += 1
#                     print(return_dict[row_key]['pred_before_decimal_code'], return_dict[row_key]['gtruth_before_decimal_code'])
        
        # full icd10_code comparison
        if return_dict[row_key]['gtruth_icd10_code'] == '': # executes if 'gtruth_before_decimal_code' maps to ['J10', 'J11'] or a code without decimal value
#             print(return_dict[row_key]['gtruth_before_decimal_code'])
            return_dict[row_key]['res_icd10_code'] = 'M'
            raw_res_dict['M_res_icd10_code'] += 1
            continue
        else:
            raw_res_dict['list_gtruths_icd10_code'].append(return_dict[row_key]['gtruth_icd10_code'])
            raw_res_dict['list_preds_icd10_code'].append(return_dict[row_key]['pred_icd10_code'])
            if return_dict[row_key]['pred_icd10_code'] == '':
                return_dict[row_key]['res_icd10_code'] = 'FN'
                raw_res_dict['FN_res_icd10_code'] += 1
#                 print(return_dict[row_key]['pred_icd10_code'], return_dict[row_key]['gtruth_icd10_code'])
            else:
                if return_dict[row_key]['pred_icd10_code'] != return_dict[row_key]['gtruth_icd10_code']:
#                     print(return_dict[row_key]['pred_icd10_code'], return_dict[row_key]['gtruth_icd10_code'])
                    return_dict[row_key]['res_icd10_code'] = 'FP'
                    raw_res_dict['FP_res_icd10_code'] += 1
                else:
                    return_dict[row_key]['res_icd10_code'] = 'TP'
                    raw_res_dict['TP_res_icd10_code'] += 1
#                     print(return_dict[row_key]['pred_icd10_code'], return_dict[row_key]['gtruth_icd10_code'])
        
    return return_dict, raw_res_dict

In [128]:
def compute_metrics(raw_res_dict):
    def manual_precision(raw_res_dict):
        raw_res_dict['precision_letter_code'] = raw_res_dict['TP_res_letter_code'] / (raw_res_dict['TP_res_letter_code'] + raw_res_dict['FP_res_letter_code'])
        raw_res_dict['precision_before_decimal_code'] = raw_res_dict['TP_res_before_decimal_code'] / (raw_res_dict['TP_res_before_decimal_code'] + raw_res_dict['FP_res_before_decimal_code'])
        raw_res_dict['precision_icd10_code'] = raw_res_dict['TP_res_icd10_code'] / (raw_res_dict['TP_res_icd10_code'] + raw_res_dict['FP_res_icd10_code'])
        return raw_res_dict
    def manual_recall(raw_res_dict):
        raw_res_dict['recall_letter_code'] = raw_res_dict['TP_res_letter_code'] / (raw_res_dict['TP_res_letter_code'] + raw_res_dict['FN_res_letter_code'])
        raw_res_dict['recall_before_decimal_code'] = raw_res_dict['TP_res_before_decimal_code'] / (raw_res_dict['TP_res_before_decimal_code'] + raw_res_dict['FN_res_before_decimal_code'])
        raw_res_dict['recall_icd10_code'] = raw_res_dict['TP_res_icd10_code'] / (raw_res_dict['TP_res_icd10_code'] + raw_res_dict['FN_res_icd10_code'])
        return raw_res_dict
    def manual_f1(raw_res_dict):
        raw_res_dict['f1_letter_code'] = 2 * raw_res_dict['precision_letter_code'] * raw_res_dict['recall_letter_code'] / (raw_res_dict['precision_letter_code'] + raw_res_dict['recall_letter_code'])
        raw_res_dict['f1_before_decimal_code'] = 2 * raw_res_dict['precision_before_decimal_code'] * raw_res_dict['recall_before_decimal_code'] / (raw_res_dict['precision_before_decimal_code'] + raw_res_dict['recall_before_decimal_code'])
        raw_res_dict['f1_icd10_code'] = 2 * raw_res_dict['precision_icd10_code'] * raw_res_dict['recall_icd10_code'] / (raw_res_dict['precision_icd10_code'] + raw_res_dict['recall_icd10_code'])
        return raw_res_dict
    raw_res_dict = manual_precision(raw_res_dict)
    raw_res_dict = manual_recall(raw_res_dict)
    raw_res_dict = manual_f1(raw_res_dict)
    return raw_res_dict

In [131]:
bigbird_exp_dict = {}
for i in range(2, 3):
    curr_bigbird_df = pd.read_csv(f'./data/predictedBigBird{i}.csv')
    print(len(curr_bigbird_df))
    bigbird_exp_dict[f'exp_{i}'] = {}
    bigbird_exp_dict[f'exp_{i}']['perf_dict'] = parse_pred_gtruth_bigbird({}, curr_bigbird_df)
    bigbird_exp_dict[f'exp_{i}']['perf_dict'], bigbird_exp_dict[f'exp_{i}']['raw_res_dict'] = compare_bigbird_preds_and_gtruths(bigbird_exp_dict[f'exp_{i}']['perf_dict'])
    bigbird_exp_dict[f'exp_{i}']['raw_res_dict'] = compute_metrics(bigbird_exp_dict[f'exp_{i}']['raw_res_dict'])
    print(bigbird_exp_dict[f'exp_{i}']['raw_res_dict'])

10000
K21
J01
J01
G70
B20
K21
K40
B20
J47
J10 or J11
G70
K21
J10 or J11
K21
K40
K21
J47
G70
K21
J47
I26
J47
K40
J10 or J11
I26
J21
K21
I26
J10 or J11
J01
B20
I26
I26
K40
K40
K40
J10 or J11
B20
I26
J01
K21
K40
I26
J47
J01
I26
J10 or J11
K40
J10 or J11
G70
J10 or J11
B20
G70
J01
J10 or J11
B20
J01
K21
J10 or J11
I26
J47
I26
G70
K21
J01
K21
J10 or J11
J01
B20
J47
K40
I26
J10 or J11
K40
J10 or J11
I26
B20
J10 or J11
B20
B20
B20
K40
K21
I26
J10 or J11
G70
I26
J10 or J11
K21
J10 or J11
I26
J01
J10 or J11
K21
K21
J01
K40
I26
B20
B20
K40
I26
J01
K21
I26
K40
G70
J47
G70
J10 or J11
J47
J01
K21
G70
J10 or J11
K40
I26
K21
K21
J10 or J11
J10 or J11
K21
B20
K40
B20
K21
J47
I26
J47
J47
J47
J01
K40
J10 or J11
I26
J47
K21
G70
J01
J47
K21
J47
J10 or J11
J01
J10 or J11
K21
G70
J01
J01
J47
I26
K21
J10 or J11
J10 or J11
J01
J01
J47
K40
G70
K21
J10 or J11
J47
G70
K40
J10 or J11
K21
J01
B20
J10 or J11
J10 or J11
K21
J47
I26
J01
I26
J01
K21
B20
K21
K40
K21
J10 or J11
J47
J10 or J11
K21
I26
I26
J10 or J11
J10 

KeyError: 'raw_res_dict'

# Parse Davinci Fine-Tuned

In [58]:
import json
import re
import pandas as pd

In [59]:
def readJSONFile(filePath):
  f = open(filePath)
  data = json.load(f)
  return data #dictionary mapping to a list of ChatGPT API responses

In [60]:
# Testing with sample responses
filePath = './data/DavinciBatch1.json'
data = readJSONFile(filePath)
chatGPTResponses= data['responses']
print(len(chatGPTResponses))
for dataDict in chatGPTResponses:
  if dataDict:
    print(dataDict)
    inputPrompt = dataDict['inputPrompt']
    response = dataDict['choices'][0]['text']
    print(response)
    rowKey = dataDict['dataPoint']
    print(rowKey)

1000
{'id': 'cmpl-7A7xwnIayXf7RUcH7uWqQteAjCUr9', 'object': 'text_completion', 'created': 1682648256, 'model': 'davinci:ft-personal-2023-04-24-19-19-14', 'choices': [{'text': ' j98.0 {', 'index': 0, 'logprobs': None, 'finish_reason': 'length'}], 'usage': {'prompt_tokens': 181, 'completion_tokens': 5, 'total_tokens': 186}, 'dataPoint': 109830, 'inputPrompt': "Diagnose a 21 year-old female patient by giving the ICD-10 code. Do not give an explanation or any context. Your response should be one token with the ICD-10 code: {'Are you experiencing shortness of breath or difficulty breathing in a significant way?': 'True', 'Do you have any family members who have asthma?': 'True', 'Have you been diagnosed with chronic sinusitis?': 'True', 'Do you have asthma or have you ever had to use a bronchodilator in the past?': 'True', 'Do you have a cough?': 'True', 'Have you traveled out of the country in the last 4 weeks?': 'No', 'Do you live in in a big city?': 'True', 'Have you noticed a wheezing s

In [61]:
def gen_davinci_responses_batch_file_name(batch_num):
    return f'./data/DavinciBatch{batch_num}.json'

In [62]:
def clean_extract_davinci_icd10_response(filepath, return_dict):
#     out_with_letter_code_dict = []
#     out_with_correct_format = []
#     out_with_wrong_format = []
    print(filepath)
    data = readJSONFile(filepath)
    chatGPTResponses= data['responses']
    print(len(chatGPTResponses))
    for dataDict in chatGPTResponses:
        if dataDict:
#             print(dataDict)
            inputPrompt = dataDict['inputPrompt']
            response = dataDict['choices'][0]['text']
#             print(response)
            rowKey = dataDict['dataPoint']
#             print(rowKey)
            return_dict[rowKey] = {'pred_letter_code': '', 'pred_before_decimal_code': '', 'pred_icd10_code': '', 'original_response': response}
            split_response = response.split(' ')
            pred = split_response[1].capitalize()[:5]
            print(pred)
            return_dict[rowKey]['pred_letter_code'] = pred[0]
            return_dict[rowKey]['pred_before_decimal_code'] = pred[:3]
            if '.' in pred:
                return_dict[rowKey]['pred_icd10_code'] = pred[:5]
            else:
                return_dict[rowKey]['pred_icd10_code'] = ''
    #         print(response)
#             if "{'code': " in response:
#     #             out_with_letter_code_dict.append(response)
#                 return_dict[rowKey]['pred_letter_code'] = response[-1]
#             elif response[0].isupper() and response[1].isdigit():
#                 return_dict[rowKey]['pred_letter_code'] = response[0]
#                 return_dict[rowKey]['pred_before_decimal_code'] = response[:3]
#                 if re.search("^[A-Z][0-9][0-9].[0-9]", response.split(' ')[0]):
#                     return_dict[rowKey]['pred_icd10_code'] = response[:5]
#             out_with_correct_format.append(response)
#         else:
#             out_with_wrong_format.append(response)
#         return_dict[rowKey]['original_response'] = response
#         if corrected_response != '':
#             return_dict[rowKey]['corrected_response'] = corrected_response
#     print(len(out_with_letter_code_dict), f"out_with_letter_code_dict", out_with_letter_code_dict)
#     print()
#     print(len(out_with_correct_format), f"out_with_correct_format", out_with_correct_format)
#     print()
#     print(len(out_with_wrong_format), f"out_with_wrong_format", out_with_wrong_format)
    return return_dict

In [63]:
davinci_parsed_dict = {}
for i in range(1,3):
    davinci_parsed_dict = clean_extract_davinci_icd10_response(gen_davinci_responses_batch_file_name(i), davinci_parsed_dict)
print(len(davinci_parsed_dict))

./data/DavinciBatch1.json
1000
J98.0
G24.9
J32.9
B20.9
F41.0
J02.9
G70.0
F41.0
J05.0
G24.9
G24.9
G44.0
B20.9
K21.3
I20.9
I51.4
J32.9
C25.9
I20.9
K22.3
H66.9
J06.9
I20.0
J01.9
I48.9
J44.1
J05.1
J32.9
J81.0
J02.9
F41.0
J02.9
J98.0
J18.9
J32.9
J06.9
I20.9
J10
J98.0
J06.9
K21.3
D86.9
C25.9
J32.9
C25.9
B20.9
I51.4
J06.9
J06.9
G44.0
G24.9
K21.3
J32.9
J20.9
J18.9
J32.9
D64.9
J06.9
J18.9
J32.9
D64.9
J06.9
J01.9
T78.0
D64.9
D86.9
I30.9
J32.9
J44.1
J05.1
T78.0
J98.0
J98.0
I26.9
J02.9
C34.9
J81.0
C34.9
J06.9
M32.9
K21.3
T78.0
J81.0
F41.0
I48.9
J10
G24.9
J06.9
J20.9
J18.9
B20.9
K21.3
J10
I20.9
I20.0
J02.9
J18.9
I51.4
J81.0
F41.0
J18.9
G24.9
T78.0
K21.3
S22.3
G24.9
B20.9
J93.1
G61.0
G61.0
J32.9
J02.9
J06.9
T61.1
J81.0
I51.4
J06.9
J32.9
T61.1
D64.9
K22.3
I30.9
G70.0
J47.1
J06.9
J02.9
T61.1
J05.1
I20.0
J05.1
I20.9
H66.9
J06.9
K40.9
J18.9
I48.9
J32.9
R60.0
J32.9
C34.9
A37.9
T78.0
D64.9
D86.9
I20.0
J98.0
I20.0
C34.9
T61.1
J06.9
H66.9
F41.0
D86.9
J02.9
R60.0
B57.5
K40.9
J81.0
C25.9
I48.9
F41.9
I30.9
T78

In [64]:
if 52308 in davinci_parsed_dict:
    print("TRUE")

TRUE


In [65]:
davinci_parsed_dict

{109830: {'pred_letter_code': 'J',
  'pred_before_decimal_code': 'J98',
  'pred_icd10_code': 'J98.0',
  'original_response': ' j98.0 {'},
 11402: {'pred_letter_code': 'G',
  'pred_before_decimal_code': 'G24',
  'pred_icd10_code': 'G24.9',
  'original_response': ' g24.9 I'},
 30840: {'pred_letter_code': 'J',
  'pred_before_decimal_code': 'J32',
  'pred_icd10_code': 'J32.9',
  'original_response': ' j32.9 i'},
 14889: {'pred_letter_code': 'B',
  'pred_before_decimal_code': 'B20',
  'pred_icd10_code': 'B20.9',
  'original_response': ' b20.9 i'},
 131903: {'pred_letter_code': 'F',
  'pred_before_decimal_code': 'F41',
  'pred_icd10_code': 'F41.0',
  'original_response': ' f41.0 i'},
 104142: {'pred_letter_code': 'J',
  'pred_before_decimal_code': 'J02',
  'pred_icd10_code': 'J02.9',
  'original_response': ' j02.9 i'},
 988: {'pred_letter_code': 'G',
  'pred_before_decimal_code': 'G70',
  'pred_icd10_code': 'G70.0',
  'original_response': ' g70.0 my'},
 34183: {'pred_letter_code': 'F',
  'pr

In [66]:
groundtruth_icd10_df = pd.read_csv('./data/Project_Test_Data.csv')
groundtruth_icd10_df

Unnamed: 0,OriginalRowKey,AGE,DIFFERENTIAL_DIAGNOSIS,SEX,PATHOLOGY,EVIDENCES,INITIAL_EVIDENCE,English Pathology,English Evidences,ICD-10
0,109830,21,"[['Asthme exacerb√© ou bronchospasme', 0.11700...",F,Asthme exacerb√© ou bronchospasme,"['dyspn', 'fam_j45', 'j32', 'j45', 'toux', 'tr...",toux,Bronchospasm / acute asthma exacerbation,{'Are you experiencing shortness of breath or ...,J98.0
1,11402,31,"[['R√©action dystonique a√Øgue', 0.17818166718...",M,R√©action dystonique a√Øgue,"['drogues_stimul', 'dyspn', 'laryngospasme', '...",protu_langue,Acute dystonic reactions,{'Do you regularly take stimulant drugs?': 'Tr...,G24.9
2,30840,3,"[['Rhinosinusite chronique', 0.189267488945076...",F,Rhinosinusite chronique,"['douleurxx', 'douleurxx_carac_@_une_br√ªlure_...",hyponos,Chronic rhinosinusitis,"{'Do you have pain somewhere, related to your ...",J32.9
3,14889,46,"[['VIH (Primo-infection)', 0.48414955175409613...",F,VIH (Primo-infection),"['atcd_its', 'diaph', 'diarrhee', 'douleurxx',...",diaph,HIV (initial infection),{'Have you ever had a sexually transmitted inf...,B20
4,131903,72,"[['Attaque de panique', 0.09172200566348035], ...",F,Attaque de panique,"['anxiete_s', 'atcdpsyfam', 'diaph', 'douleurx...",palpit,Panic attack,"{'Do you feel anxious?': 'True', 'Do any membe...",F41.0
...,...,...,...,...,...,...,...,...,...,...
9995,113311,0,"[['VIH (Primo-infection)', 0.3030252901252109]...",M,VIH (Primo-infection),"['atcd_its', 'diarrhee', 'douleurxx', 'douleur...",fatig_ext,HIV (initial infection),{'Have you ever had a sexually transmitted inf...,B20
9996,46200,35,"[['Possible NSTEMI / STEMI', 0.170304300298058...",M,Possible NSTEMI / STEMI,"['HIV', 'diaph', 'douleurxx', 'douleurxx_carac...",fatig_mod,Possible NSTEMI / STEMI,{'Are you infected with the human immunodefici...,I21.4
9997,54048,71,"[['Pharyngite virale', 0.5232406274878201], ['...",F,Pharyngite virale,"['crach_sg', 'dayc', 'douleurxx', 'douleurxx_c...",crach_sg,Viral pharyngitis,"{'Have you been coughing up blood?': 'True', '...",J02.9
9998,27233,61,"[['Possible NSTEMI / STEMI', 0.479718656849398...",M,Sarco√Ødose,"['convulsion', 'douleurxx', 'douleurxx_carac_@...",dyspn,Sarcoidosis,{'Have you lost consciousness associated with ...,D86.9


In [67]:
def add_groundtruth_icd(return_dict):
    for row_key in return_dict:
        gtruth_icd10 = groundtruth_icd10_df[groundtruth_icd10_df['OriginalRowKey'] == row_key].iloc[0]['ICD-10']
        if 'or' not in gtruth_icd10:
            return_dict[row_key]['gtruth_letter_code'] = gtruth_icd10[0]
            return_dict[row_key]['gtruth_before_decimal_code'] = gtruth_icd10[:3]
            if '.' in gtruth_icd10:
                return_dict[row_key]['gtruth_icd10_code'] = gtruth_icd10[:5]
            else:
                return_dict[row_key]['gtruth_icd10_code'] = ''
        else:
            return_dict[row_key]['gtruth_letter_code'] = 'J'
            return_dict[row_key]['gtruth_before_decimal_code'] = ['J10', 'J11']
            return_dict[row_key]['gtruth_icd10_code'] = ''
    return return_dict

In [68]:
davinci_parsed_dict = add_groundtruth_icd(davinci_parsed_dict)
print(len(davinci_parsed_dict))

2000


In [69]:
davinci_parsed_dict

{109830: {'pred_letter_code': 'J',
  'pred_before_decimal_code': 'J98',
  'pred_icd10_code': 'J98.0',
  'original_response': ' j98.0 {',
  'gtruth_letter_code': 'J',
  'gtruth_before_decimal_code': 'J98',
  'gtruth_icd10_code': 'J98.0'},
 11402: {'pred_letter_code': 'G',
  'pred_before_decimal_code': 'G24',
  'pred_icd10_code': 'G24.9',
  'original_response': ' g24.9 I',
  'gtruth_letter_code': 'G',
  'gtruth_before_decimal_code': 'G24',
  'gtruth_icd10_code': 'G24.9'},
 30840: {'pred_letter_code': 'J',
  'pred_before_decimal_code': 'J32',
  'pred_icd10_code': 'J32.9',
  'original_response': ' j32.9 i',
  'gtruth_letter_code': 'J',
  'gtruth_before_decimal_code': 'J32',
  'gtruth_icd10_code': 'J32.9'},
 14889: {'pred_letter_code': 'B',
  'pred_before_decimal_code': 'B20',
  'pred_icd10_code': 'B20.9',
  'original_response': ' b20.9 i',
  'gtruth_letter_code': 'B',
  'gtruth_before_decimal_code': 'B20',
  'gtruth_icd10_code': ''},
 131903: {'pred_letter_code': 'F',
  'pred_before_decima

In [157]:
list_icd10 = pd.read_excel('./data/englishPathology2ICD.xlsx')
set_icd10_codes = list_icd10['ICD-10'].unique()
set_icd10_codes

array(['J44.1', 'G24.9', 'J04.0', 'H66.9', 'J81.0', 'J01.9', 'J01',
       'T78.0', 'D64.9', 'I48.9', 'K22.3', 'J47', 'J21', 'J20.9', 'J98.0',
       'B57.5', 'J32.9', 'G44.0', 'J05.0', 'A98.4', 'J05.1', 'K21',
       'G61.0', 'B20', 'J10 or J11', 'K40', 'J38.5', 'R60.0', 'G70',
       'I51.4', 'C25.9', 'F41.0', 'I30.9', 'J18.9', 'I21.4', 'I47.1',
       'I26', 'C34.9', 'D86.9', 'T61.1', 'M32.9', 'J93.1', 'S22.3',
       'I20.9', 'A15.9', 'I20.0', 'J06.9', 'J02.9', 'A37.9'], dtype=object)

In [160]:
def compare_preds_and_gtruths(return_dict):
    tmp_count_outofgtruth = 0
    # M for missing ground_truth, FN for false negative (includes case for empty string pred), FP for false positive, TP for true positive
    raw_res_dict = {'M_res_letter_code': 0, 'FN_res_letter_code': 0, 'FP_res_letter_code': 0, 'TP_res_letter_code': 0,
               'M_res_before_decimal_code': 0, 'FN_res_before_decimal_code': 0, 'FP_res_before_decimal_code': 0, 'TP_res_before_decimal_code': 0,
               'M_res_icd10_code': 0, 'FN_res_icd10_code': 0, 'FP_res_icd10_code': 0, 'TP_res_icd10_code': 0,
                   'list_gtruths_letter_code': [], 'list_preds_letter_code': [],
                   'list_gtruths_before_decimal_code': [], 'list_preds_before_decimal_code': [],
                   'list_gtruths_icd10_code': [], 'list_preds_icd10_code': []}
    for row_key in return_dict:
        
        # letter code comparison
        if return_dict[row_key]['gtruth_letter_code'] == '': #should not execute
            print("blank gtruth_letter_code")
        else:
            raw_res_dict['list_gtruths_letter_code'].append(return_dict[row_key]['gtruth_letter_code'])
            raw_res_dict['list_preds_letter_code'].append(return_dict[row_key]['pred_letter_code'])
            if return_dict[row_key]['pred_letter_code'] == '':
                return_dict[row_key]['res_letter_code'] = 'FN'
                raw_res_dict['FN_res_letter_code'] += 1
            else:
                if return_dict[row_key]['gtruth_letter_code'] == return_dict[row_key]['pred_letter_code']:
                    return_dict[row_key]['res_letter_code'] = 'TP'
                    raw_res_dict['TP_res_letter_code'] += 1
                else:
                    return_dict[row_key]['res_letter_code'] = 'FP'
                    raw_res_dict['FP_res_letter_code'] += 1
        
        # before decimal comparison
        if return_dict[row_key]['gtruth_before_decimal_code'] == '' or return_dict[row_key]['gtruth_before_decimal_code'] == []: #should not execute
            print("blank gtruth_before_decimal_code")
        else:
#             raw_res_dict['list_gtruths_before_decimal_code'].append(return_dict[row_key]['gtruth_before_decimal_code'])
#             raw_res_dict['list_preds_before_decimal_code'].append(return_dict[row_key]['pred_before_decimal_code'])
            
            if type(return_dict[row_key]['gtruth_before_decimal_code']) == list:
                if return_dict[row_key]['pred_before_decimal_code'] == 'J10':
                    raw_res_dict['list_gtruths_before_decimal_code'].append('J10')
                    raw_res_dict['list_preds_before_decimal_code'].append('J10')
                elif return_dict[row_key]['pred_before_decimal_code'] == 'J11':
                    raw_res_dict['list_gtruths_before_decimal_code'].append('J11')
                    raw_res_dict['list_preds_before_decimal_code'].append('J11')
                else:
                    raw_res_dict['list_gtruths_before_decimal_code'].append('J10') #append either, force wrong
                    raw_res_dict['list_preds_before_decimal_code'].append(return_dict[row_key]['pred_before_decimal_code'])
            else:
                raw_res_dict['list_gtruths_before_decimal_code'].append(return_dict[row_key]['gtruth_before_decimal_code'])
                raw_res_dict['list_preds_before_decimal_code'].append(return_dict[row_key]['pred_before_decimal_code'])
            
            if return_dict[row_key]['pred_before_decimal_code'] == '':
                return_dict[row_key]['res_before_decimal_code'] = 'FN'
                raw_res_dict['FN_res_before_decimal_code'] += 1
#                 print(return_dict[row_key]['pred_before_decimal_code'], return_dict[row_key]['gtruth_before_decimal_code'])
            else:
                if return_dict[row_key]['pred_before_decimal_code'] not in return_dict[row_key]['gtruth_before_decimal_code']:
#                     print(return_dict[row_key]['pred_before_decimal_code'], return_dict[row_key]['gtruth_before_decimal_code'])
                    return_dict[row_key]['res_before_decimal_code'] = 'FP'
                    raw_res_dict['FP_res_before_decimal_code'] += 1
                else:
                    return_dict[row_key]['res_before_decimal_code'] = 'TP'
                    raw_res_dict['TP_res_before_decimal_code'] += 1
#                     print(return_dict[row_key]['pred_before_decimal_code'], return_dict[row_key]['gtruth_before_decimal_code'])
        
        # full icd10_code comparison
        if return_dict[row_key]['gtruth_icd10_code'] == '': # executes if 'gtruth_before_decimal_code' maps to ['J10', 'J11'] or a code without decimal value
#             print(return_dict[row_key]['gtruth_before_decimal_code'])
            return_dict[row_key]['res_icd10_code'] = 'M'
            raw_res_dict['M_res_icd10_code'] += 1
            continue
        else:
            raw_res_dict['list_gtruths_icd10_code'].append(return_dict[row_key]['gtruth_icd10_code'])
            raw_res_dict['list_preds_icd10_code'].append(return_dict[row_key]['pred_icd10_code'])
            if return_dict[row_key]['pred_icd10_code'] == '':
                return_dict[row_key]['res_icd10_code'] = 'FN'
                raw_res_dict['FN_res_icd10_code'] += 1
#                 print(return_dict[row_key]['pred_icd10_code'], return_dict[row_key]['gtruth_icd10_code'])
            else:
                if return_dict[row_key]['pred_icd10_code'] != return_dict[row_key]['gtruth_icd10_code']:
#                     print(return_dict[row_key]['pred_icd10_code'], return_dict[row_key]['gtruth_icd10_code'])
                    return_dict[row_key]['res_icd10_code'] = 'FP'
                    raw_res_dict['FP_res_icd10_code'] += 1
                    if return_dict[row_key]['pred_icd10_code'] not in set_icd10_codes:
                        print(row_key)
                        print(return_dict[row_key])
                        tmp_count_outofgtruth += 1
                else:
                    return_dict[row_key]['res_icd10_code'] = 'TP'
                    raw_res_dict['TP_res_icd10_code'] += 1
#                     print(return_dict[row_key]['pred_icd10_code'], return_dict[row_key]['gtruth_icd10_code'])
    print(tmp_count_outofgtruth)
    return return_dict, raw_res_dict

In [161]:
davinci_parsed_dict, davinci_raw_results_dict = compare_preds_and_gtruths(davinci_parsed_dict)

97499
{'pred_letter_code': 'F', 'pred_before_decimal_code': 'F41', 'pred_icd10_code': 'F41.9', 'original_response': ' f41.9 i', 'gtruth_letter_code': 'F', 'gtruth_before_decimal_code': 'F41', 'gtruth_icd10_code': 'F41.0', 'res_letter_code': 'TP', 'res_before_decimal_code': 'TP', 'res_icd10_code': 'FP'}
43398
{'pred_letter_code': 'H', 'pred_before_decimal_code': 'H32', 'pred_icd10_code': 'H32.9', 'original_response': ' h32.9 i', 'gtruth_letter_code': 'G', 'gtruth_before_decimal_code': 'G44', 'gtruth_icd10_code': 'G44.0', 'res_letter_code': 'FP', 'res_before_decimal_code': 'FP', 'res_icd10_code': 'FP'}
63256
{'pred_letter_code': 'G', 'pred_before_decimal_code': 'G70', 'pred_icd10_code': 'G70.0', 'original_response': ' g70.0 i', 'gtruth_letter_code': 'G', 'gtruth_before_decimal_code': 'G24', 'gtruth_icd10_code': 'G24.9', 'res_letter_code': 'TP', 'res_before_decimal_code': 'FP', 'res_icd10_code': 'FP'}
49168
{'pred_letter_code': 'I', 'pred_before_decimal_code': 'I51', 'pred_icd10_code': 'I

In [75]:
#davinci_parsed_dict

In [140]:
davinci_raw_results_dict

{'M_res_letter_code': 0,
 'FN_res_letter_code': 0,
 'FP_res_letter_code': 34,
 'TP_res_letter_code': 1966,
 'M_res_before_decimal_code': 0,
 'FN_res_before_decimal_code': 0,
 'FP_res_before_decimal_code': 142,
 'TP_res_before_decimal_code': 1858,
 'M_res_icd10_code': 338,
 'FN_res_icd10_code': 6,
 'FP_res_icd10_code': 141,
 'TP_res_icd10_code': 1515,
 'list_gtruths_letter_code': ['J',
  'G',
  'J',
  'B',
  'F',
  'J',
  'G',
  'F',
  'J',
  'G',
  'G',
  'G',
  'B',
  'K',
  'I',
  'I',
  'J',
  'C',
  'I',
  'K',
  'H',
  'J',
  'I',
  'J',
  'I',
  'J',
  'J',
  'J',
  'J',
  'J',
  'F',
  'J',
  'J',
  'J',
  'J',
  'J',
  'I',
  'T',
  'J',
  'J',
  'K',
  'D',
  'C',
  'J',
  'C',
  'B',
  'I',
  'J',
  'J',
  'G',
  'G',
  'K',
  'J',
  'J',
  'J',
  'J',
  'D',
  'J',
  'J',
  'J',
  'D',
  'J',
  'J',
  'T',
  'D',
  'D',
  'I',
  'J',
  'J',
  'J',
  'T',
  'J',
  'J',
  'I',
  'J',
  'C',
  'J',
  'C',
  'J',
  'M',
  'K',
  'T',
  'J',
  'F',
  'I',
  'J',
  'G',
  'J',
  '

In [76]:
def compute_metrics(raw_res_dict):
    def manual_precision(raw_res_dict):
        raw_res_dict['precision_letter_code'] = raw_res_dict['TP_res_letter_code'] / (raw_res_dict['TP_res_letter_code'] + raw_res_dict['FP_res_letter_code'])
        raw_res_dict['precision_before_decimal_code'] = raw_res_dict['TP_res_before_decimal_code'] / (raw_res_dict['TP_res_before_decimal_code'] + raw_res_dict['FP_res_before_decimal_code'])
        raw_res_dict['precision_icd10_code'] = raw_res_dict['TP_res_icd10_code'] / (raw_res_dict['TP_res_icd10_code'] + raw_res_dict['FP_res_icd10_code'])
        return raw_res_dict
    def manual_recall(raw_res_dict):
        raw_res_dict['recall_letter_code'] = raw_res_dict['TP_res_letter_code'] / (raw_res_dict['TP_res_letter_code'] + raw_res_dict['FN_res_letter_code'])
        raw_res_dict['recall_before_decimal_code'] = raw_res_dict['TP_res_before_decimal_code'] / (raw_res_dict['TP_res_before_decimal_code'] + raw_res_dict['FN_res_before_decimal_code'])
        raw_res_dict['recall_icd10_code'] = raw_res_dict['TP_res_icd10_code'] / (raw_res_dict['TP_res_icd10_code'] + raw_res_dict['FN_res_icd10_code'])
        return raw_res_dict
    def manual_f1(raw_res_dict):
        raw_res_dict['f1_letter_code'] = 2 * raw_res_dict['precision_letter_code'] * raw_res_dict['recall_letter_code'] / (raw_res_dict['precision_letter_code'] + raw_res_dict['recall_letter_code'])
        raw_res_dict['f1_before_decimal_code'] = 2 * raw_res_dict['precision_before_decimal_code'] * raw_res_dict['recall_before_decimal_code'] / (raw_res_dict['precision_before_decimal_code'] + raw_res_dict['recall_before_decimal_code'])
        raw_res_dict['f1_icd10_code'] = 2 * raw_res_dict['precision_icd10_code'] * raw_res_dict['recall_icd10_code'] / (raw_res_dict['precision_icd10_code'] + raw_res_dict['recall_icd10_code'])
        return raw_res_dict
    raw_res_dict = manual_precision(raw_res_dict)
    raw_res_dict = manual_recall(raw_res_dict)
    raw_res_dict = manual_f1(raw_res_dict)
    return raw_res_dict

In [81]:
def print_results(davinci_raw_results_dict):
    davinci_raw_results_dict = compute_metrics(davinci_raw_results_dict)
    print(davinci_raw_results_dict['precision_letter_code'])
    print(davinci_raw_results_dict['precision_before_decimal_code'])
    print(davinci_raw_results_dict['precision_icd10_code'])
    print()
    print(davinci_raw_results_dict['recall_letter_code'])
    print(davinci_raw_results_dict['recall_before_decimal_code'])
    print(davinci_raw_results_dict['recall_icd10_code'])
    print()
    print(davinci_raw_results_dict['f1_letter_code'])
    print(davinci_raw_results_dict['f1_before_decimal_code'])
    print(davinci_raw_results_dict['f1_icd10_code'])
print_results(davinci_raw_results_dict)

0.983
0.929
0.9148550724637681

1.0
1.0
0.9960552268244576

0.9914271306101865
0.9631933644375325
0.9537299338999056


# Parse ChatGPT Output Zero-Shot

In [83]:
import json
import re
import pandas as pd

In [84]:
def readJSONFile(filePath):
  f = open(filePath)
  data = json.load(f)
  return data #dictionary mapping to a list of ChatGPT API responses

In [85]:
# Testing with sample responses
filePath = './data/chatGPTnext9.json'
data = readJSONFile(filePath)
chatGPTResponses = data['responses']
for dataDict in chatGPTResponses:
    inputPrompt = dataDict['inputPrompt']
    response = dataDict['choices'][0]['message']['content']
    rowKey = dataDict['dataPoint']
    print(response)

G44.1
B22.9
K21.0
I50.9
I20.0
J01.90
K86.1
I20.9
K29.70


In [86]:
def gen_chatgpt_responses_batch_file_name(batch_num):
    return f'./data/chatGPTBatch{batch_num}.json'

In [87]:
def tmp_extract_chatgpt_icd10_response(filepath):
#     num_out_with_code_dict = 0
#     num_correct_format = 0
    print(filepath)
    data = readJSONFile(filepath)
    chatGPTResponses = data['responses']
#     print(len(chatGPTResponses))
#     return_dict = {}
    for dataDict in chatGPTResponses:
        rowKey = dataDict['dataPoint']
        response = dataDict['choices'][0]['message']['content']
        print(response)
#         if "{'code': " in response:
#             num_out_with_code_dict += 1
#             response = response[-1]
        
#         return_dict[rowKey] = response
    return

In [88]:
batch_num = 9
tmp_extract_chatgpt_icd10_response(gen_chatgpt_responses_batch_file_name(batch_num))

./data/chatGPTBatch9.json
I25.10
J44.1
R78.81
M79.1
J01.90
J30.1
G24.4
I50.9
J33.8
D62.9
G70.9
B97.4
B20.
G40.1
J98.4
J44.1
A37.0
J45.909
K86.1
R06.0
I26.99
I20.0
{'code': 'N
J30.1
G61.0
H57.13
K59.00
G61.0
K57.30
I50.9
J40.
{'code': 'D
M79.1
I50.1
M65.9
I50.9
J12.89
G70.0
J06.9
{'code': 'H
J44.1
F41.1
T78.0
J18.9
J02.9
J44.1
J44.1
I48.91
R07.0
B22.2
I50.9
R07.0
M51.26
D50.9
B22.9
C34.9
J12.89
J98.09
J93.83
R50.9
G40.1
{'code': 'J
{'code': 'M
G44.0
J01.90
D50.9
J45.909
J06.9
J18.9
G24.9
J39.8
I26.90
I47.1
J12.89
G70.9
H66.90
J98.4
K21.0
I25.10
K57.30
{'code': 'K
J02.9
B33.3
C34.90
I50.9
G24.9
G44.1
K29.70
M54.5
G40.1
D64.9
{'code': 'G
J44.1
J12.9
G44.0
K21.0
M35.9
F41.0
K57.30
J45.909
F41.1
J45.909
G44.0
K29.01
I51.7
B90.9
J01.90
G44.0
R07.0
Z86.79
J45.909
G70.0
G24.9
R07.0
{'code': 'J
J15.0
I20.9
M79.1
C34.91
Z99.2
K29.01
I30.0
B34.9
I50.9
D64.9
H66.90
D50.9
G44.0
J45.909
F11.5
{'code': 'A
M25.5
D50.9
J02.9
B20
I50.9
J45.909
D50.9
K25.9
J44.1
J44.1
G24.9
I50.9
J02.9
D64.9
{'code': 'J


In [89]:
def clean_extract_chatgpt_icd10_response(filepath, return_dict):
#     out_with_letter_code_dict = []
#     out_with_correct_format = []
#     out_with_wrong_format = []
    print(filepath)
    data = readJSONFile(filepath)
    chatGPTResponses = data['responses']
#     print(len(chatGPTResponses))
#     return_dict = {}
    for dataDict in chatGPTResponses:
        rowKey = dataDict['dataPoint']
        response = dataDict['choices'][0]['message']['content']
        return_dict[rowKey] = {'pred_letter_code': '', 'pred_before_decimal_code': '', 'pred_icd10_code': '', 'original_response': response}
#         print(response)
        if "{'code': " in response:
#             out_with_letter_code_dict.append(response)
            return_dict[rowKey]['pred_letter_code'] = response[-1]
        elif response[0].isupper() and response[1].isdigit():
            return_dict[rowKey]['pred_letter_code'] = response[0]
            return_dict[rowKey]['pred_before_decimal_code'] = response[:3]
            if re.search("^[A-Z][0-9][0-9].[0-9]", response.split(' ')[0]):
                return_dict[rowKey]['pred_icd10_code'] = response[:5]
#             out_with_correct_format.append(response)
#         else:
#             out_with_wrong_format.append(response)
#         return_dict[rowKey]['original_response'] = response
#         if corrected_response != '':
#             return_dict[rowKey]['corrected_response'] = corrected_response
#     print(len(out_with_letter_code_dict), f"out_with_letter_code_dict", out_with_letter_code_dict)
#     print()
#     print(len(out_with_correct_format), f"out_with_correct_format", out_with_correct_format)
#     print()
#     print(len(out_with_wrong_format), f"out_with_wrong_format", out_with_wrong_format)
    return return_dict

In [90]:
chatgpt_parsed_dict = {}
for i in range(1,11):
    chatgpt_parsed_dict = clean_extract_chatgpt_icd10_response(gen_chatgpt_responses_batch_file_name(i), chatgpt_parsed_dict)
print(len(chatgpt_parsed_dict))

./data/chatGPTBatch1.json
./data/chatGPTBatch2.json
./data/chatGPTBatch3.json
./data/chatGPTBatch4.json
./data/chatGPTBatch5.json
./data/chatGPTBatch6.json
./data/chatGPTBatch7.json
./data/chatGPTBatch8.json
./data/chatGPTBatch9.json
./data/chatGPTBatch10.json
10000


In [91]:
# Notes
# if the original response has "diagnose", check manually and count how many times chatgpt says it cannot diagnose
# I cannot diagnose a patient

In [92]:
groundtruth_icd10_df = pd.read_csv('./data/Project_Test_Data.csv')
groundtruth_icd10_df

Unnamed: 0,OriginalRowKey,AGE,DIFFERENTIAL_DIAGNOSIS,SEX,PATHOLOGY,EVIDENCES,INITIAL_EVIDENCE,English Pathology,English Evidences,ICD-10
0,109830,21,"[['Asthme exacerb√© ou bronchospasme', 0.11700...",F,Asthme exacerb√© ou bronchospasme,"['dyspn', 'fam_j45', 'j32', 'j45', 'toux', 'tr...",toux,Bronchospasm / acute asthma exacerbation,{'Are you experiencing shortness of breath or ...,J98.0
1,11402,31,"[['R√©action dystonique a√Øgue', 0.17818166718...",M,R√©action dystonique a√Øgue,"['drogues_stimul', 'dyspn', 'laryngospasme', '...",protu_langue,Acute dystonic reactions,{'Do you regularly take stimulant drugs?': 'Tr...,G24.9
2,30840,3,"[['Rhinosinusite chronique', 0.189267488945076...",F,Rhinosinusite chronique,"['douleurxx', 'douleurxx_carac_@_une_br√ªlure_...",hyponos,Chronic rhinosinusitis,"{'Do you have pain somewhere, related to your ...",J32.9
3,14889,46,"[['VIH (Primo-infection)', 0.48414955175409613...",F,VIH (Primo-infection),"['atcd_its', 'diaph', 'diarrhee', 'douleurxx',...",diaph,HIV (initial infection),{'Have you ever had a sexually transmitted inf...,B20
4,131903,72,"[['Attaque de panique', 0.09172200566348035], ...",F,Attaque de panique,"['anxiete_s', 'atcdpsyfam', 'diaph', 'douleurx...",palpit,Panic attack,"{'Do you feel anxious?': 'True', 'Do any membe...",F41.0
...,...,...,...,...,...,...,...,...,...,...
9995,113311,0,"[['VIH (Primo-infection)', 0.3030252901252109]...",M,VIH (Primo-infection),"['atcd_its', 'diarrhee', 'douleurxx', 'douleur...",fatig_ext,HIV (initial infection),{'Have you ever had a sexually transmitted inf...,B20
9996,46200,35,"[['Possible NSTEMI / STEMI', 0.170304300298058...",M,Possible NSTEMI / STEMI,"['HIV', 'diaph', 'douleurxx', 'douleurxx_carac...",fatig_mod,Possible NSTEMI / STEMI,{'Are you infected with the human immunodefici...,I21.4
9997,54048,71,"[['Pharyngite virale', 0.5232406274878201], ['...",F,Pharyngite virale,"['crach_sg', 'dayc', 'douleurxx', 'douleurxx_c...",crach_sg,Viral pharyngitis,"{'Have you been coughing up blood?': 'True', '...",J02.9
9998,27233,61,"[['Possible NSTEMI / STEMI', 0.479718656849398...",M,Sarco√Ødose,"['convulsion', 'douleurxx', 'douleurxx_carac_@...",dyspn,Sarcoidosis,{'Have you lost consciousness associated with ...,D86.9


In [93]:
def add_groundtruth_icd(return_dict):
    for row_key in return_dict:
        gtruth_icd10 = groundtruth_icd10_df[groundtruth_icd10_df['OriginalRowKey'] == row_key].iloc[0]['ICD-10']
        if 'or' not in gtruth_icd10:
            return_dict[row_key]['gtruth_letter_code'] = gtruth_icd10[0]
            return_dict[row_key]['gtruth_before_decimal_code'] = gtruth_icd10[:3]
            if '.' in gtruth_icd10:
                return_dict[row_key]['gtruth_icd10_code'] = gtruth_icd10[:5]
            else:
                return_dict[row_key]['gtruth_icd10_code'] = ''
        else:
            return_dict[row_key]['gtruth_letter_code'] = 'J'
            return_dict[row_key]['gtruth_before_decimal_code'] = ['J10', 'J11']
            return_dict[row_key]['gtruth_icd10_code'] = ''
    return return_dict
            

In [94]:
chatgpt_parsed_dict = add_groundtruth_icd(chatgpt_parsed_dict)
print(len(chatgpt_parsed_dict))

10000


In [95]:
chatgpt_parsed_dict

{109830: {'pred_letter_code': 'J',
  'pred_before_decimal_code': 'J45',
  'pred_icd10_code': 'J45.9',
  'original_response': 'J45.909',
  'gtruth_letter_code': 'J',
  'gtruth_before_decimal_code': 'J98',
  'gtruth_icd10_code': 'J98.0'},
 11402: {'pred_letter_code': 'F',
  'pred_before_decimal_code': 'F11',
  'pred_icd10_code': 'F11.9',
  'original_response': 'F11.9',
  'gtruth_letter_code': 'G',
  'gtruth_before_decimal_code': 'G24',
  'gtruth_icd10_code': 'G24.9'},
 30840: {'pred_letter_code': 'J',
  'pred_before_decimal_code': 'J01',
  'pred_icd10_code': 'J01.9',
  'original_response': 'J01.90',
  'gtruth_letter_code': 'J',
  'gtruth_before_decimal_code': 'J32',
  'gtruth_icd10_code': 'J32.9'},
 14889: {'pred_letter_code': 'B',
  'pred_before_decimal_code': 'B97',
  'pred_icd10_code': 'B97.3',
  'original_response': 'B97.35',
  'gtruth_letter_code': 'B',
  'gtruth_before_decimal_code': 'B20',
  'gtruth_icd10_code': ''},
 131903: {'pred_letter_code': 'F',
  'pred_before_decimal_code':

In [100]:
def compare_preds_and_gtruths(return_dict):
    # M for missing ground_truth, FN for false negative (includes case for empty string pred), FP for false positive, TP for true positive
    raw_res_dict = {'M_res_letter_code': 0, 'FN_res_letter_code': 0, 'FP_res_letter_code': 0, 'TP_res_letter_code': 0,
               'M_res_before_decimal_code': 0, 'FN_res_before_decimal_code': 0, 'FP_res_before_decimal_code': 0, 'TP_res_before_decimal_code': 0,
               'M_res_icd10_code': 0, 'FN_res_icd10_code': 0, 'FP_res_icd10_code': 0, 'TP_res_icd10_code': 0,
                   'list_gtruths_letter_code': [], 'list_preds_letter_code': [],
                   'list_gtruths_before_decimal_code': [], 'list_preds_before_decimal_code': [],
                   'list_gtruths_icd10_code': [], 'list_preds_icd10_code': []}
    for row_key in return_dict:
        
        # letter code comparison
        if return_dict[row_key]['gtruth_letter_code'] == '': #should not execute
            print("blank gtruth_letter_code")
        else:
            raw_res_dict['list_gtruths_letter_code'].append(return_dict[row_key]['gtruth_letter_code'])
            raw_res_dict['list_preds_letter_code'].append(return_dict[row_key]['pred_letter_code'])
            if return_dict[row_key]['pred_letter_code'] == '':
                return_dict[row_key]['res_letter_code'] = 'FN'
                raw_res_dict['FN_res_letter_code'] += 1
            else:
                if return_dict[row_key]['gtruth_letter_code'] == return_dict[row_key]['pred_letter_code']:
                    return_dict[row_key]['res_letter_code'] = 'TP'
                    raw_res_dict['TP_res_letter_code'] += 1
                else:
                    return_dict[row_key]['res_letter_code'] = 'FP'
                    raw_res_dict['FP_res_letter_code'] += 1
        
        # before decimal comparison
        if return_dict[row_key]['gtruth_before_decimal_code'] == '' or return_dict[row_key]['gtruth_before_decimal_code'] == []: #should not execute
            print("blank gtruth_before_decimal_code")
        else:
#             raw_res_dict['list_gtruths_before_decimal_code'].append(return_dict[row_key]['gtruth_before_decimal_code'])
#             raw_res_dict['list_preds_before_decimal_code'].append(return_dict[row_key]['pred_before_decimal_code'])
            
            if type(return_dict[row_key]['gtruth_before_decimal_code']) == list:
#               print("reached")
                if return_dict[row_key]['pred_before_decimal_code'] == 'J10':
                    raw_res_dict['list_gtruths_before_decimal_code'].append('J10')
                    raw_res_dict['list_preds_before_decimal_code'].append('J10')
                elif return_dict[row_key]['pred_before_decimal_code'] == 'J11':
                    raw_res_dict['list_gtruths_before_decimal_code'].append('J11')
                    raw_res_dict['list_preds_before_decimal_code'].append('J11')
                else:
                    raw_res_dict['list_gtruths_before_decimal_code'].append('J10') #append either, force wrong
                    raw_res_dict['list_preds_before_decimal_code'].append(return_dict[row_key]['pred_before_decimal_code'])
#                     print(return_dict[row_key]['pred_before_decimal_code'])
            else:
                raw_res_dict['list_gtruths_before_decimal_code'].append(return_dict[row_key]['gtruth_before_decimal_code'])
                raw_res_dict['list_preds_before_decimal_code'].append(return_dict[row_key]['pred_before_decimal_code'])
            
            if return_dict[row_key]['pred_before_decimal_code'] == '':
                return_dict[row_key]['res_before_decimal_code'] = 'FN'
                raw_res_dict['FN_res_before_decimal_code'] += 1
#                 print(return_dict[row_key]['pred_before_decimal_code'], return_dict[row_key]['gtruth_before_decimal_code'])
            else:
                if return_dict[row_key]['pred_before_decimal_code'] not in return_dict[row_key]['gtruth_before_decimal_code']:
#                     print(return_dict[row_key]['pred_before_decimal_code'], return_dict[row_key]['gtruth_before_decimal_code'])
                    return_dict[row_key]['res_before_decimal_code'] = 'FP'
                    raw_res_dict['FP_res_before_decimal_code'] += 1
                else:
                    return_dict[row_key]['res_before_decimal_code'] = 'TP'
                    raw_res_dict['TP_res_before_decimal_code'] += 1
#                     print(return_dict[row_key]['pred_before_decimal_code'], return_dict[row_key]['gtruth_before_decimal_code'])
        
        # full icd10_code comparison
        if return_dict[row_key]['gtruth_icd10_code'] == '': # executes if 'gtruth_before_decimal_code' maps to ['J10', 'J11'] or a code without decimal value
#             print(return_dict[row_key]['gtruth_before_decimal_code'])
            return_dict[row_key]['res_icd10_code'] = 'M'
            raw_res_dict['M_res_icd10_code'] += 1
            continue
        else:
            raw_res_dict['list_gtruths_icd10_code'].append(return_dict[row_key]['gtruth_icd10_code'])
            raw_res_dict['list_preds_icd10_code'].append(return_dict[row_key]['pred_icd10_code'])
            if return_dict[row_key]['pred_icd10_code'] == '':
                return_dict[row_key]['res_icd10_code'] = 'FN'
                raw_res_dict['FN_res_icd10_code'] += 1
#                 print(return_dict[row_key]['pred_icd10_code'], return_dict[row_key]['gtruth_icd10_code'])
            else:
                if return_dict[row_key]['pred_icd10_code'] != return_dict[row_key]['gtruth_icd10_code']:
#                     print(return_dict[row_key]['pred_icd10_code'], return_dict[row_key]['gtruth_icd10_code'])
                    return_dict[row_key]['res_icd10_code'] = 'FP'
                    raw_res_dict['FP_res_icd10_code'] += 1
                else:
                    return_dict[row_key]['res_icd10_code'] = 'TP'
                    raw_res_dict['TP_res_icd10_code'] += 1
#                     print(return_dict[row_key]['pred_icd10_code'], return_dict[row_key]['gtruth_icd10_code'])
        
    return return_dict, raw_res_dict

In [101]:
chatgpt_parsed_dict, chatgpt_raw_results_dict = compare_preds_and_gtruths(chatgpt_parsed_dict)

In [103]:
# chatgpt_parsed_dict

In [104]:
# chatgpt_raw_results_dict

{'M_res_letter_code': 0,
 'FN_res_letter_code': 14,
 'FP_res_letter_code': 2803,
 'TP_res_letter_code': 7183,
 'M_res_before_decimal_code': 0,
 'FN_res_before_decimal_code': 839,
 'FP_res_before_decimal_code': 6285,
 'TP_res_before_decimal_code': 2876,
 'M_res_icd10_code': 1756,
 'FN_res_icd10_code': 885,
 'FP_res_icd10_code': 5666,
 'TP_res_icd10_code': 1693,
 'list_gtruths_letter_code': ['J',
  'G',
  'J',
  'B',
  'F',
  'J',
  'G',
  'F',
  'J',
  'G',
  'G',
  'G',
  'B',
  'K',
  'I',
  'I',
  'J',
  'C',
  'I',
  'K',
  'H',
  'J',
  'I',
  'J',
  'I',
  'J',
  'J',
  'J',
  'J',
  'J',
  'F',
  'J',
  'J',
  'J',
  'J',
  'J',
  'I',
  'T',
  'J',
  'J',
  'K',
  'D',
  'C',
  'J',
  'C',
  'B',
  'I',
  'J',
  'J',
  'G',
  'G',
  'K',
  'J',
  'J',
  'J',
  'J',
  'D',
  'J',
  'J',
  'J',
  'D',
  'J',
  'J',
  'T',
  'D',
  'D',
  'I',
  'J',
  'J',
  'J',
  'T',
  'J',
  'J',
  'I',
  'J',
  'C',
  'J',
  'C',
  'J',
  'M',
  'K',
  'T',
  'J',
  'F',
  'I',
  'J',
  'G',


In [64]:
# Metrics to calculate
# precision, recall, f-1
# "I cannot diagnose a patient"
# Top-5 most common FP
# Top-5 most common TP
# Top-5 most common FN

In [105]:
def find_num_cannot_diagnose_responses(return_dict):
    num = 0
    for row_key in return_dict:
        original_response = return_dict[row_key]['original_response']
        if original_response == 'I cannot diagnose a patient':
            num += 1
        elif 'diagnose' in original_response:
            print('found')
    return num, num / len(return_dict) * 100

In [106]:
num_cannot_diagnose_responses, percent_cannot_diagnose_responses = find_num_cannot_diagnose_responses(chatgpt_parsed_dict)
print(num_cannot_diagnose_responses)
print(f'{percent_cannot_diagnose_responses}%')

4
0.04%


In [35]:
def manual_precision(raw_res_dict):
    raw_res_dict['precision_letter_code'] = raw_res_dict['TP_res_letter_code'] / (raw_res_dict['TP_res_letter_code'] + raw_res_dict['FP_res_letter_code'])
    raw_res_dict['precision_before_decimal_code'] = raw_res_dict['TP_res_before_decimal_code'] / (raw_res_dict['TP_res_before_decimal_code'] + raw_res_dict['FP_res_before_decimal_code'])
    raw_res_dict['precision_icd10_code'] = raw_res_dict['TP_res_icd10_code'] / (raw_res_dict['TP_res_icd10_code'] + raw_res_dict['FP_res_icd10_code'])
    return raw_res_dict

In [36]:
raw_results_dict = manual_precision(raw_results_dict)
print(raw_results_dict['precision_letter_code'], raw_results_dict['precision_before_decimal_code'], raw_results_dict['precision_icd10_code'])

0.7193070298417785 0.3139395262525925 0.23005843185215383


In [37]:
def manual_recall(raw_res_dict):
    raw_res_dict['recall_letter_code'] = raw_res_dict['TP_res_letter_code'] / (raw_res_dict['TP_res_letter_code'] + raw_res_dict['FN_res_letter_code'])
    raw_res_dict['recall_before_decimal_code'] = raw_res_dict['TP_res_before_decimal_code'] / (raw_res_dict['TP_res_before_decimal_code'] + raw_res_dict['FN_res_before_decimal_code'])
    raw_res_dict['recall_icd10_code'] = raw_res_dict['TP_res_icd10_code'] / (raw_res_dict['TP_res_icd10_code'] + raw_res_dict['FN_res_icd10_code'])
    return raw_results_dict

In [38]:
raw_results_dict = manual_recall(raw_results_dict)
print(raw_results_dict['recall_letter_code'], raw_results_dict['recall_before_decimal_code'], raw_results_dict['recall_icd10_code'])

0.9980547450326525 0.7741588156123822 0.656710628394104


In [39]:
def manual_f1(raw_res_dict):
    raw_res_dict['f1_letter_code'] = 2 * raw_res_dict['precision_letter_code'] * raw_res_dict['recall_letter_code'] / (raw_res_dict['precision_letter_code'] + raw_res_dict['recall_letter_code'])
    raw_res_dict['f1_before_decimal_code'] = 2 * raw_res_dict['precision_before_decimal_code'] * raw_res_dict['recall_before_decimal_code'] / (raw_res_dict['precision_before_decimal_code'] + raw_res_dict['recall_before_decimal_code'])
    raw_res_dict['f1_icd10_code'] = 2 * raw_res_dict['precision_icd10_code'] * raw_res_dict['recall_icd10_code'] / (raw_res_dict['precision_icd10_code'] + raw_res_dict['recall_icd10_code'])
    return raw_results_dict

In [40]:
raw_results_dict = manual_f1(raw_results_dict)
print(raw_results_dict['f1_letter_code'], raw_results_dict['f1_before_decimal_code'], raw_results_dict['f1_icd10_code'])

0.8360588954198918 0.44672258465361914 0.3407467042366912


In [107]:
def compute_metrics(raw_res_dict):
    def manual_precision(raw_res_dict):
        raw_res_dict['precision_letter_code'] = raw_res_dict['TP_res_letter_code'] / (raw_res_dict['TP_res_letter_code'] + raw_res_dict['FP_res_letter_code'])
        raw_res_dict['precision_before_decimal_code'] = raw_res_dict['TP_res_before_decimal_code'] / (raw_res_dict['TP_res_before_decimal_code'] + raw_res_dict['FP_res_before_decimal_code'])
        raw_res_dict['precision_icd10_code'] = raw_res_dict['TP_res_icd10_code'] / (raw_res_dict['TP_res_icd10_code'] + raw_res_dict['FP_res_icd10_code'])
        return raw_res_dict
    def manual_recall(raw_res_dict):
        raw_res_dict['recall_letter_code'] = raw_res_dict['TP_res_letter_code'] / (raw_res_dict['TP_res_letter_code'] + raw_res_dict['FN_res_letter_code'])
        raw_res_dict['recall_before_decimal_code'] = raw_res_dict['TP_res_before_decimal_code'] / (raw_res_dict['TP_res_before_decimal_code'] + raw_res_dict['FN_res_before_decimal_code'])
        raw_res_dict['recall_icd10_code'] = raw_res_dict['TP_res_icd10_code'] / (raw_res_dict['TP_res_icd10_code'] + raw_res_dict['FN_res_icd10_code'])
        return raw_res_dict
    def manual_f1(raw_res_dict):
        raw_res_dict['f1_letter_code'] = 2 * raw_res_dict['precision_letter_code'] * raw_res_dict['recall_letter_code'] / (raw_res_dict['precision_letter_code'] + raw_res_dict['recall_letter_code'])
        raw_res_dict['f1_before_decimal_code'] = 2 * raw_res_dict['precision_before_decimal_code'] * raw_res_dict['recall_before_decimal_code'] / (raw_res_dict['precision_before_decimal_code'] + raw_res_dict['recall_before_decimal_code'])
        raw_res_dict['f1_icd10_code'] = 2 * raw_res_dict['precision_icd10_code'] * raw_res_dict['recall_icd10_code'] / (raw_res_dict['precision_icd10_code'] + raw_res_dict['recall_icd10_code'])
        return raw_res_dict
    raw_res_dict = manual_precision(raw_res_dict)
    raw_res_dict = manual_recall(raw_res_dict)
    raw_res_dict = manual_f1(raw_res_dict)
    return raw_res_dict

In [108]:
def print_results(chatgpt_raw_results_dict):
    chatgpt_raw_results_dict = compute_metrics(chatgpt_raw_results_dict)
    print(chatgpt_raw_results_dict['precision_letter_code'])
    print(chatgpt_raw_results_dict['precision_before_decimal_code'])
    print(chatgpt_raw_results_dict['precision_icd10_code'])
    print()
    print(chatgpt_raw_results_dict['recall_letter_code'])
    print(chatgpt_raw_results_dict['recall_before_decimal_code'])
    print(chatgpt_raw_results_dict['recall_icd10_code'])
    print()
    print(chatgpt_raw_results_dict['f1_letter_code'])
    print(chatgpt_raw_results_dict['f1_before_decimal_code'])
    print(chatgpt_raw_results_dict['f1_icd10_code'])
print_results(chatgpt_raw_results_dict)

0.7193070298417785
0.3139395262525925
0.23005843185215383

0.9980547450326525
0.7741588156123822
0.656710628394104

0.8360588954198918
0.44672258465361914
0.3407467042366912


In [41]:
import json
with open('./rawresults_ZeroShotGPT3-5_turbo.json', 'w') as fp:
    json.dump(raw_results_dict, fp)

In [42]:
import json
with open('./parsed_output_ZeroShotGPT3-5_turbo.json', 'w') as fp:
    json.dump(chatgpt_parsed_dict, fp)

# sklearn functions

In [95]:
from sklearn.metrics import f1_score, precision_score, recall_score
from sklearn import preprocessing

In [106]:
le_letter_code = preprocessing.LabelEncoder()
le_letter_code.fit(raw_results_dict['list_gtruths_letter_code'])
le_letter_code_dict = dict(zip(le_letter_code.classes_, le_letter_code.transform(le_letter_code.classes_)))
raw_results_dict['le_list_preds_letter_code'] = [le_letter_code_dict.get(val, 1+max(le_letter_code.transform(le_letter_code.classes_))) for val in raw_results_dict['list_preds_letter_code']]

In [107]:
le_letter_code.transform(raw_results_dict['le_list_preds_letter_code'])

TypeError: ufunc 'isnan' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''

In [79]:
sklearn_results_dict = {}

In [80]:
sklearn_results_dict['sklearn_macro_precision_letter_code'] = precision_score(raw_results_dict['list_gtruths_letter_code'], raw_results_dict['list_preds_letter_code'], average='macro', zero_division=0)
sklearn_results_dict['sklearn_micro_precision_letter_code'] = precision_score(raw_results_dict['list_gtruths_letter_code'], raw_results_dict['list_preds_letter_code'], average='micro', zero_division=0)
sklearn_results_dict['sklearn_weighted_precision_letter_code'] = precision_score(raw_results_dict['list_gtruths_letter_code'], raw_results_dict['list_preds_letter_code'], average='weighted', zero_division=0)

sklearn_results_dict['sklearn_macro_precision_before_decimal_code'] = precision_score(raw_results_dict['list_gtruths_before_decimal_code'], raw_results_dict['list_preds_before_decimal_code'], average='macro', zero_division=0)
sklearn_results_dict['sklearn_micro_precision_before_decimal_code'] = precision_score(raw_results_dict['list_gtruths_before_decimal_code'], raw_results_dict['list_preds_before_decimal_code'], average='micro', zero_division=0)
sklearn_results_dict['sklearn_weighted_precision_before_decimal_code'] = precision_score(raw_results_dict['list_gtruths_before_decimal_code'], raw_results_dict['list_preds_before_decimal_code'], average='weighted', zero_division=0)

sklearn_results_dict['sklearn_macro_precision_icd10_code'] = precision_score(raw_results_dict['list_gtruths_icd10_code'], raw_results_dict['list_preds_icd10_code'], average='macro', zero_division=0)
sklearn_results_dict['sklearn_micro_precision_icd10_code'] = precision_score(raw_results_dict['list_gtruths_icd10_code'], raw_results_dict['list_preds_icd10_code'], average='micro', zero_division=0)
sklearn_results_dict['sklearn_weighted_precision_icd10_code'] = precision_score(raw_results_dict['list_gtruths_icd10_code'], raw_results_dict['list_preds_icd10_code'], average='weighted', zero_division=0)

sklearn_results_dict['sklearn_macro_recall_letter_code'] = recall_score(raw_results_dict['list_gtruths_letter_code'], raw_results_dict['list_preds_letter_code'], average='macro', zero_division=0)
sklearn_results_dict['sklearn_micro_recall_letter_code'] = recall_score(raw_results_dict['list_gtruths_letter_code'], raw_results_dict['list_preds_letter_code'], average='micro', zero_division=0)
sklearn_results_dict['sklearn_weighted_recall_letter_code'] = recall_score(raw_results_dict['list_gtruths_letter_code'], raw_results_dict['list_preds_letter_code'], average='weighted', zero_division=0)

sklearn_results_dict['sklearn_macro_recall_before_decimal_code'] = recall_score(raw_results_dict['list_gtruths_before_decimal_code'], raw_results_dict['list_preds_before_decimal_code'], average='macro', zero_division=0)
sklearn_results_dict['sklearn_micro_recall_before_decimal_code'] = recall_score(raw_results_dict['list_gtruths_before_decimal_code'], raw_results_dict['list_preds_before_decimal_code'], average='micro', zero_division=0)
sklearn_results_dict['sklearn_weighted_recall_before_decimal_code'] = recall_score(raw_results_dict['list_gtruths_before_decimal_code'], raw_results_dict['list_preds_before_decimal_code'], average='weighted', zero_division=0)

sklearn_results_dict['sklearn_macro_recall_icd10_code'] = recall_score(raw_results_dict['list_gtruths_icd10_code'], raw_results_dict['list_preds_icd10_code'], average='macro', zero_division=0)
sklearn_results_dict['sklearn_micro_recall_icd10_code'] = recall_score(raw_results_dict['list_gtruths_icd10_code'], raw_results_dict['list_preds_icd10_code'], average='micro', zero_division=0)
sklearn_results_dict['sklearn_weighted_recall_icd10_code'] = recall_score(raw_results_dict['list_gtruths_icd10_code'], raw_results_dict['list_preds_icd10_code'], average='weighted', zero_division=0)

sklearn_results_dict['sklearn_macro_f1_letter_code'] = f1_score(raw_results_dict['list_gtruths_letter_code'], raw_results_dict['list_preds_letter_code'], average='macro', zero_division=0)
sklearn_results_dict['sklearn_micro_f1_letter_code'] = f1_score(raw_results_dict['list_gtruths_letter_code'], raw_results_dict['list_preds_letter_code'], average='micro', zero_division=0)
sklearn_results_dict['sklearn_weighted_f1_letter_code'] = f1_score(raw_results_dict['list_gtruths_letter_code'], raw_results_dict['list_preds_letter_code'], average='weighted', zero_division=0)

sklearn_results_dict['sklearn_macro_f1_before_decimal_code'] = f1_score(raw_results_dict['list_gtruths_before_decimal_code'], raw_results_dict['list_preds_before_decimal_code'], average='macro', zero_division=0)
sklearn_results_dict['sklearn_micro_f1_before_decimal_code'] = f1_score(raw_results_dict['list_gtruths_before_decimal_code'], raw_results_dict['list_preds_before_decimal_code'], average='micro', zero_division=0)
sklearn_results_dict['sklearn_weighted_f1_before_decimal_code'] = f1_score(raw_results_dict['list_gtruths_before_decimal_code'], raw_results_dict['list_preds_before_decimal_code'], average='weighted', zero_division=0)

sklearn_results_dict['sklearn_macro_f1_icd10_code'] = f1_score(raw_results_dict['list_gtruths_icd10_code'], raw_results_dict['list_preds_icd10_code'], average='macro', zero_division=0)
sklearn_results_dict['sklearn_micro_f1_icd10_code'] = f1_score(raw_results_dict['list_gtruths_icd10_code'], raw_results_dict['list_preds_icd10_code'], average='micro', zero_division=0)
sklearn_results_dict['sklearn_weighted_f1_icd10_code'] = f1_score(raw_results_dict['list_gtruths_icd10_code'], raw_results_dict['list_preds_icd10_code'], average='weighted', zero_division=0)




In [81]:
sklearn_results_dict

{'sklearn_macro_precision_letter_code': 0.3559527985217898,
 'sklearn_micro_precision_letter_code': 0.7183,
 'sklearn_weighted_precision_letter_code': 0.8066841840581022,
 'sklearn_macro_precision_before_decimal_code': 0.09550954658383723,
 'sklearn_micro_precision_before_decimal_code': 0.2876,
 'sklearn_weighted_precision_before_decimal_code': 0.5692213682226512,
 'sklearn_macro_precision_icd10_code': 0.056607097197205185,
 'sklearn_micro_precision_icd10_code': 0.20536147501213003,
 'sklearn_weighted_precision_icd10_code': 0.5192834323225799,
 'sklearn_macro_recall_letter_code': 0.3710374641345575,
 'sklearn_micro_recall_letter_code': 0.7183,
 'sklearn_weighted_recall_letter_code': 0.7183,
 'sklearn_macro_recall_before_decimal_code': 0.04970196395167424,
 'sklearn_micro_recall_before_decimal_code': 0.2876,
 'sklearn_weighted_recall_before_decimal_code': 0.2876,
 'sklearn_macro_recall_icd10_code': 0.02415425873603472,
 'sklearn_micro_recall_icd10_code': 0.20536147501213003,
 'sklearn_w