In [16]:

import json
from prettytable import PrettyTable
import json
import os
import random
import pandas as pd
import glob
import numpy as np

def load_jsonl(jsonl_file_path):
    with open(jsonl_file_path, 'r', encoding='utf-8') as file:
        return [json.loads(line.strip()) for line in file]

def save_as_json(json_list, output_file_path):
    with open(output_file_path, 'w', encoding='utf-8') as outfile:
        json.dump(json_list, outfile, indent=4)

def save_as_jsonl(json_list, output_file_path):
    with open(output_file_path, 'w', encoding='utf-8') as outfile:
        for json_obj in json_list:
            json.dump(json_obj, outfile)
            outfile.write('\n')


# Get all test files with options
test_files = glob.glob('./data/**/*test.jsonl', recursive=True)

# Print found files
print("Found test files with options:")
for f in test_files:
    print(f'\t{f}')


Found test files with options:
	./data/mmlu-pro/test.jsonl
	./data/pubmedqa/test.jsonl
	./data/medqa/test.jsonl
	./data/afrimedqa/test.jsonl
	./data/medbullets/test.jsonl
	./data/medqa_5options/test.jsonl
	./data/mmlu/test.jsonl
	./data/medmcqa/test.jsonl


## MedQA

Please follow [https://github.com/jind11/MedQA](https://github.com/jind11/MedQA) to download the data. Name the folder as `medqa` and create a `test.jsonl` file.

In [7]:
os.makedirs('./data/medqa', exist_ok=True)

medqa_train_set = load_jsonl(os.path.join('./data/medqa', 'train.jsonl'))
medqa_train_set = [{'realidx': i, **item} for i, item in enumerate(medqa_train_set)]
medqa_test_set = load_jsonl(os.path.join('./data/medqa', 'test.jsonl'))
medqa_test_set = [{'realidx': i, **item} for i, item in enumerate(medqa_test_set)]

df_train = pd.DataFrame(medqa_train_set)

medqa_train_set = df_train.to_dict(orient='records')
save_as_jsonl(medqa_train_set, os.path.join('./data/medqa', 'train.jsonl'))

df_test = pd.DataFrame(medqa_test_set)

medqa_test_set = df_test.to_dict(orient='records')
save_as_jsonl(medqa_test_set, os.path.join('./data/medqa', 'test.jsonl'))

sampled_df = df_test.sample(50, random_state=42).sort_values(by='realidx')
sampled_50_medqa = sampled_df.to_dict(orient='records')

# save medqa test set
save_as_jsonl(sampled_50_medqa, os.path.join('./data/medqa', 'sampled_50.jsonl'))

sampled_df.head(10)


Unnamed: 0,realidx,question,answer,options,meta_info,answer_idx,metamap_phrases
23,23,A 62-year-old patient has been hospitalized fo...,Staphylococcus aureus,"{'A': 'Streptococcus pneumoniae', 'B': 'Mycoba...",step1,D,"[62 year old patient, hospitalized, week, stro..."
43,43,A healthy 23-year-old male is undergoing an ex...,Coronary sinus,"{'A': 'Inferior vena cava', 'B': 'Coronary sin...",step1,B,"[healthy 23 year old male, exercise stress tes..."
51,51,A 56-year-old man with a history of hypertensi...,Aldosterone excess,"{'A': 'Aldosterone excess', 'B': 'Catecholamin...",step2&3,A,"[year old man, history of hypertension present..."
63,63,An 80-year-old man is transferred from a step-...,Insert a ‘straight cath’ into the patient’s bl...,{'A': 'Insert a ‘straight cath’ into the patie...,step2&3,A,"[80 year old man, transferred, step-down unit,..."
76,76,A 62-year old female comes to the physician be...,Biopsy of the mass,"{'A': 'Pap smear', 'B': 'Biopsy of the mass', ...",step2&3,B,"[62 year old female, physician, vaginal spotti..."
101,101,A 65-year old man presents with gradually wors...,Amantadine,"{'A': 'Amantadine', 'B': 'Ribavirin', 'C': 'Le...",step1,A,"[65 year old man presents, worsening rigidity,..."
123,123,A 41-year-old G3P1 woman presents with a sudde...,Mixing study,"{'A': 'Mixing study', 'B': 'INR', 'C': 'Ristoc...",step1,A,"[year old, woman presents, sudden onset throbb..."
128,128,A 47-year-old woman comes to the physician bec...,Intrafascicular infiltration on muscle biopsy,{'A': 'Intrafascicular infiltration on muscle ...,step2&3,A,"[year old woman, physician, of progressive mus..."
155,155,A 19-year-old man is brought to the emergency ...,Synthetic cathinone intoxication,"{'A': 'Brief psychotic disorder', 'B': 'Neurol...",step2&3,D,"[year old man, brought, emergency department, ..."
168,168,A 56-year-old man is brought to the emergency ...,Undergo upper GI endoscopy,"{'A': 'Undergo colonoscopy', 'B': 'Undergo upp...",step2&3,B,"[year old man, brought, emergency department, ..."


In [4]:
os.makedirs('./data/medqa_5options', exist_ok=True)

# subsample medqa 5 options
medqa_5options = load_jsonl(os.path.join('./data/medqa_5options', 'test.jsonl'))
medqa_5options = [{'realidx': i, **item} for i, item in enumerate(medqa_5options)]

df = pd.DataFrame(medqa_5options)
sampled_df = df.sample(50, random_state=42).sort_values(by='realidx')
sampled_50_medqa_5options = sampled_df.to_dict(orient='records')

# save medqa 5 options
save_as_jsonl(sampled_50_medqa_5options, os.path.join('./data/medqa_5options', 'sampled_50.jsonl'))

sampled_df.head(10)


Unnamed: 0,realidx,question,answer,options,meta_info,answer_idx
23,23,A 62-year-old patient has been hospitalized fo...,Staphylococcus aureus,"{'A': 'Pseudomona aeruginosa', 'B': 'Streptoco...",step1,E
43,43,A healthy 23-year-old male is undergoing an ex...,Coronary sinus,"{'A': 'Superior vena cava', 'B': 'Inferior ven...",step1,C
51,51,A 56-year-old man with a history of hypertensi...,Aldosterone excess,"{'A': 'Aldosterone excess', 'B': 'Catecholamin...",step2&3,A
63,63,An 80-year-old man is transferred from a step-...,Insert a ‘straight cath’ into the patient’s bl...,{'A': 'Insert a ‘straight cath’ into the patie...,step2&3,A
76,76,A 62-year old female comes to the physician be...,Biopsy of the mass,"{'A': 'Pap smear', 'B': 'Biopsy of the mass', ...",step2&3,B
101,101,A 65-year old man presents with gradually wors...,Amantadine,"{'A': 'Amantadine', 'B': 'Ribavirin', 'C': 'Ac...",step1,A
123,123,A 41-year-old G3P1 woman presents with a sudde...,Mixing study,"{'A': 'Mixing study', 'B': 'INR', 'C': 'D-dime...",step1,A
128,128,A 47-year-old woman comes to the physician bec...,Intrafascicular infiltration on muscle biopsy,{'A': 'Intrafascicular infiltration on muscle ...,step2&3,A
155,155,A 19-year-old man is brought to the emergency ...,Synthetic cathinone intoxication,"{'A': 'Brief psychotic disorder', 'B': 'Neurol...",step2&3,E
168,168,A 56-year-old man is brought to the emergency ...,Undergo upper GI endoscopy,"{'A': 'Undergo colonoscopy', 'B': 'Undergo upp...",step2&3,B


## PubmedQA
Please follow [https://github.com/pubmedqa/pubmedqa](https://github.com/pubmedqa/pubmedqa) to download the data. Name the folder as `pubmedqa` and create a `test_set.json` file.

In [28]:
os.makedirs('./data/pubmedqa', exist_ok=True)

# look for pubmedqa test set
pubmedqa_train_set = json.load(open(os.path.join('./data/pubmedqa', 'train_set.json'), 'r', encoding='utf-8'))
pubmedqa_train_set = [{'realidx': idx, **item} for idx, item in pubmedqa_train_set.items()]
pubmedqa_test_set = json.load(open(os.path.join('./data/pubmedqa', 'test_set.json'), 'r', encoding='utf-8'))
pubmedqa_test_set = [{'realidx': idx, **item} for idx, item in pubmedqa_test_set.items()]

df_train = pd.DataFrame(pubmedqa_train_set)
df_meta = df_train[['reasoning_required_pred', 'reasoning_free_pred', 'YEAR', 'MESHES', 'LABELS']]
df_train = df_train.drop(columns=['reasoning_required_pred', 'reasoning_free_pred', 'YEAR', 'MESHES', 'LABELS'])
df_train = df_train.rename(columns={'QUESTION': 'question', 'CONTEXTS': 'context', 'final_decision': 'answer', 'LONG_ANSWER': 'answer_rationale'})
df_train['options'] = [{'A': 'yes', 'B': 'no', 'C': 'maybe'} for _ in range(len(df_train))]
df_train['answer_idx'] = df_train['answer'].map({'yes': 'A', 'no': 'B', 'maybe': 'C'})
df_train['context'] = df_train['context'].apply(lambda x: '\n'.join(x))
# context + question is question
df_train['question'] = df_train['context'] + '\n' + df_train['question']
df_train.drop(columns=['context'], inplace=True)
df_train['answer'] = df_train.apply(lambda row: row['options'][row['answer_idx']], axis=1)

# save pubmedqa train set
save_as_jsonl(df_train.to_dict(orient='records'), os.path.join('./data/pubmedqa', 'train.jsonl'))

df_test = pd.DataFrame(pubmedqa_test_set)
# rename columns
df_meta = df_test[['reasoning_required_pred', 'reasoning_free_pred', 'YEAR', 'MESHES', 'LABELS']]
df_test = df_test.drop(columns=['reasoning_required_pred', 'reasoning_free_pred', 'YEAR', 'MESHES', 'LABELS'])
df_test = df_test.rename(columns={'QUESTION': 'question', 'CONTEXTS': 'context', 'final_decision': 'answer', 'LONG_ANSWER': 'answer_rationale'})
df_test['options'] = [{'A': 'yes', 'B': 'no', 'C': 'maybe'} for _ in range(len(df_test))]
df_test['answer_idx'] = df_test['answer'].map({'yes': 'A', 'no': 'B', 'maybe': 'C'})
df_test['context'] = df_test['context'].apply(lambda x: '\n'.join(x))
# context + question is question
df_test['question'] = df_test['context'] + '\n' + df_test['question']
df_test.drop(columns=['context'], inplace=True)
df_test['answer'] = df_test.apply(lambda row: row['options'][row['answer_idx']], axis=1)

# save df as jsonl
pubmedqa_test_set = df_test.to_dict(orient='records')
save_as_jsonl(pubmedqa_test_set, os.path.join('./data/pubmedqa', 'test.jsonl'))

sampled_df = df_test.sample(50, random_state=42).sort_values(by='realidx')
sampled_50_pubmedqa = sampled_df.to_dict(orient='records')
save_as_jsonl(sampled_50_pubmedqa, os.path.join('./data/pubmedqa', 'sampled_50.jsonl'))

sampled_df.head(10)

Unnamed: 0,realidx,question,answer,answer_rationale,options,answer_idx
209,10135926,Patients transported by helicopter often requi...,yes,Oral endotracheal intubation in the in-flight ...,"{'A': 'yes', 'B': 'no', 'C': 'maybe'}",A
336,10973547,"It is generally assumed, that patients with We...",no,Patients with WD may possibly undergo cardiac ...,"{'A': 'yes', 'B': 'no', 'C': 'maybe'}",B
63,12094116,The purpose of this study was to identify the ...,yes,The relationships between leg muscle power and...,"{'A': 'yes', 'B': 'no', 'C': 'maybe'}",A
0,12377809,Dyschesia can be provoked by inappropriate def...,yes,Linear anorectal endosonography demonstrated i...,"{'A': 'yes', 'B': 'no', 'C': 'maybe'}",A
490,12407608,To investigate whether prepuncture ultrasound ...,maybe,Prepuncture ultrasound evaluation did not impr...,"{'A': 'yes', 'B': 'no', 'C': 'maybe'}",C
194,12595848,Implementation of the complex treatment strate...,yes,We found an improved survival associated with ...,"{'A': 'yes', 'B': 'no', 'C': 'maybe'}",A
491,14599616,Lymphedema may be identified by simpler circum...,maybe,An increase of 5% in circumference measurement...,"{'A': 'yes', 'B': 'no', 'C': 'maybe'}",C
84,15208005,Low intakes or blood levels of eicosapentaenoi...,yes,"The Omega-3 Index may represent a novel, physi...","{'A': 'yes', 'B': 'no', 'C': 'maybe'}",A
93,15489384,Spasticity and loss of function in an affected...,yes,"Using a targeted meta-analytic approach, it is...","{'A': 'yes', 'B': 'no', 'C': 'maybe'}",A
76,15528969,Current guidelines include a recommendation th...,yes,Expert breast pathology assessments continue t...,"{'A': 'yes', 'B': 'no', 'C': 'maybe'}",A


## MedMCQA

We use dev set as test set. There are multi and single choice questions, but we only keep single choice questions.

In [19]:
# save a copy of medmcqa train, dev, test set
# medmcqa_train_set = load_jsonl(os.path.join('./data/medmcqa', 'train.jsonl'))
# medmcqa_dev_set = load_jsonl(os.path.join('./data/medmcqa', 'dev.jsonl'))
# medmcqa_test_set = load_jsonl(os.path.join('./data/medmcqa', 'test.jsonl'))

# save_as_json(medmcqa_train_set, os.path.join('./data/medmcqa', 'train.json'))
# save_as_json(medmcqa_dev_set, os.path.join('./data/medmcqa', 'dev.json'))
# save_as_json(medmcqa_test_set, os.path.join('./data/medmcqa', 'test.json'))

In [27]:
os.makedirs('./data/medmcqa', exist_ok=True)

medmcqa_test_set = json.load(open(os.path.join('./data/medmcqa', 'dev.json'), 'r', encoding='utf-8'))
medmcqa_test_set = [{**item} for item in medmcqa_test_set]
medmcqa_train_set = json.load(open(os.path.join('./data/medmcqa', 'train.json'), 'r', encoding='utf-8'))
medmcqa_train_set = [{**item} for item in medmcqa_train_set]

cop_map = {1: 'A', 2: 'B', 3: 'C', 4: 'D'}

df_train = pd.DataFrame(medmcqa_train_set)
df_train = df_train[df_train['choice_type'] == 'single']
df_train['options'] = df_train.apply(lambda row: {'A': row['opa'], 'B': row['opb'], 'C': row['opc'], 'D': row['opd']}, axis=1)
df_train['answer_idx'] = df_train['cop'].map(cop_map)
df_train.rename(columns={'question': 'question', 'answer': 'answer', 'answer_rationale': 'answer_rationale', 'id': 'realidx'}, inplace=True)
df_train.drop(columns=['opa', 'opb', 'opc', 'opd', 'cop', 'exp', 'choice_type', 'topic_name', 'subject_name'], inplace=True)
df_train['answer'] = df_train.apply(lambda row: row['options'][row['answer_idx']], axis=1)

# save medmcqa train set
save_as_jsonl(df_train.to_dict(orient='records'), os.path.join('./data/medmcqa', 'train.jsonl'))

df_test = pd.DataFrame(medmcqa_test_set)
df_test = df_test[df_test['choice_type'] == 'single']
df_test['options'] = df_test.apply(lambda row: {'A': row['opa'], 'B': row['opb'], 'C': row['opc'], 'D': row['opd']}, axis=1)
df_test['answer_idx'] = df_test['cop'].map(cop_map)
df_test.rename(columns={'question': 'question', 'answer': 'answer', 'answer_rationale': 'answer_rationale', 'id': 'realidx'}, inplace=True)
df_test.drop(columns=['opa', 'opb', 'opc', 'opd', 'cop', 'exp', 'choice_type', 'topic_name', 'subject_name'], inplace=True)
df_test['answer'] = df_test.apply(lambda row: row['options'][row['answer_idx']], axis=1)

# save medmcqa test set
medmcqa_test_set = df_test.to_dict(orient='records')
save_as_jsonl(medmcqa_test_set, os.path.join('./data/medmcqa', 'test.jsonl'))

sampled_df = df_test.sample(50, random_state=42).sort_values(by='realidx')
sampled_50_medmcqa = sampled_df.to_dict(orient='records')


# save medmcqa test set
save_as_jsonl(sampled_50_medmcqa, os.path.join('./data/medmcqa', 'sampled_50.jsonl'))

sampled_df.head(10)

Unnamed: 0,question,realidx,options,answer_idx,answer
3563,An 11-year-old boy complains of spacing betwee...,12a8e0d0-21d1-4edf-905d-5e9c415b1a80,"{'A': 'Hawley's appliance', 'B': 'Fixed applia...",C,No treatment
1663,Major determinant of loading dose of a drug is:-,156f862e-9e92-4070-b0f0-7beacc93d11b,"{'A': 'Half life', 'B': 'Clearance', 'C': 'Vol...",C,Volume of distribution
3102,Which of the following amino acids does not in...,185c4942-7886-4e49-b242-6634e83b0efb,"{'A': 'Selenocysteine', 'B': 'Triiodothyronine...",A,Selenocysteine
2641,Most common site of esophageal carcinoma?,1e6928e0-01e6-4346-8ea7-25cecbb99932,"{'A': 'Middle 1/3rd of esophagus', 'B': 'Upper...",A,Middle 1/3rd of esophagus
684,"The disturbances occurred during ""Calcificatio...",2cacbd66-ae8e-45cc-85b8-6242487724b6,"{'A': 'Peg laterals', 'B': 'Microdontia', 'C':...",D,Interglobular dentin
268,Which of the following is characterized by App...,315212cd-d605-4f61-8d4e-535ef7847059,"{'A': 'Scrofula', 'B': 'Lupus vulgaris', 'C': ...",B,Lupus vulgaris
293,23 serotypes pneumococcal vaccine Most useful in,44483815-3319-493d-b156-d3663a4d61a1,"{'A': 'Cystic fibrosis', 'B': 'Recurrent otiti...",D,Sickle cell anaemia
1416,T-lymphocytes play a primary role in,4e3061f9-0a14-4878-9abd-6a0459b268b8,"{'A': 'Production of Antibodies', 'B': 'Produc...",B,Production\tof\tlymphokines\tand\tdelayed hype...
2681,Which of the following drug is used to counter...,4fbd9ccb-2efb-4e4a-bcea-2db337a825ff,"{'A': 'Roxatidine', 'B': 'Pirenzipine', 'C': '...",D,Misoprostol
670,Most common phobia in chilhood:,4ffc9c91-2230-44f7-826f-91b8d683ab20,"{'A': 'Zoophobia', 'B': 'Nyclophobia', 'C': 'X...",A,Zoophobia


## AfriMedQA

In [24]:
from datasets import load_dataset

os.makedirs('./data/afrimedqa', exist_ok=True)

options_map = {'option1': 'A', 'option2': 'B', 'option3': 'C', 'option4': 'D', 'option5': 'E', 'option6': 'F', 'option7': 'G', 'option8': 'H', 'option9': 'I', 'option10': 'J'}

afrimedqa_test_set = load_dataset('intronhealth/afrimedqa_v2')['train'].to_pandas()
afrimedqa_test_set = afrimedqa_test_set[afrimedqa_test_set['question_type'] == 'mcq']
afrimedqa_test_set = afrimedqa_test_set[['sample_id', 'question_clean', 'answer_options', 'correct_answer', 'answer_rationale']]
afrimedqa_test_set.rename(columns={'sample_id': 'realidx', 'question_clean': 'question', 'answer_options': 'options', 'correct_answer': 'answer_idx', 'answer_rationale': 'reason'}, inplace=True)
afrimedqa_test_set['options'] = afrimedqa_test_set['options'].apply(lambda x: eval(x))
afrimedqa_test_set['options'] = afrimedqa_test_set['options'].apply(lambda x: {options_map[k]: v for k, v in x.items()})
afrimedqa_test_set['answer_idx'] = afrimedqa_test_set['answer_idx'].map(options_map)
afrimedqa_test_set.dropna(inplace=True)     # multiple choice questions should be removed
afrimedqa_test_set['answer'] = afrimedqa_test_set.apply(lambda row: row['options'][row['answer_idx']], axis=1)

# save afrimedqa test set
save_as_jsonl(afrimedqa_test_set.to_dict(orient='records'), os.path.join('./data/afrimedqa', 'test.jsonl'))

sampled_df = afrimedqa_test_set.sample(50, random_state=42).sort_values(by='realidx')
sampled_50_afrimedqa = sampled_df.to_dict(orient='records')

# save afrimedqa test set
save_as_jsonl(sampled_50_afrimedqa, os.path.join('./data/afrimedqa', 'sampled_50.jsonl'))

sampled_df.head(10)


298


Unnamed: 0,realidx,question,options,answer_idx,reason,answer
5585,013a76385bffe9bbc150095f79b78ec1d2affd4fa7eabb...,Which variety of HIV is most common in West Af...,"{'A': 'HIV 1', 'B': 'HIV 2', 'C': 'HIV 5', 'D'...",B,HIV 2 is commonest in West Africa,HIV 2
5561,020d17d0a81fab48c9c94335789a7ce7e420ef3bfd76b0...,Which of the following is a common presentatio...,"{'A': 'Rash resembling measles', 'B': 'Hemolyt...",A,African tick-bite fever often presents with a ...,Rash resembling measles
14236,03905375f23fcdbf704053df35e70f593e2d38df3a0f6c...,What is the commonest cause of intestinal obst...,"{'A': 'Acute appendicitis', 'B': 'Stangulated ...",B,External stangulated inguinal hernias are know...,Stangulated inguinal hernia
5612,03fa5fdd2a877ec7f05f5ff8ee0de48c31723f0a2858c6...,"In a rural African community, a patient presen...","{'A': 'Generalized anxiety disorder (GAD)', 'B...",D,In African contexts where mental health resour...,Major depressive disorder (MDD)
5595,05139259d5adda160966fd94081026b7c7caf94c61219d...,In a rural African setting with limited access...,"{'A': 'Aortic regurgitation', 'B': 'Mitral reg...",A,Infective endocarditis may lead to aortic regu...,Aortic regurgitation
5509,09cbeee45181ee6a8b98880fd5beb1e7e63bce5cdaad8b...,"Which of the following conditions, prevalent i...","{'A': 'Eczema', 'B': 'Psoriasis', 'C': 'Dermat...",B,Psoriasis is a chronic inflammatory skin diso...,Psoriasis
5524,0f46c8b52dd21fc4d659bdd37e543323612714bf5e9706...,Which of the following is a common complicatio...,"{'A': 'Cirrhosis', 'B': 'Hepatocellular carcin...",B,Chronic hepatitis B infection is a major risk ...,Hepatocellular carcinoma (HCC)
5576,104cbc9b5d91699136f65ddd5f7ca29205653dcc9b5ac6...,Which of the following is a common presentatio...,"{'A': 'Painful genital ulcers', 'B': 'Large, p...",B,Buruli ulcer often presents with painless skin...,"Large, painless skin ulcers"
14108,16d2b6f9953b72c6286eaa6438df8b942a42c3171b546a...,The most appropriate treatment for Supraventri...,"{'A': 'Digoxin', 'B': 'Flecainide', 'C': 'Aden...",C,N\A,Adenosine
14092,1af1e2c3fac36ae2c1fd1c4b35407a9bd371a3d174ef3f...,"\r\n28. About Kawasaki disease, which of the f...","{'A': 'A. Vasculitis of large arteries', 'B': ...",D,N/a,D. Aspirin is contraindicated


## MMLU

We follow the Med-PaLM's setting, and only keep the following fields:

clinical_knowledge, professional_medicine, college_medicine, medical_genetics, anatomy, college_biology


In [26]:
from datasets import load_dataset

os.makedirs('./data/mmlu', exist_ok=True)

options_map = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
mmlu_test_set = load_dataset('cais/mmlu', 'all')['test']
mmlu_test_set = mmlu_test_set.to_pandas()

mmlu_test_set = mmlu_test_set[mmlu_test_set['subject'].isin(['clinical_knowledge', 'professional_medicine', 'college_medicine', 'medical_genetics', 'anatomy', 'college_biology'])]

print(len(mmlu_test_set))
mmlu_test_set.rename(columns={'question': 'question', 'choices': 'options', 'answer': 'answer_idx'}, inplace=True)
mmlu_test_set['options'] = mmlu_test_set['options'].apply(lambda x: {options_map[i]: x[i] for i in range(len(x))})
mmlu_test_set['answer_idx'] = mmlu_test_set['answer_idx'].apply(lambda x: options_map[x])
mmlu_test_set['realidx'] = mmlu_test_set.index
mmlu_test_set = mmlu_test_set[['realidx', 'question', 'options', 'answer_idx', 'subject']]
mmlu_test_set['answer'] = mmlu_test_set.apply(lambda row: row['options'][row['answer_idx']], axis=1)
mmlu_test_set.head(10)

# save mmlu test set
save_as_jsonl(mmlu_test_set.to_dict(orient='records'), os.path.join('./data/mmlu', 'test.jsonl'))

sampled_df = mmlu_test_set.sample(50, random_state=42).sort_values(by='realidx')
sampled_50_mmlu = sampled_df.to_dict(orient='records')

# save mmlu test set
save_as_jsonl(sampled_50_mmlu, os.path.join('./data/mmlu', 'sampled_50.jsonl'))

sampled_df.head(10)


1089


Unnamed: 0,realidx,question,options,answer_idx,subject,answer
151,151,Which of the following structures accompany th...,"{'A': 'The phrenic nerves', 'B': 'The splanchn...",D,anatomy,The vagus nerves
156,156,The infraorbital nerve,{'A': 'is a terminal branch of the maxillary d...,A,anatomy,is a terminal branch of the maxillary division...
170,170,The major concentrations of proprioceptive rec...,{'A': 'the capsule and ligaments of the TMJ an...,B,anatomy,the capsule and ligaments of the TMJ and the l...
186,186,Which of one of the following statements about...,{'A': 'Pneumatisation by enlargement of the de...,C,anatomy,The zygomaticomaxillary sutures contribute to ...
188,188,Which of the following paranasal sinuses open ...,"{'A': 'The anterior ethmoidal sinuses', 'B': '...",C,anatomy,"The anterior ethmoidal, frontal and maxillary ..."
196,196,Parasympathetic preganglionic nerves leave the...,"{'A': 'third cranial nerves.', 'B': 'fourth cr...",A,anatomy,third cranial nerves.
201,201,The lateral pterygoid muscle,{'A': 'is attached to the coronoid process and...,D,anatomy,is attached to the condylar process and protru...
491,491,In what situation are closed pouches applied?,{'A': 'The patient has a semi-formed or liquid...,B,clinical_knowledge,The patient has a colostomy.
570,570,Dopamine is prescribed at a rate of 4 microgra...,"{'A': '156', 'B': '15.6', 'C': '1.56', 'D': '1...",B,clinical_knowledge,15.6
572,572,Why can't a patient talk if the cuff is inflated?,{'A': 'They are unable to breathe in sufficien...,D,clinical_knowledge,They are unable to pass air through their voca...


## MMLU-Pro

**Law & Business:**

ori_mmlu-professional_law, ori_mmlu-jurisprudence, ori_mmlu-business_ethics, ori_mmlu-management,
ori_mmlu-professional_accounting, ori_mmlu-marketing, ori_mmlu-public_relations

**Science & Engineering:**

stemez-Chemistry, stemez-Physics, stemez-Mechanics, scibench-chemmc, stemez-OrganicChemistry, stemez-Biology,
stemez-PhysicalChemistry, stemez-Optics, stemez-Thermodynamics, stemez-TransportPhenomena, stemez-Genetics,
stemez-ElectronicCommunications, stemez-ElectricalMachines, stemez-Electromagnetics, stemez-MachineDesign,
stemez-ElectricCircuits, stemez-HeatTransfer, stemez-FluidMechanics, stemez-ComputerScience

**Mathematics & Computing:**

theoremQA-Math, scibench-quan, scibench-class, scibench-calculus, scibench-diff, scibench-matter,
ori_mmlu-computer_security, theoremQA-EECS, ori_mmlu-college_computer_science, scibench-stat, scibench-fund,
scibench-thermo

**Social Sciences & Humanities:**

ori_mmlu-us_foreign_policy, ori_mmlu-professional_psychology, stemez-Psychology, ori_mmlu-sociology,
ori_mmlu-high_school_government_and_politics, ori_mmlu-prehistory, ori_mmlu-moral_disputes, ori_mmlu-philosophy,
ori_mmlu-security_studies, ori_mmlu-world_religions, ori_mmlu-international_law

**General Education:**

ori_mmlu-elementary_mathematics, ori_mmlu-miscellaneous, ori_mmlu-high_school_mathematics,
ori_mmlu-high_school_macroeconomics, ori_mmlu-nutrition, ori_mmlu-virology, ori_mmlu-human_sexuality,
ori_mmlu-high_school_psychology, ori_mmlu-high_school_statistics, ori_mmlu-high_school_microeconomics,
ori_mmlu-high_school_biology, ori_mmlu-conceptual_physics, ori_mmlu-high_school_physics,
ori_mmlu-high_school_chemistry, ori_mmlu-human_aging, ori_mmlu-high_school_world_history, ori_mmlu-global_facts,
ori_mmlu-abstract_algebra, ori_mmlu-high_school_us_history, ori_mmlu-astronomy, ori_mmlu-anatomy,
ori_mmlu-college_biology, ori_mmlu-college_mathematics, ori_mmlu-logical_fallacies,
ori_mmlu-high_school_european_history, ori_mmlu-college_physics, ori_mmlu-electrical_engineering,
ori_mmlu-high_school_geography, ori_mmlu-college_chemistry, ori_mmlu-high_school_computer_science

Only keeping health fields (clinical knowledge, professional medicine, college medicine, medical genetics, nutrition, human aging, anatomy, virology)




In [29]:
from datasets import load_dataset

os.makedirs('./data/mmlu-pro', exist_ok=True)
options_map = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']

mmlu_pro_train_set = load_dataset('TIGER-Lab/MMLU-Pro')['validation']
mmlu_pro_train_set = mmlu_pro_train_set.to_pandas()

# to include: ori_mmlu-clinical_knowledge ori_mmlu-professional_medicine, ori_mmlu-college_medicine, ori_mmlu-medical_genetics, ori_mmlu-nutrition, ori_mmlu-human_aging, ori_mmlu-anatomy, ori_mmlu-virology
mmlu_pro_train_set = mmlu_pro_train_set[mmlu_pro_train_set['category'].isin(['health'])]

mmlu_pro_train_set.rename(columns={'question_id': 'realidx', 'question': 'question', 'options': 'options', 'answer': 'answer_idx'}, inplace=True)
mmlu_pro_train_set['options'] = mmlu_pro_train_set['options'].apply(lambda x: {options_map[i]: x[i] for i in range(len(x))})
mmlu_pro_train_set['answer'] = mmlu_pro_train_set.apply(lambda row: row['options'][row['answer_idx']], axis=1)
mmlu_pro_train_set.drop(columns=['cot_content', 'answer_index'], inplace=True)

# save mmlu pro train set
save_as_jsonl(mmlu_pro_train_set.to_dict(orient='records'), os.path.join('./data/mmlu-pro', 'train.jsonl'))
save_as_jsonl(mmlu_pro_train_set.to_dict(orient='records'), os.path.join('./data/mmlu', 'train.jsonl'))


mmlu_pro_test_set = load_dataset('TIGER-Lab/MMLU-Pro')['test']
mmlu_pro_test_set = mmlu_pro_test_set.to_pandas()

# to include: ori_mmlu-clinical_knowledge ori_mmlu-professional_medicine, ori_mmlu-college_medicine, ori_mmlu-medical_genetics, ori_mmlu-nutrition, ori_mmlu-human_aging, ori_mmlu-anatomy, ori_mmlu-virology
mmlu_pro_test_set = mmlu_pro_test_set[mmlu_pro_test_set['category'].isin(['health'])]

mmlu_pro_test_set.rename(columns={'question_id': 'realidx', 'question': 'question', 'options': 'options', 'answer': 'answer_idx'}, inplace=True)
mmlu_pro_test_set['options'] = mmlu_pro_test_set['options'].apply(lambda x: {options_map[i]: x[i] for i in range(len(x))})

mmlu_pro_test_set.drop(columns=['cot_content', 'answer_index'], inplace=True)
mmlu_pro_test_set['answer'] = mmlu_pro_test_set.apply(lambda row: row['options'][row['answer_idx']], axis=1)
mmlu_pro_test_set.head(10)

# save mmlu pro test set
save_as_jsonl(mmlu_pro_test_set.to_dict(orient='records'), os.path.join('./data/mmlu-pro', 'test.jsonl'))

sampled_df = mmlu_pro_test_set.sample(50, random_state=42).sort_values(by='realidx')
sampled_50_mmlu_pro = sampled_df.to_dict(orient='records')

# save mmlu pro test set
save_as_jsonl(sampled_50_mmlu_pro, os.path.join('./data/mmlu-pro', 'sampled_50.jsonl'))

sampled_df.head(10)

Unnamed: 0,realidx,question,options,answer_idx,category,src,answer
5865,6024,Which of the following anatomical regions of a...,"{'A': 'Pectoral', 'B': 'Iliac', 'C': 'Subcosta...",F,health,ori_mmlu-anatomy,Epigastric
5872,6032,A patient with damage to their cervical sympat...,{'A': 'Pupillary constriction and vasodilation...,A,health,ori_mmlu-anatomy,Pupillary constriction and vasodilation of fac...
5881,6041,A 25-year-old man is brought to the emergency ...,"{'A': 'DNA helicase', 'B': 'Ribosomal assembly...",B,health,ori_mmlu-professional_medicine,Ribosomal assembly
5907,6067,How are new polyomaviruses detailed,"{'A': 'Shot gun sequencing', 'B': 'Cultivation...",A,health,ori_mmlu-virology,Shot gun sequencing
5908,6068,Describe the coronavirus structure.,{'A': 'Club shaped glycoprotein spikes protrud...,C,health,ori_mmlu-virology,An icosahedral large pleomorphic virus
5909,6069,Disease can most often be prevented by which o...,"{'A': 'Sunscreen', 'B': 'Vaccines', 'C': 'Anti...",B,health,ori_mmlu-virology,Vaccines
5920,6080,The energy released from the breakdown of the ...,"{'A': '20-30 minutes.', 'B': '1-2 seconds.', '...",H,health,ori_mmlu-college_medicine,5-10 seconds.
5928,6089,Which of the following is true about the carpa...,"{'A': 'It causes numbness in the entire arm', ...",B,health,ori_mmlu-clinical_knowledge,It can be caused by rheumatoid arthritis
5962,6123,Which of the following statements is not true?\n,{'A': 'Vegan diets are likely to be deficient ...,A,health,ori_mmlu-nutrition,Vegan diets are likely to be deficient in protein
5979,6140,Glycogen breakdown in muscle initially results...,"{'A': 'glucose-6-phosphate.', 'B': 'glucose-1,...",I,health,ori_mmlu-college_medicine,glucose-1-phosphate.


## MedBullets

In [32]:
# ls data/medbullets/
# sampled_50_hard.jsonl  test_bad.jsonl  test_easy.jsonl  test_good.jsonl  test_hard.jsonl

medbullets_test_set_bad = load_jsonl(os.path.join('./data/medbullets', 'test_bad.jsonl'))
medbullets_test_set_easy = load_jsonl(os.path.join('./data/medbullets', 'test_easy.jsonl'))
medbullets_test_set_good = load_jsonl(os.path.join('./data/medbullets', 'test_good.jsonl'))
medbullets_test_set_hard = load_jsonl(os.path.join('./data/medbullets', 'test_hard.jsonl'))
medbullets_test_set_sampled = load_jsonl(os.path.join('./data/medbullets', 'sampled_50_hard.jsonl'))

df_test_hard = pd.DataFrame(medbullets_test_set_hard)
df_test_hard.rename(columns={'explanation': 'reason', 'id': 'realidx'}, inplace=True)

df_test_easy = pd.DataFrame(medbullets_test_set_easy)
df_test_easy.rename(columns={'explanation': 'reason', 'id': 'realidx'}, inplace=True)

df_test_good = pd.DataFrame(medbullets_test_set_good)
df_test_good.rename(columns={'explanation': 'reason', 'id': 'realidx'}, inplace=True)

df_test_bad = pd.DataFrame(medbullets_test_set_bad)
df_test_bad.rename(columns={'explanation': 'reason', 'id': 'realidx'}, inplace=True)

df_test_sampled = pd.DataFrame(medbullets_test_set_sampled)
df_test_sampled.rename(columns={'explanation': 'reason', 'id': 'realidx'}, inplace=True)

# save medbullets test set
save_as_jsonl(df_test_hard.sort_values(by='realidx').to_dict(orient='records'), os.path.join('./data/medbullets', 'test_hard.jsonl'))
save_as_jsonl(df_test_easy.sort_values(by='realidx').to_dict(orient='records'), os.path.join('./data/medbullets', 'test_easy.jsonl'))
save_as_jsonl(df_test_good.sort_values(by='realidx').to_dict(orient='records'), os.path.join('./data/medbullets', 'test_good.jsonl'))
save_as_jsonl(df_test_bad.sort_values(by='realidx').to_dict(orient='records'), os.path.join('./data/medbullets', 'test_bad.jsonl'))
save_as_jsonl(df_test_all.sort_values(by='realidx').to_dict(orient='records'), os.path.join('./data/medbullets', 'test.jsonl'))
save_as_jsonl(df_test_sampled.sort_values(by='realidx').to_dict(orient='records'), os.path.join('./data/medbullets', 'sampled_50_hard.jsonl'))



In [18]:
mmlu_pro_test_set['options'].apply(lambda x: len(x)).value_counts()

options
10    597
4      72
9      67
8      42
7      21
5       8
6       8
3       3
Name: count, dtype: int64

In [31]:
# Create LaTeX table from results using Python libraries
import glob
import pandas as pd
from tabulate import tabulate

def analyze_dataset(base_path):
    test_file = os.path.join(base_path, 'test.jsonl')
    good_file = os.path.join(base_path, 'test_good.jsonl')
    hard_file = os.path.join(base_path, 'test_hard.jsonl')
    
    results = {}
    
    # Get total tests and count options
    if os.path.exists(test_file):
        test_data = load_jsonl(test_file)
        results['total'] = len(test_data)
        # Count number of options in first question
        if len(test_data) > 0 and 'options' in test_data[0]:
            results['num_options'] = len(test_data[0]['options'])
        else:
            results['num_options'] = 0
    else:
        results['total'] = 0
        results['num_options'] = 0
        
    # Get good tests
    if os.path.exists(good_file):
        results['good'] = len(load_jsonl(good_file))
    else:
        results['good'] = 0
        
    # Get hard tests
    if os.path.exists(hard_file):
        results['hard'] = len(load_jsonl(hard_file))
    else:
        results['hard'] = 0
        
    return results

datasets = glob.glob('./data/*')
results = []

for dataset in datasets:
    if dataset.endswith('.py') or dataset.endswith('.ipynb'):
        continue
    name = os.path.basename(dataset)
    stats = analyze_dataset(dataset)
    
    results.append({
        'Dataset': name,
        'Total Tests': stats['total'],
        'Good Tests': stats['good'],
        'Good %': f"{stats['good']/stats['total']*100:.1f}%" if stats['total'] > 0 else "N/A",
        'Hard Tests': stats['hard'],
        'Hard %': f"{stats['hard']/stats['good']*100:.1f}%" if stats['good'] > 0 else "N/A",
        'Num Options': stats['num_options']
    })

results_df = pd.DataFrame(results)
print(tabulate(results_df, headers='keys', tablefmt='latex_booktabs', showindex=False))

\begin{tabular}{lrrlrlr}
\toprule
 Dataset        &   Total Tests &   Good Tests & Good \%   &   Hard Tests & Hard \%   &   Num Options \\
\midrule
 mmlu-pro       &           818 &          813 & 99.4\%    &          303 & 37.3\%    &             7 \\
 pubmedqa       &           500 &          495 & 99.0\%    &          119 & 24.0\%    &             3 \\
 medqa          &          1273 &         1156 & 90.8\%    &          302 & 26.1\%    &             4 \\
 afrimedqa      &           298 &          296 & 99.3\%    &           72 & 24.3\%    &             5 \\
 medbullets     &           550 &          192 & 34.9\%    &           84 & 43.8\%    &             5 \\
 medqa\_5options &          1273 &         1156 & 90.8\%    &          357 & 30.9\%    &             5 \\
 mmlu           &          1089 &         1087 & 99.8\%    &          173 & 15.9\%    &             4 \\
 medmcqa        &          2816 &         2736 & 97.2\%    &          913 & 33.4\%    &             4 \\
\bottomrule

## Subsample hard set

In [5]:
datasets = glob.glob('./data/*')
for dataset in datasets:
    name = os.path.basename(dataset)
    hard_file = os.path.join(dataset, 'test_hard.jsonl')
    hard_set = load_jsonl(hard_file)
    sampled_hard_set = random.Random(42).sample(hard_set, 50)
    save_as_jsonl(sampled_hard_set, os.path.join(dataset, 'sampled_50_hard.jsonl'))
