# MIMIC 3 Dataset

In [None]:
import pandas as pd 
import pyspark
import numpy as np
import matplotlib.pyplot as plt

## utils 
import cleaning_utils

In [None]:
## read the datasets 

raw_dir = 'data/raw/'

patients = pd.read_csv(raw_dir + 'PATIENTS.csv')
admissions = pd.read_csv(raw_dir + 'ADMISSIONS.csv')
diagnosis = pd.read_csv(raw_dir + 'DIAGNOSES_ICD.csv')

In [None]:
from cleaning_utils import standardise_col_names

In [None]:
## Run basic column cleaning for dfs 

patients = standardise_col_names(patients)
admissions = standardise_col_names(admissions)
diagnosis = standardise_col_names(diagnosis)

In [None]:
## Check 
patients.columns

### 1. Patients

* Cols: subject ID, gender, date of birth, date of death, dead of death recorded in hospital records, date of death from Social Security records, expire flag (death indicator at discharge: 1 for dead, 0 for alive at discharge)

* Expire flag helps with mortality analysis

* Data has been deidentified: dates are shifted to provide privacy, relative intervals are kept in tact

In [None]:
patients.head()

In [None]:
## consolidating the three DOD columns

## explore 
DOD_cols = [col for col in patients.columns if col.startswith('dod')]
print(DOD_cols)
patients[patients[DOD_cols].notna().any(axis = 1)][DOD_cols]

## method: prioritise the DOD_HOSP column, then fill missing with DOD_SNN and DOD 


In [None]:
## Clean patients 
patients = cleaning_utils.convert_to_time_patients(patients)
patients_cleaned = cleaning_utils.handle_DOD_columns(patients)


In [None]:
## Gender split 
patients.gender.value_counts().plot(kind = 'bar', title = 'Gender split')
plt.show()

There is a higher proportion of makes

In [None]:
## Get age 

patients_cleaned = cleaning_utils.get_age_via_extraction(patients_cleaned,
                       # .loc[patients_cleaned['EXPIRE_FLAG'] == 1], 
                       'DOD_consolidated', 'dob', 'age_at_death')


In [None]:
patients_cleaned.age_at_death.plot(kind = 'hist', title = 'Age at death', xlabel = 'Age', 
                                  bins = 15)
## add more x axis ticks
plt.xticks(range(0, 300, 20))

plt.show()

In [None]:
## Expire flag 
patients_cleaned.expire_flag.value_counts().plot(kind = 'bar', title = 'expire = 1: died in hopsital,  expire = 0: alive at discharge')
plt.suptitle('Patient Expire Flag (Death count)')
plt.show()

A very large number that are between 60 and 80, and mostly above 40. 

Most (double) also are alive at discharge

#### Sanity check

In [None]:
## Check 

## Patients with expire flag == 1: are all DODs entries filled for them?
def detect_DOD_missing(patients_cleaned):
    '''
    checking if there are patients that have passed away (expire flag == 1) that have no consolidated DOD
    '''
    missing_DOD = patients_cleaned[patients_cleaned.expire_flag == 1].DOD_consolidated.isna().sum()
    if missing_DOD > 0:
        print('There are missing DOD dates for patients who have expire flag == 1')
    else:
        print('No missing DOD, good to proceed')

In [None]:
detect_DOD_missing(patients_cleaned)

### 2. Admission

* one row per hospital admission (each patient can have more than one row)
* admin and clinical metadata about each stay
* cols: admission/ discharge dates, admission type, language, insurance, ethnictiy, religion, ED times, diagnosis, hospital expire, chart events
* **admission id (HADM_ID) allows joins/connections to diagnoses, procedures, lab events, notes**
* ethnicity can be used for subgroup analysis
* insurance, language, marital status can be used for socioeconomic research

In [None]:
admissions.head()

In [None]:
admissions.dtypes

In [None]:
admissions['admission_type'].value_counts().plot(kind = 'pie', title='Admissions type distribution')
plt.show()

Most admissions are due to emergency

In [None]:
admissions['admission_location'].value_counts().plot(kind = 'bar', title = 'Admission location distrubution')
plt.show()

Most are acute, unscheduled care -- e.g. chest pain, trauma, stroke

HMO: referred by insurance 

#### Sanity check

