# MIMIC 3 Dataset

In [None]:
import pandas as pd 
import pyspark
import numpy as np
import matplotlib.pyplot as plt


## utils 
import cleaning_utils

## read the datasets 

patients = pd.read_csv('PATIENTS.csv')
admissions = pd.read_csv('ADMISSIONS.csv')
diagnosis = pd.read_csv('DIAGNOSES_ICD.csv')

### 1. Patients

* Cols: subject ID, gender, date of birth, date of death, dead of death recorded in hospital records, date of death from Social Security records, expire flag (death indicator at discharge: 1 for dead, 0 for alive at discharge)

* Expire flag helps with mortality analysis

* Data has been deidentified: dates are shifted to provide privacy, relative intervals are kept in tact

patients.head()

## consolidating the three DOD columns

## explore 
DOD_cols = [col for col in patients.columns if col.startswith('DOD')]
print(DOD_cols)
patients[patients[DOD_cols].notna().any(axis = 1)][DOD_cols]

## method: prioritise the DOD_HOSP column, then fill missing with DOD_SNN and DOD 


## Clean patients 
patients = cleaning_utils.convert_to_time_patients(patients)
patients_cleaned = cleaning_utils.handle_DOD_columns(patients)


## Gender split 
patients.GENDER.value_counts().plot(kind = 'bar', title = 'Gender split')

## Check 

## Patients with expire flag == 1: are all DODs entries filled for them?
def detect_DOD_missing(patients_cleaned):
    '''
    checking if there are patients that have passed away (expire flag == 1) that have no consolidated DOD
    '''
    missing_DOD = patients_cleaned[patients_cleaned.EXPIRE_FLAG == 1].DOD_consolidated.isna().sum()
    if missing_DOD > 0:
        print('There are missing DOD dates for patients who have expire flag == 1')
    else:
        print('No missing DOD, good to proceed')

detect_DOD_missing(patients_cleaned)

## Get age 

patients_cleaned = cleaning_utils.get_age_via_extraction(patients_cleaned,
                       # .loc[patients_cleaned['EXPIRE_FLAG'] == 1], 
                       'DOD_consolidated', 'DOB', 'age_at_death')


patients_cleaned.age_at_death.plot(kind = 'hist', title = 'Age at death', xlabel = 'Age')
plt.show()

### 2. Admission

* one row per hospital admission (each patient can have more than one row)
* admin and clinical metadata about each stay
* cols: admission/ discharge dates, admission type, language, insurance, ethnictiy, religion, ED times, diagnosis, hospital expire, chart events
* **admission id (HADM_ID) allows joins/connections to diagnoses, procedures, lab events, notes**
* ethnicity can be used for subgroup analysis
* insurance, language, marital status can be used for socioeconomic research

admissions.head()

admissions.dtypes

admissions['admission_type'.upper()].value_counts()

admissions['admission_location'.upper()].value_counts()

## check invalid entries -- discharge time is less than admission time 
(admissions['DISCHTIME'] < admissions['ADMITTIME']).value_counts()
incorrect_ad_times_mask = admissions['DISCHTIME'] < admissions['ADMITTIME']

## scrutinise 
admissions[incorrect_ad_times_mask]

## Clean admissions
admissions_cleaned = cleaning_utils.full_clean_admissions(admissions)

admissions_cleaned.LOS.dt.days.plot(kind = 'box', title = 'Length of stay (days)', ylabel = 'Days')
plt.show()

### Diagnosis

diagnosis

diagnosis_cleaned = cleaning_utils.icd9_chapters(diagnosis)
    

diagnosis_cleaned

## Check common diseases

## Across all sequence numbers 
print(diagnosis_cleaned.groupby('icd_group')['ROW_ID'].count().sort_values(ascending= False))
print('=='*20)

## Which is usually the main diagnosis (i.e. lower sequence number)
print(diagnosis_cleaned.groupby('icd_group')['SEQ_NUM'].mean().sort_values(ascending= True))
print('=='*20)

## Which is the most common disease per sequence number
print(pd.DataFrame(diagnosis_cleaned.groupby('SEQ_NUM')['icd_group'].agg(lambda x: x.value_counts().idxmax())).head())
print('=='*20)



# Transform

## Join patients and admissions
patients_admissions = admissions_cleaned.merge(patients, on = 'SUBJECT_ID', 
                                              suffixes=('_admissions', '_patients'))

patients_admissions = patients_admissions.drop(columns=['ROW_ID' + '_admissions', 'ROW_ID' + '_patients'])

## Generate some useful columns 

## Create age during admission 

patients_admissions = cleaning_utils.get_age_via_extraction(patients_admissions,'ADMITTIME', 'DOB', 'age_during_admission')
patients_admissions = cleaning_utils.age_categorise(patients_admissions)


