# Age Distribution by Therapy Area in MIMIC4 Demo Dataset

This notebook analyzes the age distribution across different therapy areas using the MIMIC4 demo dataset.

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pyhealth.datasets import MIMIC4Dataset
import warnings
warnings.filterwarnings('ignore')

# Set style for plots
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)

In [None]:
# Load MIMIC4 demo dataset
print("Loading MIMIC4 demo dataset...")
mimic4_demo = MIMIC4Dataset(
    ehr_root="https://physionet.org/files/mimic-iv-demo/2.2/",
    ehr_tables=["diagnoses_icd", "prescriptions"],
    dev=True  # Use dev mode for small subset
)

print(f"Dataset statistics:")
mimic4_demo.stat()

In [None]:
# Extract patient data: age and diagnoses
print("Extracting patient data...")

patient_data = []

for patient in mimic4_demo.iter_patients():
    # Get patient demographics
    patient_info = patient.get_events(event_type="patients")
    if not patient_info:
        continue
    
    # Get anchor_age from MIMIC4 patients table
    age = getattr(patient_info[0], 'anchor_age', None)
    if age is None:
        continue
    
    # Get all diagnoses for this patient
    diagnoses = patient.get_events(event_type="diagnoses_icd")
    
    # Get all ICD codes for this patient (MIMIC4 uses 'icd_code', not 'icd9_code')
    icd_codes = []
    for d in diagnoses:
        # ICD codes are stored in attr_dict
        if hasattr(d, 'attr_dict') and 'icd_code' in d.attr_dict:
            code = d.attr_dict['icd_code']
            if code:
                icd_codes.append(code)
    
    icd_codes = [code for code in icd_codes if code]  # Remove empty
    
    if icd_codes:  # Only include patients with diagnoses
        patient_data.append({
            'patient_id': patient.patient_id,
            'age': int(age),
            'icd_codes': icd_codes
        })

print(f"Extracted data for {len(patient_data)} patients")

# Convert to DataFrame
df = pd.DataFrame(patient_data)
print(f"DataFrame shape: {df.shape}")
print("Sample ICD codes:", df['icd_codes'].iloc[0][:5])
print(df.head())

In [None]:
# Define therapy area mapping based on ICD-10 code categories
print("Setting up therapy area mappings...")

# Define mapping from ICD-10 code prefixes to therapy areas
# Based on standard ICD-10 chapter classifications
def map_icd_to_therapy_area(icd_codes):
    """Map a list of ICD-10 codes to therapy areas based on code categories."""
    areas = set()
    for code in icd_codes:
        try:
            # ICD-10 codes start with letters followed by numbers
            if isinstance(code, str) and len(code) >= 3:
                prefix = code[:3].upper()  # Get first 3 characters
                
                # Map based on ICD-10 categories
                if prefix.startswith('C'):  # Neoplasms (C00-D48)
                    areas.add('Oncology')
                elif prefix.startswith('I') and prefix[1:].isdigit():
                    # Circulatory system (I00-I99)
                    if 20 <= int(prefix[1:]) <= 52:  # I20-I52 covers heart diseases
                        areas.add('Cardiovascular')
                elif prefix.startswith('G'):  # Nervous system (G00-G99)
                    areas.add('Neurology')
                elif prefix.startswith('E'):  # Endocrine/metabolic (E00-E89)
                    areas.add('Metabolic/Endocrine')
                elif prefix.startswith('A') or prefix.startswith('B'):  # Infectious diseases
                    areas.add('Infections Diseases')
                elif prefix.startswith('J'):  # Respiratory system (J00-J99)
                    areas.add('Pulmonary')
                elif prefix.startswith('K'):  # Digestive system (K00-K93)
                    areas.add('Gastrointestinal')
                elif prefix.startswith('N'):  # Genitourinary system (N00-N99)
                    areas.add('Renal')
                elif prefix.startswith('D'):  # Blood/immune (D50-D89)
                    areas.add('Autoimmune/Inflammatory')
                elif prefix.startswith('Q'):  # Congenital anomalies (Q00-Q99)
                    areas.add('Rare Diseases')
                # Add more mappings as needed
            
        except (ValueError, IndexError):
            continue
    return list(areas) if areas else ['Other']

# Test the mapping function
print("Testing mapping function:")
test_codes = ['C50.9', 'I25.10', 'G43.9', 'E11.9', 'J18.9', 'N18.9', 'K29.7']
test_areas = map_icd_to_therapy_area(test_codes)
print(f"Test codes {test_codes} -> {test_areas}")

# Apply mapping to DataFrame
print("Mapping ICD codes to therapy areas...")
df['therapy_areas'] = df['icd_codes'].apply(map_icd_to_therapy_area)

# Since patients can have multiple areas, we'll explode the DataFrame
# to have one row per patient per therapy area
df_exploded = df.explode('therapy_areas')

print(f"After exploding: {df_exploded.shape}")
print(f"Therapy area distribution: {df_exploded['therapy_areas'].value_counts()}")
print(df_exploded.head())

In [None]:
# Create age groups
print("Creating age groups...")

# Define age bins as requested
age_bins = [0, 18, 30, 45, 60, 75, 100]
age_labels = ['0-17', '18-29', '30-44', '45-59', '60-74', '75+']

df_exploded['age_group'] = pd.cut(df_exploded['age'], bins=age_bins, labels=age_labels, right=False)

print("Age group distribution:")
print(df_exploded['age_group'].value_counts().sort_index())

print("\nTherapy area distribution:")
print(df_exploded['therapy_areas'].value_counts())

In [None]:
# Plot histogram of age distribution by therapy area
print("Creating histogram plot...")

# Filter out 'Other' category for cleaner visualization
plot_df = df_exploded[df_exploded['therapy_areas'] != 'Other']

# Create the plot
plt.figure(figsize=(14, 10))

# Create histogram with hue for therapy areas
# Use 'layer' instead of 'stack' to ensure legend appears
g = sns.histplot(
    data=plot_df,
    x='age',
    hue='therapy_areas',
    multiple='layer',
    bins=30,
    alpha=0.7
)

plt.title('Age Distribution by Therapy Area in MIMIC4 Demo Dataset', fontsize=16, fontweight='bold')
plt.xlabel('Age (years)', fontsize=14)
plt.ylabel('Number of Patients', fontsize=14)
# Let seaborn handle the legend automatically
plt.tight_layout()
plt.show()

In [None]:
# Print counts for each therapy area and age group
print("Counts for each therapy area and age group:")
print("=" * 60)

# Create cross-tabulation
cross_tab = pd.crosstab(
    df_exploded['therapy_areas'],
    df_exploded['age_group'],
    margins=True,
    margins_name='Total'
)

print(cross_tab)

print("\n" + "=" * 60)
print("Summary:")
print(f"Total patients: {len(df)}")
print(f"Total patient-therapy area combinations: {len(df_exploded)}")
print(f"Average therapy areas per patient: {len(df_exploded) / len(df):.2f}")
print(f"Age range: {df_exploded['age'].min()}-{df_exploded['age'].max()} years")
print(f"Median age: {df_exploded['age'].median():.1f} years")