In [None]:
## check invalid entries -- discharge time is less than admission time 
(admissions['dischtime'] < admissions['admittime']).value_counts()
incorrect_ad_times_mask = admissions['dischtime'] < admissions['admittime']

## print
print(f'Proportion of incorrect disch rel. to admit times: {incorrect_ad_times_mask.mean()*100:.2f}%')

## scrutinise 
admissions[incorrect_ad_times_mask].head()

In [None]:
## Clean admissions
admissions_cleaned = cleaning_utils.full_clean_admissions(admissions)

In [None]:
admissions_cleaned.los_admission.dt.days.plot(kind = 'box', title = 'Length of stay (days)', ylabel = 'Days')
plt.show()

In [None]:
## More
admissions_cleaned.los_admission.describe()

Most hospital days are at least a few days

In [None]:
## Diseases with longest and shortest admission stays
admissions_cleaned.groupby('diagnosis')['los_admission'].mean().dt.days.sort_values(ascending= False)

### 3. Diagnosis

In [None]:
diagnosis

diagnosis_cleaned = cleaning_utils.icd9_chapters(diagnosis)
    

In [None]:
diagnosis_cleaned.head()

In [None]:
print(f'Average number of diagnoses per admission {diagnosis_cleaned.groupby('hadm_id')['row_id'].count().mean():.1f}') 


Understand the dataset: a patient can have one admission but multiple diagnoses
* Hence, subset for the first diagnosis to not cause a one to many

In [None]:
## Check common diseases

## Across all sequence numbers
print('ROW COUNT PER ICD GROUP =========')
print(diagnosis_cleaned.groupby('icd_group')['row_id'].count().sort_values(ascending= False))
print('=='*20)

## Which is usually the main diagnosis (i.e. lower sequence number)
print('AVERAGE SEQ NUMBER =========')
print(diagnosis_cleaned.groupby('icd_group')['seq_num'].mean().sort_values(ascending= True))
print('=='*20)

## Which is the most common disease per sequence number
print('MOST COMMON SEQ NUMBER =========')
print(pd.DataFrame(diagnosis_cleaned.groupby('seq_num')['icd_group'].agg(lambda x: x.value_counts().idxmax())).head())
print('=='*20)



# Transform

In [None]:
## Join patients and admissions
patients_admissions = admissions_cleaned.merge(patients, on = 'subject_id', 
                                              suffixes=('_admissions', '_patients'))

patients_admissions = patients_admissions.drop(columns=['row_id' + '_admissions', 'row_id' + '_patients'])

In [None]:
## sanity checks 
print(admissions_cleaned.shape)
print(patients_admissions.shape)

In [None]:
## Generate some useful columns 

## Create age during admission 
patients_admissions = cleaning_utils.get_age_via_extraction(patients_admissions,'admittime', 'dob', 'age_during_admission')
## Categorise the age into bins 
patients_admissions = cleaning_utils.age_categorise(patients_admissions)


In [None]:
## Check 
patients_admissions.age_during_admission.plot(kind = 'hist', title = 'Age at the point of admission', xlabel = 'Years', 
                                             bins = 15)
plt.xticks(np.arange(0, 300+20, 20))
plt.show()
patients_admissions.age_during_admission.describe()
## Recall that we are seeing age = 300 due to HIPAA requirement for masking the age of those that are 90+

In [None]:
## Closer examination of those of unusually high age 
patients_admissions[patients_admissions.age_during_admission > 200].age_during_admission.value_counts().sort_index()

In [None]:
## implement 
patients_admissions = cleaning_utils.pull_back_deid_ages(patients_admissions)

## check ages 
patients_admissions['age_during_admission'].plot(kind = 'hist', title = 'Post age cleaning: Age at the point of admission', xlabel = 'Years')
plt.show()

In [None]:
## add visit order number per patient 
patients_admissions['nth_visit'] = patients_admissions.groupby('SUBJECT_ID'.lower())\
                                    ['ADMITTIME'.lower()].rank(method = 'first').astype(int)
patients_admissions['latest_visit'] = patients_admissions.groupby('SUBJECT_ID'.lower())\
                                    ['ADMITTIME'.lower()].rank(method = 'first', ascending = False).astype(int)
patients_admissions['latest_visit'] = (patients_admissions['latest_visit'] == 1).astype(int)