## Check 
patients_admissions.age_during_admission.plot(kind = 'hist', title = 'Age at the point of admission', xlabel = 'Years')
patients_admissions.age_during_admission.describe()
## Recall that we are seeing age = 300 due to HIPAA requirement for masking the age of those that are 90+

## Closer examination of those of unusually high age 
patients_admissions[patients_admissions.age_during_admission > 200].age_during_admission.value_counts().sort_index()

## implement 
patients_admissions = pull_back_deid_ages(patients_admissions)

## check ages 
patients_admissions['age_during_admission'].plot(kind = 'hist', title = 'Post age cleaning: Age at the point of admission', xlabel = 'Years')
plt.show()

## Explore neonates

## Large number of age_during_admission = 0
(patients_admissions.age_during_admission == 0).value_counts()

## Ensure that it makes sense --> check the reasons for admission
neonates_mask = patients_admissions.age_during_admission == 0
patients_admissions[neonates_mask].ADMISSION_TYPE.value_counts().plot(kind = 'bar', title = 'Neonates (age = 0) admission types')
plt.show()

## Merge in PRIMARY DIAGNOSIS ONLY 

primary_diagnosis = diagnosis_cleaned[diagnosis_cleaned.SEQ_NUM == 1].rename({'ICD9_CODE': 'PRI_ICD9', 'icd_group':'pri_icd_group'})
print(primary_diagnosis.shape)

## Merge 
print(f'before: {patients_admissions.shape}')
patients_admissions_diagnosis = patients_admissions.merge(primary_diagnosis, on = 'HADM_ID', suffixes = ('', '_diagnosis'), how = 'left')
print(f'after: {patients_admissions_diagnosis.shape}') ## ensure no loss of admissions information


## Exploring/ Analysis

import duckdb

con = duckdb.connect()

con.register("patients_admissions", patients_admissions)

## check cols 
patients_admissions.columns

### Number of admissions per patient

con.sql('''
    SELECT SUBJECT_ID, count(distinct HADM_ID) as admission_count 
    FROM patients_admissions
    group by SUBJECT_ID
    order by admission_count desc 
''').show()

## Finding key stats of number of visits per patient 

con.sql('''
with PATIENT_ADM_COUNT AS (
    SELECT SUBJECT_ID, count(distinct HADM_ID) as admission_count 
    FROM patients_admissions
    group by SUBJECT_ID
    order by admission_count desc)
select 
    avg(admission_count) as average_count, 
    max(admission_count) as max_count, 
    min(admission_count) as min_count, 
    mode(admission_count) as mode_count 
FROM PATIENT_ADM_COUNT
''').show()

Most patients only visit once --> however, the most a patient has been admitted is 42 times.

## Most common diseases per ethnicity?

top_diseases_per_ethnicity = patients_admissions_diagnosis.groupby('ETHNICITY_CATEGORISED')['DIAGNOSIS'].agg(lambda x: x.value_counts().sort_values(ascending = False).head(5).index)\
        .apply(pd.Series)

top_diseases_per_ethnicity.columns = [f'Top {i}' for i in range(1,6)]

top_diseases_per_ethnicity

top_diseases_per_ethnicity = patients_admissions_diagnosis.groupby('ETHNICITY_CATEGORISED')['icd_group'].agg(lambda x: x.value_counts().sort_values(ascending = False).head(5).index)\
        .apply(pd.Series)

top_diseases_per_ethnicity.columns = [f'Top {i}' for i in range(1,6)]

top_diseases_per_ethnicity

patients_admissions_diagnosis.columns

## Mortality rates 

patients_admissions_diagnosis.groupby('icd_group')['HOSPITAL_EXPIRE_FLAG'].mean().sort_values(ascending=False).head(20)

patients_admissions_diagnosis.groupby('ETHNICITY')['HOSPITAL_EXPIRE_FLAG'].mean().sort_values(ascending=False).head(5)

patients_admissions_diagnosis.groupby('ETHNICITY_CATEGORISED')['HOSPITAL_EXPIRE_FLAG'].mean().sort_values(ascending=False).head(5)

## Age group analysis

pd.DataFrame(patients_admissions_diagnosis.groupby('age_group')['HOSPITAL_EXPIRE_FLAG'].mean() * 100).rename(columns = {'HOSPITAL_EXPIRE_FLAG':'Hosp Mortality Rate'})


pd.DataFrame(patients_admissions_diagnosis.groupby('age_group')['icd_group'].agg(lambda x: x.value_counts().sort_values(ascending = False).head(3).idxmax()))

patients_admissions_diagnosis.groupby('age_group')['DIAGNOSIS'].agg(lambda x: x.value_counts().sort_values(ascending = False).head(3).idxmax())

patients_admissions_diagnosis.columns