In [1]:
import pandas as pd
import json

We use MedQA-Open dataset from the paper "Few shot chain-of-thought driven reasoning to prompt LLMs for open ended medical question answering"  
The dataset available at (ancillary files setion):  
https://arxiv.org/abs/2403.04890  

The author of the paper used the USMLE-MedQA dataset (Jin et al., 2021), a medical exam dataset that consists of questions   
sourced from professional medical board exams in the USA.  


The authors used the MedQA dataset (Zhang et al., 2018) is a publicly available collection of complex medical questions  
with multiple choices based on the United States medical license exams. To emulate real-world medical scenarios,   
they convert these multiple-choice questions into open-ended questions by (1) removing the multiple-choice options and   
(2) re-pharsing the question to be open-ended using LLM, creating MedQA-Open.   

The dataset contains more than 10k questions related to different medical fields, including psychiatry.   
The following script uses LLM to classify questions as psychiatry-related or not psychiatry-related.  
It is a screening step that uses fast cost-efficient LLM to filter out non-psychiatry questions   
before using more expensive LLMs for detailed analysis.  

In [2]:
ORIGINAL_DATASET_PATH = "MedQA_open_dataset.xlsx"
original_df = pd.read_excel(ORIGINAL_DATASET_PATH)
print(len(original_df), "rows in the original dataset")

10178 rows in the original dataset


I used used LLM (gemini 2.0 flash model) with rather simple prompt to screen for psychiatry-related questions. Google was used as provider of LLM since it gives generous Free-tier limit for experimenting with LLM:

In [3]:
question = ""
answer = "" # Just placeholders

prompt = f"""
    Question: {question}
    Answer: {answer}
    Is this question related to psychiatry? Respond with only 'psychiatry' or 'non-psychiatry'."""

If either question or answer column was empty, we would classify question as invalid. We don't provide LLM with empty query.

In [4]:
SCREENED_QUESTIONS_PATH = "MedQA_open_dataset_classified.xlsx"
screened_df = pd.read_excel(SCREENED_QUESTIONS_PATH)
print(len(screened_df), "rows in the screened dataset")
screened_df["classification"].value_counts()

10178 rows in the screened dataset


classification
non-psychiatry    8780
psychiatry         886
invalid            512
Name: count, dtype: int64

After initial screening, only 886 question were considered as psychiatry-related. However, the prompt was weak and model was selected due to speed rather than precision. A the next step we verified that question actually asks about psychiatry, not just mentions psychiatric concepts in the different cases vignettes. To do so, we prompted a newer model gemini 2.5 flash with the folllowing prompt:

In [5]:
prompt = f"""
    Act as an experienced clinical psychiatrist and medical educator. 

    Question: {question}

    Evaluate provided question and reasoning based on the following criteria:

    CLINICAL PSYCHIATRY FOCUS: Is this question primarily testing knowledge of clinical psychiatry, mental health disorders, psychiatric treatments, or psychological concepts? 
    - Questions that merely mention mental health terms but primarily test other medical knowledge (like diabetes, cardiology, etc.) should be excluded
    - Questions should focus on psychiatric diagnosis, treatment, symptoms, or mental health concepts as the main learning objective

    Provide your response in the following JSON format:
    {{
        "classification": "INCLUDE" or "EXCLUDE",
        "reasoning": "Brief explanation of why you included or excluded this question"
    }}

    Classification options:
    - "INCLUDE" if the question is primarily focused on clinical psychiatry AND the reasoning is useful
    - "EXCLUDE" if the first criteria fail
    """

In [6]:
with open("MedQA_open_dataset_classified_psychiatry_evaluation.json", 'r', encoding='utf-8') as file:
    data = json.load(file)

verified_df = pd.DataFrame(data)
print("Classification value counts:")
print(verified_df["psychiatry_classification"].value_counts())


Classification value counts:
psychiatry_classification
include    737
exclude    147
re-do        2
Name: count, dtype: int64


Overall, we have 737 psychiatry-related questions. Now we prompt an LLM with a task to classify psychiatric question into the one of categories. 

In [None]:
prompt = f"""
    Act as an experienced mental health specialist. Your task is to classify a provided psychiatric question.

    Question: "{question}"

    Instructions:
    1. Carefully analyze the question to identify the mental health condition or disorder being discussed
    2. Classify it into ONE of the following categories:
       • Anxiety Disorders
       • Bipolar Disorders  
       • Depressive Disorders
       • Dissociative Disorders
       • Eating Disorders
       • Obsessive-Compulsive Disorders
       • Personality Disorders
       • Schizophrenia Spectrum and Other Psychotic Disorders
       • Somatic Disorders
       • Trauma and Stressor Related Disorders
       • Other Mental Disorders

    3. Provide your response in the following JSON format:
    {{
        "reasoning": "Brief explanation of why this question fits the selected category, including any key symptoms, conditions, or diagnostic criteria mentioned",
        "category": "Selected category name exactly as listed above",
        "confidence": "high/medium/low"
    }}

    Important notes:
    - If the question doesn't clearly fit into categories 1-10, use "Other Mental Disorders"
    - Focus on the primary disorder being discussed
    - Use exact category names as provided
    - Be concise but thorough in your reasoning
    """