## check 
example_id = patients_admissions.subject_id.iloc[1]
patients_admissions.loc[patients_admissions.subject_id == example_id, 
                        ['SUBJECT_ID'.lower(),'ADMITTIME'.lower(), 'HADM_ID'.lower(), 'nth_visit', 'latest_visit']]

### Sanity checks

In [None]:
# Lowercase all column names
patients_admissions.columns = patients_admissions.columns.str.lower()

## since patient can have more than 1 admission, filter for latest_visit
latest_visit_mask = patients_admissions.latest_visit == 1
latest_visit_expire_flag = patients_admissions.loc[latest_visit_mask, 'expire_flag']
latest_visit_hosp_expire_flag = patients_admissions.loc[latest_visit_mask, 'hospital_expire_flag']

## ALL EQUAL?
expire_flags_align = np.all(latest_visit_expire_flag == latest_visit_hosp_expire_flag)
print(f'Do expire flags align (between patients expire flag and admissions expire flag?: {expire_flags_align}')

## check where not equal 
patients_admissions[latest_visit_mask][latest_visit_expire_flag != latest_visit_hosp_expire_flag][[
    'subject_id', 'hadm_id', 'admittime', 'expire_flag', 'hospital_expire_flag', 'dod_consolidated', 'deathtime'
]].head()

Finding: we see that DEATHTIME (admissions) does not align with date of death (patients) --> it is likely that patients data can be updated post admission, hence a patient could be 'alive' in the admissions df but 'dead' in the patients df 

In [None]:
## Subset for main diagnosis 
main_diagnosis = diagnosis_cleaned[diagnosis_cleaned.seq_num == 1]

## Merge diagnosis 
print(patients_admissions.shape[0])
patients_admissions_diagnosis = patients_admissions.merge(main_diagnosis, on = ['subject_id', 'hadm_id'], how  = 'left')
print(patients_admissions_diagnosis.shape[0])

In [None]:
## Large number of age_during_admission = 0
(patients_admissions.age_during_admission == 0).value_counts()

## Ensure that it makes sense --> check the reasons for admission
neonates_mask = patients_admissions.age_during_admission == 0
patients_admissions[neonates_mask].ADMISSION_TYPE.value_counts().plot(kind = 'bar', title = 'Neonates (age = 0) admission types')
plt.show()

In [None]:
## Merge in PRIMARY DIAGNOSIS ONLY 

primary_diagnosis = diagnosis_cleaned[diagnosis_cleaned.SEQ_NUM == 1].rename({'ICD9_CODE': 'PRI_ICD9', 'icd_group':'pri_icd_group'})
print(primary_diagnosis.shape)

## Merge 
print(f'before: {patients_admissions.shape}')
patients_admissions_diagnosis = patients_admissions.merge(primary_diagnosis, on = 'HADM_ID', suffixes = ('', '_diagnosis'), how = 'left')
print(f'after: {patients_admissions_diagnosis.shape}') ## ensure no loss of admissions information


# Exploring/ Analysis

In [None]:
import duckdb

con = duckdb.connect()

In [None]:
con.register("patients_admissions", patients_admissions)
con.register("patients_admissions_diagnosis", patients_admissions_diagnosis)

## check cols 
patients_admissions.columns

### Explore neonates

### Number of admissions per patient

In [None]:
con.sql('''
    SELECT SUBJECT_ID, count(distinct HADM_ID) as admission_count 
    FROM patients_admissions
    group by SUBJECT_ID
    order by admission_count desc 
''').show()

In [None]:
## Finding key stats of number of visits per patient 

con.sql('''
with PATIENT_ADM_COUNT AS (
    SELECT SUBJECT_ID, count(distinct HADM_ID) as admission_count 
    FROM patients_admissions
    group by SUBJECT_ID
    order by admission_count desc)
select 
    avg(admission_count) as average_count, 
    max(admission_count) as max_count, 
    min(admission_count) as min_count, 
    mode(admission_count) as mode_count 
FROM PATIENT_ADM_COUNT
''').show()

Most patients only visit once --> however, the most a patient has been admitted is 42 times.

## Most common diseases per ethnicity?

In [None]:
top_diseases_per_ethnicity = patients_admissions_diagnosis.groupby('ethnicity_categorised')['diagnosis'].agg(lambda x: x.value_counts().sort_values(ascending = False).head(5).index)\
        .apply(pd.Series)

