Here we study the preproc file with the MIMIC data sampled by Simen

In [2]:
# Importing libraries
import pandas as pd
import numpy as np

In [3]:
# Loading full dataset
data = pd.read_csv(r"Dataset_MIMIC_Sample_SImen.csv")
data.head(5)

Unnamed: 0,subject_id,admission_id,gender,age,admittime_var,len_admission_days,len_life_after_admission_days,no_unique_admissions,main_diagnosis,all_other_diag,outcome
0,3,145834,1,76.57534,20oct2101 19:08:14,10.783099,225.41957,1,389,"78559, 5849, 4275, 41071, 4280, 6826, 4254, 2639",died
1,4,185777,0,47.876713,16mar2191 00:26:27,7.761161,,1,42,"1363, 7994, 2763, 7907, 5715, 04111, V090, E9317",not died
2,11,194540,0,50.18082,16apr2178 06:21:37,25.5287,186.2072,1,1913,,died
3,13,143045,0,39.890411,08jan2167 18:43:28,6.857007,,1,41401,"4111, 25000, 4019, 2720",not died
4,17,161087,0,47.849316,09may2135 14:11:23,4.017114,,1,4239,"5119, 78551, 4589, 311, 7220, 71946, 2724",not died


In [4]:
data.shape

(37762, 11)

In [5]:
# Checking the columns name
data.keys()

Index(['subject_id', 'admission_id', 'gender', 'age', 'admittime_var',
       'len_admission_days', 'len_life_after_admission_days',
       'no_unique_admissions', 'main_diagnosis', 'all_other_diag', 'outcome'],
      dtype='object')

In [6]:
# Function for decoding gender
def convert_gender(gender_code):
    if gender_code == 1:
        return "male"
    else:
        return "female"
    
relevant_columns = ['gender', 'age_admission', 'admission_date', 'len_admission_days', 'all_other_diag', 'main_diagnosis', 'subject_id','outcome']
def process_df(df, relevant_columns=relevant_columns):
    df['gender'] = df['gender'].apply(convert_gender)
    df['age_admission'] = df['age'].round().astype("int")
    df['admission_date'] = pd.to_datetime(df['admittime_var']).dt.strftime('%d %b %Y')
    df['len_admission_days'] = df['len_admission_days'].round().astype("int")
    df['all_other_diag'] = df['all_other_diag'].apply(lambda x: ', '.join(str(x).split(', ')))

    new_df = df[relevant_columns]
    return new_df

df = process_df(data)
df.head(5)

  df['admission_date'] = pd.to_datetime(df['admittime_var']).dt.strftime('%d %b %Y')


Unnamed: 0,gender,age_admission,admission_date,len_admission_days,all_other_diag,main_diagnosis,subject_id,outcome
0,male,77,20 Oct 2101,11,"78559, 5849, 4275, 41071, 4280, 6826, 4254, 2639",389,3,died
1,female,48,16 Mar 2191,8,"1363, 7994, 2763, 7907, 5715, 04111, V090, E9317",42,4,not died
2,female,50,16 Apr 2178,26,,1913,11,died
3,female,40,08 Jan 2167,7,"4111, 25000, 4019, 2720",41401,13,not died
4,female,48,09 May 2135,4,"5119, 78551, 4589, 311, 7220, 71946, 2724",4239,17,not died


Now we have almost the dataset _ready_. We'd like to achieve something like the dementia data case where we had three feature:
 - static
 - event
 - dementia_status

To achieve so let us try to replicate such things with a modified version of the _annual_summary_by_age_death_integrated_ function wrote originally for the MIMIC dataset.

In [7]:
def annual_summary_by_age_admission_death_integrated(patient_id, df) -> dict:
    """
    Generate an annual summary by age, including deatch status for a specific patient.
    
    Parameters:
        patient_id (int): The ID of the patient.
        df (pd.DataFrame): The DataFrame containing patient data.
    
    Returns:
        dict: A dictionary containing static information, event information, and death status.
    """
    # Calculate death status (outcome) for all patients
    death_status_all = df.groupby('subject_id')['outcome'].agg(
        lambda x: x.mode().iloc[0] if not x.mode().empty else np.nan
    )

    # Filter by patient ID
    patient_data = df[df['subject_id'] == patient_id]

    # Static information
    static_summary = ""
    if not patient_data.empty:
        gender = patient_data['gender'].iloc[0]
        #death_date = patient_data['pasient_dodsarmaned'].iloc[0]
        
        static_summary += f"Gender: {gender}\n"
        #if pd.notna(death_date):
        #    static_summary += f"Deceased: {death_date}\n"
    else:
        return static_summary + "No records available.\n"
    
    # Event information
    event_summary = ""
    # patient_data['year'] = patient_data['inndato'].dt.year
    # patient_data['age_at_diagnosis'] = patient_data['year'] - birth_year
    annual_diagnoses = patient_data.groupby('age_admission')['all_other_diag'].agg(list)
    # patient_data['med_year'] = patient_data['treatment_start'].dt.year
    # patient_data['age_at_medication'] = patient_data['med_year'] - birth_year
    # annual_medications = patient_data.groupby('age_at_medication')['atc_name'].agg(list) # we don't have here medications (there are in MIMIC)

    all_ages = sorted(set(annual_diagnoses.index))

    # For all ages placing all diagnoses and medication of a given patient
    for age in all_ages:
        event_summary += f"\nAt the age of {int(age)}:\n"
        if age in annual_diagnoses:
            event_summary += "Diagnoses: " + ", ".join(set(annual_diagnoses.loc[age])) + "\n"

    # combine summaries
    death_status = death_status_all.get(patient_id, "Unknown") # If not found place "Unknown"
    #full_summary = static_summary + event_summary #+ f"{dementia_status}" #add Dementian Status: if needed
    return {'Patient ID' : patient_id,'static': static_summary, 'event': event_summary, 'death_status': death_status}