In [None]:
prompt = f"""
    Act as an experienced mental health specialist. Your task is to classify a provided psychiatric question.

    Question: "{question}"

    Instructions:
    1. Carefully analyze the question to identify the mental health condition or disorder being discussed
    2. Classify it into ONE of the following categories:
       • Anxiety Disorders
       • Bipolar Disorders  
       • Depressive Disorders
       • Dissociative Disorders
       • Eating Disorders
       • Obsessive-Compulsive Disorders
       • Personality Disorders
       • Schizophrenia Spectrum and Other Psychotic Disorders
       • Somatic Disorders
       • Trauma and Stressor Related Disorders
       • Other Mental Disorders

    3. Provide your response in the following JSON format:
    {{
        "reasoning": "Brief explanation of why this question fits the selected category, including any key symptoms, conditions, or diagnostic criteria mentioned",
        "category": "Selected category name exactly as listed above",
        "confidence": "high/medium/low"
    }}

    Important notes:
    - If the question doesn't clearly fit into categories 1-10, use "Other Mental Disorders"
    - Focus on the primary disorder being discussed
    - Use exact category names as provided
    - Be concise but thorough in your reasoning
    """

In [7]:
with open("MedQA_open_dataset_classified_psychiatry_evaluation_with_categories.json", 'r', encoding='utf-8') as file:
    data = json.load(file)

df_with_categories = pd.DataFrame(data)
category_counts = df_with_categories["psychiatric_category"].value_counts().to_dict()
print("Percentage of categories:")
for category, count in category_counts.items():
            percentage = (count / len(df_with_categories)) * 100
            print(f"{category:<40} {count:>4} ({percentage:>5.1f}%)")



Percentage of categories:
Other Mental Disorders                    220 ( 29.9%)
Depressive Disorders                      115 ( 15.6%)
Schizophrenia Spectrum and Other Psychotic Disorders  107 ( 14.5%)
Bipolar Disorders                          73 (  9.9%)
Anxiety Disorders                          52 (  7.1%)
Trauma and Stressor Related Disorders      50 (  6.8%)
Personality Disorders                      45 (  6.1%)
Eating Disorders                           24 (  3.3%)
Somatic Disorders                          24 (  3.3%)
Obsessive-Compulsive Disorders             22 (  3.0%)
Dissociative Disorders                      5 (  0.7%)


I need to split a dataset 50:50. One half will be used to produce final results; the other half will be additionally split to 30:20 to play with evaluation and tweak it a little. 

In [8]:
# Randomly shuffle and split the data 50:50
from sklearn.model_selection import train_test_split

# Set random_state for reproducibility
train_dataset, test_dataset = train_test_split(df_with_categories, test_size=0.5, random_state=42)

print(f"Test dataset: {len(test_dataset)}")
print(f"Train dataset: {len(train_dataset)}")

category_counts = test_dataset["psychiatric_category"].value_counts().to_dict()
print("Percentage of categories in test dataset:")
for category, count in category_counts.items():
            percentage = (count / len(test_dataset)) * 100
            print(f"{category:<40} {count:>4} ({percentage:>5.1f}%)")

Test dataset: 369
Train dataset: 368
Percentage of categories in test dataset:
Other Mental Disorders                    114 ( 30.9%)
Depressive Disorders                       62 ( 16.8%)
Schizophrenia Spectrum and Other Psychotic Disorders   57 ( 15.4%)
Bipolar Disorders                          33 (  8.9%)
Trauma and Stressor Related Disorders      24 (  6.5%)
Personality Disorders                      23 (  6.2%)
Anxiety Disorders                          22 (  6.0%)
Eating Disorders                           13 (  3.5%)
Somatic Disorders                          10 (  2.7%)
Obsessive-Compulsive Disorders             10 (  2.7%)
Dissociative Disorders                      1 (  0.3%)


In [9]:
data = test_dataset.to_dict('records')
with open("test_dataset.json", 'w', encoding='utf-8') as f:
    json.dump(data, f, indent=2, ensure_ascii=False)

In [10]:
# Set random_state for reproducibility
train_dataset, val_dataset = train_test_split(train_dataset, test_size=0.2, random_state=42)

print(f"Val dataset: {len(val_dataset)}")
print(f"Train dataset: {len(train_dataset)}")

category_counts = val_dataset["psychiatric_category"].value_counts().to_dict()
print("Percentage of categories in val dataset:")
for category, count in category_counts.items():
            percentage = (count / len(val_dataset)) * 100
            print(f"{category:<40} {count:>4} ({percentage:>5.1f}%)")


category_counts = train_dataset["psychiatric_category"].value_counts().to_dict()
print("********************************")
print("Percentage of categories in train dataset:")
for category, count in category_counts.items():
            percentage = (count / len(train_dataset)) * 100
            print(f"{category:<40} {count:>4} ({percentage:>5.1f}%)")

Val dataset: 74
Train dataset: 294
Percentage of categories in val dataset:
Other Mental Disorders                     25 ( 33.8%)
Depressive Disorders                       13 ( 17.6%)
Schizophrenia Spectrum and Other Psychotic Disorders    8 ( 10.8%)
Trauma and Stressor Related Disorders       7 (  9.5%)
Bipolar Disorders                           5 (  6.8%)
Anxiety Disorders                           4 (  5.4%)
Obsessive-Compulsive Disorders              3 (  4.1%)
Eating Disorders                            3 (  4.1%)
Personality Disorders                       3 (  4.1%)
Somatic Disorders                           2 (  2.7%)
Dissociative Disorders                      1 (  1.4%)
********************************
Percentage of categories in train dataset:
Other Mental Disorders                     81 ( 27.6%)
Schizophrenia Spectrum and Other Psychotic Disorders   42 ( 14.3%)
Depressive Disorders                       40 ( 13.6%)
Bipolar Disorders                          35 ( 11.9%)

In [11]:
data = train_dataset.to_dict('records')
with open("train_dataset.json", 'w', encoding='utf-8') as f:
    json.dump(data, f, indent=2, ensure_ascii=False)

data = val_dataset.to_dict('records')
with open("val_dataset.json", 'w', encoding='utf-8') as f:
    json.dump(data, f, indent=2, ensure_ascii=False)