top_diseases_per_ethnicity.columns = [f'Top {i}' for i in range(1,6)]

top_diseases_per_ethnicity

In [None]:
top_diseases_per_ethnicity = patients_admissions_diagnosis.groupby('ethnicity_categorised')['icd_group'].agg(lambda x: x.value_counts().sort_values(ascending = False).head(5).index)\
        .apply(pd.Series)

top_diseases_per_ethnicity.columns = [f'Top {i}' for i in range(1,6)]

top_diseases_per_ethnicity

## Mortality rates 

In [None]:
pd.DataFrame(patients_admissions_diagnosis.groupby('icd_group')['hospital_expire_flag'].mean().\
             sort_values(ascending=False).head(10) * 100).round(2)\
            .rename(columns = {'hospital_expire_flag': 'Hospital Mortality Rate (%)'})

In [None]:
pd.DataFrame(patients_admissions_diagnosis.groupby('ethnicity_categorised')['hospital_expire_flag'].mean()\
             .sort_values(ascending=False).head(10) * 100).round(2)\
             .rename(columns = {'hospital_expire_flag':'Hospital Mortality Rate (%)'})


## Age group analysis

In [None]:
pd.DataFrame(patients_admissions_diagnosis.groupby('age_group')['hospital_expire_flag'].mean() * 100).rename(columns = {'HOSPITAL_EXPIRE_FLAG':'Hosp Mortality Rate'})


In [None]:
pd.DataFrame(patients_admissions_diagnosis.groupby('age_group')['icd_group'].agg(lambda x: x.value_counts().sort_values(ascending = False).head(3).idxmax()))

In [None]:
pd.DataFrame(patients_admissions_diagnosis.groupby('age_group')['diagnosis'].agg(lambda x: x.value_counts().sort_values(ascending = False).head(3).idxmax()))

In [None]:
patients_admissions_diagnosis['LOS_days'] = patients_admissions_diagnosis['los'].dt.days
pd.DataFrame(patients_admissions_diagnosis.groupby('age_group')['los_days'].mean()).rename(columns={'los_days':'Average Length of Stay (Days)'})

In [None]:
patients_admissions_diagnosis.columns

## Readmission analysis

In [None]:
## find the number of days between each patient's admission 
## then find rows where the days between is less than 30

readmission_30_days = con.sql(
    '''
with admissions_lag as (
    select SUBJECT_ID, HADM_ID, date(ADMITTIME) as ADMITTIME, 
        lag(DATE(ADMITTIME)) OVER (PARTITION BY SUBJECT_ID ORDER BY ADMITTIME) as previous_admit
    FROM patients_admissions_diagnosis
--    where SUBJECT_ID = 13033
    ORDER BY SUBJECT_ID, ADMITTIME desc), 
admissions_subset as (
select *,
    date_diff('day',  previous_admit, ADMITTIME) AS days_between_admit
    from admissions_lag
    where date_diff('day', previous_admit, ADMITTIME) <= 30)
select 
    p.SUBJECT_ID, p.HADM_ID, days_between_admit,
    icd_group, DIAGNOSIS
    from patients_admissions_diagnosis p 
    inner join admissions_subset a on
        p.HADM_ID = a.HADM_ID
        
    ''').df()

In [None]:
readmission_30_days.head()

### Most common ICD group

In [None]:
print(readmission_30_days.icd_group.value_counts().head(10))

### Most common diagnosis

In [None]:
print(readmission_30_days.DIAGNOSIS.value_counts().head(10))

### Most immediate readmission 

In [None]:
readmission_30_days.groupby('DIAGNOSIS')['days_between_admit'].mean().sort_values(ascending = True).head(10)

## Comorbidity Profiling 

# Save and store the cleaned data 

In [None]:
import os 

os.makedirs('data/cleaned', exist_ok=True)
os.makedirs('data/curated', exist_ok=True)

In [None]:
## Clean data 
patients_cleaned.to_parquet("data/cleaned/patients_cleaned.parquet")
admissions_cleaned.to_parquet("data/cleaned/admissions_cleaned.parquet")
diagnosis_cleaned.to_parquet("data/cleaned/diagnosis_cleaned.parquet")

## Combined/transformed data 
patients_admissions.to_parquet("data/curated/patients_admissions.parquet")
patients_admissions_diagnosis.to_parquet("data/curated/patients_adm_diag.parquet")