Now we need to create another function to use the previous one

In [8]:
def generate_annual_summaries(data) -> pd.DataFrame:
    """
    Generate annual summaries, including dementia statuses, for all patients in the provided data.

    Parameters:
        data (pd.DataFrame): The DataFrame containing patient data.
    
    Returns:
        pd.DataFrame: A DataFrame with annual summaries for all patients.
    """
    grouped_data = {k: v for k, v in data.groupby('subject_id')}
    
    # dementia status
    # Not needed
    # dementia_status = data.groupby('lopenr')['flag_all'].agg(
    #     lambda x: decode_dementia_status(x.mode().iloc[0] if not x.mode().empty else np.nan)
    # )

    summaries = []
    for patient_id, patient_data in grouped_data.items():
        patient_summary = annual_summary_by_age_admission_death_integrated(patient_id, patient_data)
        summaries.append(patient_summary)
    
    # Giving a list of dictionaries into a pandas DF creates a DF 
    return pd.DataFrame(summaries)


Let us use the function

In [9]:
df.head(5)

Unnamed: 0,gender,age_admission,admission_date,len_admission_days,all_other_diag,main_diagnosis,subject_id,outcome
0,male,77,20 Oct 2101,11,"78559, 5849, 4275, 41071, 4280, 6826, 4254, 2639",389,3,died
1,female,48,16 Mar 2191,8,"1363, 7994, 2763, 7907, 5715, 04111, V090, E9317",42,4,not died
2,female,50,16 Apr 2178,26,,1913,11,died
3,female,40,08 Jan 2167,7,"4111, 25000, 4019, 2720",41401,13,not died
4,female,48,09 May 2135,4,"5119, 78551, 4589, 311, 7220, 71946, 2724",4239,17,not died


In [10]:

# annual summaries dataframe
df_summarized = generate_annual_summaries(df)
df_summarized.head(5)

Unnamed: 0,Patient ID,static,event,death_status
0,3,Gender: male\n,"\nAt the age of 77:\nDiagnoses: 78559, 5849, 4...",died
1,4,Gender: female\n,"\nAt the age of 48:\nDiagnoses: 1363, 7994, 27...",not died
2,11,Gender: female\n,\nAt the age of 50:\nDiagnoses: nan\n,died
3,13,Gender: female\n,"\nAt the age of 40:\nDiagnoses: 4111, 25000, 4...",not died
4,17,Gender: female\n,"\nAt the age of 48:\nDiagnoses: 5119, 78551, 4...",not died


Now, let us start with the more complicated parts (i.e. MedData and so on)

In [11]:
import datasets
from datasets import DatasetDict

ds_all = datasets.Dataset.from_pandas(df_summarized) # TODO: why is this necessary?
split_ratio = [0.8, 0.10, 0.10]

ds_split = ds_all.train_test_split(test_size=split_ratio[1]+split_ratio[2], seed=42) # split into train and (val+test)

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
# Further split (val+test)
ds_val_test = ds_split['test'].train_test_split(test_size=split_ratio[2]/(split_ratio[1]+split_ratio[2]), seed=42)

# Combine all splits into a DatasetDict
ds = DatasetDict({
    'train': ds_split['train'],
    'validation': ds_val_test['train'],
    'test': ds_val_test['test']
})

ds.save_to_disk("./")

Saving the dataset (0/1 shards):   0%|          | 0/23397 [00:00<?, ? examples/s]

Saving the dataset (1/1 shards): 100%|██████████| 23397/23397 [00:00<00:00, 218042.41 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 2925/2925 [00:00<00:00, 162746.76 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 2925/2925 [00:00<00:00, 168255.35 examples/s]
