In [None]:
import os
import pickle
import numpy as np
import pandas as pd
import yaml
import random

from tqdm import tqdm
from sklearn.model_selection import train_test_split,KFold
from sklearn.metrics import confusion_matrix

from pyhealth.datasets import SampleEHRDataset, get_dataloader
from pyhealth.models import Transformer,RNN
from pyhealth.trainer import Trainer
from pyhealth.metrics.binary import binary_metrics_fn
from pyhealth.metrics.fairness import fairness_metrics_fn

# Creating the Dataset

This code works with the public MIMIC-III ICU stay database. Before using the code, please apply, complete training, and download the requisite files from <https://physionet.org>. The required files are:
* 'PATIENTS.csv'
* 'ADMISSIONS.csv'
* 'DIAGNOSES_ICD.csv'
* 'PROCEDURES_ICD.csv'
* 'PRESCRIPTIONS.csv'
* 'CHARTEVENTS.csv'

## Patient Event Generation
Next, to create the patient events you need to perform the mimic3-benchmarks preprocessing according the the repository found at <https://github.com/YerevaNN/mimic3-benchmarks>. That repository has comprehensive documentation, and it will create a series of .csv files containing lab time-series information for each ICU stay. You just need to get through the `extract_episodes_from_subjects` step.

From there,  edit the `mimic_dir` and `timeseries_dir` variables in the notebook. Running the cell will generate all of the base data files for the experiments.

Next, we need to discretize the continuous variables to feed them into the generator. To do so, create a `discretized_data/` directory and run the cell under 'discretize'

At this point, the discretized data and corresponding artifacts will be available, and your dataset will be fully processed.

In [None]:
mimic_dir = "D:\\Phd\\Prelim\\mimic3\\"
timeseries_dir = "data\\mimic3-benchmarks\\root\\"
valid_subjects = os.listdir(timeseries_dir)
patientsFile = mimic_dir + "PATIENTS.csv"
admissionFile = mimic_dir + "ADMISSIONS.csv"
diagnosisFile = mimic_dir + "DIAGNOSES_ICD.csv"
procedureFile = mimic_dir + "PROCEDURES_ICD.csv"
medicationFile = mimic_dir + "PRESCRIPTIONS.csv"

channel_to_id = pickle.load(open("data\\channel_to_id.pkl", "rb"))
is_categorical_channel = pickle.load(open("data\\is_categorical_channel.pkl", "rb"))
possible_values = pickle.load(open("data\\possible_values.pkl", "rb"))
begin_pos = pickle.load(open("data\\begin_pos.pkl", "rb"))
end_pos = pickle.load(open("data\\end_pos.pkl", "rb"))

print("Loading CSVs Into Dataframes")
patientsDf = pd.read_csv(patientsFile, dtype=str).set_index("SUBJECT_ID")
patientsDf['DOB'] = pd.to_datetime(patientsDf['DOB'])
patientsDf['DOD'] = pd.to_datetime(patientsDf['DOD'])
patientsDf['DOD_HOSP'] = pd.to_datetime(patientsDf['DOD_HOSP'])

gender_mapping = {gender: index for index, gender in enumerate({'M','F'})}
patientsDf['GENDER_MAP'] = patientsDf['GENDER'].map(gender_mapping)

admissionDf = pd.read_csv(admissionFile, dtype=str)
admissionDf['ADMITTIME'] = pd.to_datetime(admissionDf['ADMITTIME'])
admissionDf['DISCHTIME'] = pd.to_datetime(admissionDf['DISCHTIME'])
admissionDf['DEATHTIME'] = pd.to_datetime(admissionDf['DEATHTIME'])
def map_ethnicity(ethnicity):
    if 'WHITE' in ethnicity:
        return 0
    elif 'BLACK' in ethnicity:
        return 1
    elif 'HISPANIC' in ethnicity:
        return 2
    elif 'ASIAN' in ethnicity:
        return 3
    else:
        return 4

admissionDf["ETHNICITY_MAP"] = admissionDf['ETHNICITY'].apply(map_ethnicity)
insurance_mapping = {
    'Private': 0,
    'Medicare': 1,
    'Medicaid': 2,
    'Self Pay': 3,
    'Government': 4
}
admissionDf["INSURANCE_MAP"] = admissionDf['INSURANCE'].map(insurance_mapping)
admissionDf = admissionDf.sort_values('ADMITTIME')
admissionDf = admissionDf.reset_index(drop=True)

diagnosisDf = pd.read_csv(diagnosisFile, dtype=str).set_index("HADM_ID")
diagnosisDf = diagnosisDf[diagnosisDf['ICD9_CODE'].notnull()]
diagnosisDf = diagnosisDf[diagnosisDf['ICD9_CODE'].str.startswith("428")]
diagnosisDf = diagnosisDf[['ICD9_CODE']]

procedureDf = pd.read_csv(procedureFile, dtype=str).set_index("HADM_ID")
procedureDf = procedureDf[procedureDf['ICD9_CODE'].notnull()]
procedureDf = procedureDf[['ICD9_CODE']]

medicationDf = pd.read_csv(medicationFile, dtype=str).set_index("HADM_ID")
medicationDf = medicationDf[medicationDf['NDC'].notnull()]
medicationDf = medicationDf[medicationDf['NDC'] != 0]
medicationDf = medicationDf[['NDC', 'DRUG']]
medicationDf['NDC'] = medicationDf['NDC'].astype(np.int64).astype(str)
medicationDf['NDC'] = [('0' * (11 - len(c))) + c for c in medicationDf['NDC']]
medicationDf['NDC'] = [c[0:5] + '-' + c[5:9] + '-' + c[10:12] for c in medicationDf['NDC']]
print("DONE LAODING")

# Exploratory Data Analysis

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Assuming patientDf and admissionDf are your dataframes
merged_data = pd.merge(admissionDf, patientsDf, on='SUBJECT_ID')

# Gender distribution
gender_counts = merged_data['GENDER_MAP'].value_counts()
gender_deaths = merged_data.groupby(['GENDER_MAP', 'HOSPITAL_EXPIRE_FLAG']).size().unstack(fill_value=0)

# Racial distribution
race_counts = merged_data['ETHNICITY_MAP'].value_counts()
race_deaths = merged_data.groupby(['ETHNICITY_MAP', 'HOSPITAL_EXPIRE_FLAG']).size().unstack(fill_value=0)

# Create 2x2 subplots
fig, axs = plt.subplots(2, 2, figsize=(12, 10))

# Gender distribution
sns.barplot(x=gender_counts.index, y=gender_counts.values, color='skyblue', ax=axs[0, 0])
axs[0, 0].set_title('Gender Distribution')
axs[0, 0].set_xlabel('Gender: 0 is Male and 1 is Female')
axs[0, 0].set_ylabel('Number of Patients')
for ax in axs.flatten():
    for p in ax.patches:
        ax.annotate(f'{p.get_height()}', (p.get_x() + p.get_width() / 2., p.get_height()),
                    ha='center', va='center', fontsize=10, color='black', xytext=(0, 5),
                    textcoords='offset points')

# Racial distribution
sns.barplot(x=race_counts.index, y=race_counts.values, color='skyblue', ax=axs[0, 1])
axs[0, 1].set_title('Racial Distribution')
axs[0, 1].set_xlabel('Race')
axs[0, 1].set_ylabel('Number of Patients')
for ax in axs.flatten():
    for p in ax.patches:
        ax.annotate(f'{p.get_height()}', (p.get_x() + p.get_width() / 2., p.get_height()),
                    ha='center', va='center', fontsize=10, color='black', xytext=(0, 5),
                    textcoords='offset points')

# Number of patients of each gender (stacked)
sns.barplot(data=gender_counts.reset_index(), x='index', y='GENDER_MAP', color='skyblue', ax=axs[1, 0])
sns.barplot(data=gender_deaths.reset_index(), x='GENDER_MAP', y=1, color='red', ax=axs[1, 0])
sns.barplot(data=gender_deaths.reset_index(), x='GENDER_MAP', y=0, color='green', ax=axs[1, 0])
axs[1, 0].set_title('Gender Distribution with Deaths')
axs[1, 0].set_xlabel('Gender')
axs[1, 0].set_ylabel('Number of Patients')
for ax in axs.flatten():
    for p in ax.patches:
        ax.annotate(f'{p.get_height()}', (p.get_x() + p.get_width() / 2., p.get_height()),
                    ha='center', va='center', fontsize=10, color='black', xytext=(0, 5),
                    textcoords='offset points')

# Number of patients of each ethnicity (stacked)
sns.barplot(data=race_counts.reset_index(), x='index', y='ETHNICITY_MAP', color='skyblue', ax=axs[1, 1])
sns.barplot(data=race_deaths.reset_index(), x='ETHNICITY_MAP', y=1, color='red', ax=axs[1, 1])
sns.barplot(data=race_deaths.reset_index(), x='ETHNICITY_MAP', y=0, color='green', ax=axs[1, 1])
axs[1, 1].set_title('Racial Distribution with Deaths')
axs[1, 1].set_xlabel('Race')
axs[1, 1].set_ylabel('Number of Patients')
for ax in axs.flatten():
    for p in ax.patches:
        ax.annotate(f'{p.get_height()}', (p.get_x() + p.get_width() / 2., p.get_height()),
                    ha='center', va='center', fontsize=10, color='black', xytext=(0, 5),
                    textcoords='offset points')


plt.tight_layout()
plt.show()

In [None]:
#Smaller dataset
print(len(admissionDf))
# Randomly sample 10,000 rows
admissionDf = admissionDf.sample(n=1000, random_state=42)  # You can adjust the random_state for reproducibility

print(len(admissionDf))

In [None]:
print("Building Dataset")
data = {}
for row in tqdm(admissionDf.itertuples(), total=len(admissionDf)):
    hadm_id = row.HADM_ID
    subject_id = row.SUBJECT_ID
    admit_time = row.ADMITTIME
    ethnicity = row.ETHNICITY_MAP
    insurance = row.INSURANCE_MAP

    if subject_id not in patientsDf.index:
        continue
    visit_count = (0 if subject_id not in data else len(data[subject_id]["visits"])) + 1

    tsDf = (
        pd.read_csv(f"{timeseries_dir}{subject_id}/episode{visit_count}_timeseries.csv")
        if os.path.exists(
            f"{timeseries_dir}{subject_id}/episode{visit_count}_timeseries.csv"
        )
        else None
    )

    patientRow = patientsDf.loc[[subject_id]].iloc[0]
    age = (admit_time.to_pydatetime() - patientRow["DOB"].to_pydatetime()).days / 365
    if age > 120:
        continue

    # Extracting the Diagnoses
    if hadm_id in diagnosisDf.index:
        diagnoses = list(set(diagnosisDf.loc[[hadm_id]]["ICD9_CODE"]))
    else:
        diagnoses = []

    # Extracting the Procedures
    if hadm_id in procedureDf.index:
        procedures = list(set(procedureDf.loc[[hadm_id]]["ICD9_CODE"]))
    else:
        procedures = []

    # Extracting the Medications
    if hadm_id in medicationDf.index:
        medications = list(set(medicationDf.loc[[hadm_id]]["NDC"]))
    else:
        medications = []

    # Extract the lab timeseries
    labs = []
    prevTime = 0
    currTime = int(tsDf.iloc[0]["Hours"]) if tsDf is not None else 0
    currMask = []
    currValues = []
    if tsDf is not None:
        for i, row in tsDf.iterrows():
            rowTime = int(row["Hours"])

            if rowTime != currTime:
                labs.append((currMask, currValues, [currTime - prevTime]))
                prevTime = currTime
                currTime = rowTime
                currMask = []
                currValues = []

            for col, value in row.items():
                if value != value or col == "Hours":
                    continue

                if is_categorical_channel[col]:
                    if col == "Glascow coma scale total":
                        value = str(int(value))
                    elif col == "Capillary refill rate":
                        value = str(value)

                    if begin_pos[channel_to_id[col]] in currMask:
                        currValues[
                            currMask.index(
                                begin_pos[channel_to_id[col]]
                                + possible_values[col].index(value)
                            )
                        ] = 1
                    else:
                        for j in range(
                            begin_pos[channel_to_id[col]], end_pos[channel_to_id[col]]
                        ):
                            currMask.append(j)
                            currValues.append(
                                1
                                if j - begin_pos[channel_to_id[col]]
                                == possible_values[col].index(value)
                                else 0
                            )
                else:
                    if begin_pos[channel_to_id[col]] in currMask:
                        currValues[
                            currMask.index(begin_pos[channel_to_id[col]])
                        ] = value
                    else:
                        currMask.append(begin_pos[channel_to_id[col]])
                        currValues.append(value)

        labs.append((currMask, currValues, [currTime - prevTime]))

        # Building the hospital admission data point
    if subject_id not in data:
        data[subject_id] = {
            "visits": [(diagnoses, procedures, medications, age, labs)],
            "gender": patientRow.GENDER_MAP,
            "ethnicity": ethnicity,
            "insurance": insurance,
            "isDead": patientRow.EXPIRE_FLAG,
        }
    else:
        data[subject_id]["visits"].append(
            (diagnoses, procedures, medications, age, labs)
        )
pickle.dump(data, open("./data/data_genDatasetContinuous.pkl", "wb"))

# Build the label mapping
print("Adding Labels")
with open("../hcup_ccs_2015_definitions_benchmark.yaml") as definitions_file:
    definitions = yaml.full_load(definitions_file)

code_to_group = {}
for group in definitions:
    if definitions[group]["use_in_benchmark"] == False:
        continue
    codes = definitions[group]["codes"]
    for code in codes:
        if code not in code_to_group:
            code_to_group[code] = group
        else:
            assert code_to_group[code] == group

id_to_group = sorted(
    [k for k in definitions.keys() if definitions[k]["use_in_benchmark"] == True]
)
group_to_id = dict((x, i) for (i, x) in enumerate(id_to_group))

for p in data:
    label = np.zeros(len(group_to_id))
    for v in data[p]["visits"]:
        for d in v[0]:
            d = str(d)
            if d not in code_to_group:
                continue

            label[group_to_id[code_to_group[d]]] = 1

    data[p]["labels"] = label
    data[p]["labels"] = np.append(data[p]["labels"], data[p]["insurance"])
    data[p]["labels"] = np.append(data[p]["labels"], data[p]["ethnicity"])
    data[p]["labels"] = np.append(data[p]["labels"], data[p]["gender"])
    data[p]["labels"] = np.append(data[p]["labels"], data[p]["isDead"])

# Convert diagnoses, procedures, and medications to text
print("Converting Codes to Text")
medMapping = {row["NDC"]: row["DRUG"] for _, row in medicationDf.iterrows()}
for p in data:
    new_visits = []
    for v in data[p]["visits"]:
        new_visit = []
        for c in v[0]:
            new_visit.append(c)
        for c in v[1]:
            new_visit.append(c)
        for c in v[2]:
            if c in medMapping:
                new_visit.append(medMapping[c])
            else:
                new_visit.append(c)

        new_visits.append((new_visit, [], [], [v[3]]))

        for lab_v in v[4]:
            new_visits.append(([], lab_v[0], lab_v[1], lab_v[2]))
    data[p]["visits"] = new_visits

# Convert diagnoses, procedures, and medications to indices

print("Converting Codes to Indices")
allCodes = list(set([c for p in data for v in data[p]["visits"] for c in v[0]]))
np.random.shuffle(allCodes)
code_to_index = {c: i for i, c in enumerate(allCodes)}
counter = 0
for p in data:
    new_visits = []
    for v in data[p]["visits"]:
        new_visit = []
        for c in v[0]:
            new_visit.append(code_to_index[c])

        new_visits.append((new_visit, v[1], v[2], v[3]))
    data[p]["visits"] = new_visits

index_to_code = {v: k for k, v in code_to_index.items()}
data = list(data.values())

MAX_TIME_STEPS = 150
data = [
    {"labels": data[i]["labels"], "visits": data[i]["visits"][: MAX_TIME_STEPS - 2]}
    for i in range(len(data))
]  # 2 for the start and label visits

In [None]:
# Train-Val-Test Split
print("Splitting Datasets")
train_dataset, test_dataset = train_test_split(
    data, test_size=0.4, random_state=4, shuffle=True
) # train 60, val 10, test 30
train_dataset, val_dataset = train_test_split(
    train_dataset, test_size=0.1, random_state=4, shuffle=True
)

# Save Everything
print("Saving Everything")
print(len(index_to_code))
print(len(data[0]["labels"]))
pickle.dump(
    dict((i, x) for (x, i) in list(group_to_id.items())),
    open("./data/idToLabel.pkl", "wb"),
)
pickle.dump(index_to_code, open("./data/indexToCode.pkl", "wb"))
pickle.dump(data, open("./data/allData_1000.pkl", "wb"))
pickle.dump(train_dataset, open("./data/trainData.pkl", "wb"))
pickle.dump(val_dataset, open("./data/valData.pkl", "wb"))
pickle.dump(test_dataset, open("./data/testData.pkl", "wb"))

In [None]:
import pickle
import random
import numpy as np

from sklearn.metrics import confusion_matrix
from sklearn.model_selection import KFold
from pyhealth.metrics.binary import binary_metrics_fn
from pyhealth.metrics.fairness import fairness_metrics_fn
from pyhealth.models import RNN, Transformer
from pyhealth.trainer import Trainer
from pyhealth.datasets import SampleEHRDataset, get_dataloader, split_by_patient
from pyhealth.datasets.splitter import split_by_patient

allData_5000 = pickle.load(open("./data/allData_5000.pkl","rb"))
allData_1000 = allData_5000[0:1000]
# Train-Val-Test Split
print("Splitting Datasets")
train_ehr_data, test_ehr_data = train_test_split(
  allData_1000  , test_size=0.1, random_state=4, shuffle=True
) # train 80, val 10, test 10
train_ehr_data, val_ehr_data = train_test_split(
    train_ehr_data, test_size=0.1, random_state=4, shuffle=True
)


def transform_data(ehr_dataset):
    final_data = []
    patient_id = 0  # Starting patient ID

    for patient in ehr_dataset:
        for i, visit in enumerate(patient["visits"]):
            visit_data = {
                "visit_id": i,
                "patient_id": patient_id,
                "visit_codes": [[int(x) for x in visit[0]]],
                "gender": [[int(float(patient["labels"][26]))]],
                "ethnicity": [[int(float(patient["labels"][25]))]],
                "disease_label": [[int(float(x)) for x in patient["labels"][0:25]]],
                "label": int(float(patient["labels"][27])),
            }
            final_data.append(visit_data)
        patient_id += 1
    return final_data


def calculate_wtpr(y_true, y_prob, sensitive_attribute, threshold=0.5):
    y_pred = (y_prob >= threshold).astype(int)
    subgroups = np.unique(sensitive_attribute)
    tpr_scores = {}

    for subgroup in subgroups:
        subgroup_mask = sensitive_attribute == subgroup
        y_true_subgroup = y_true[subgroup_mask]
        y_pred_subgroup = y_pred[subgroup_mask]

        confusion_mat = confusion_matrix(y_true_subgroup, y_pred_subgroup)

        if confusion_mat.size == 1:
            if y_true_subgroup[0] == 1:
                tp = confusion_mat[0, 0]
                fn = 0
            else:
                tp = 0
                fn = confusion_mat[0, 0]
            tn = fp = 0
        else:
            tn, fp, fn, tp = confusion_mat.ravel()

        tpr = tp / (tp + fn) if (tp + fn) != 0 else 0
        tpr_scores[subgroup] = tpr

    wtpr = np.mean(list(tpr_scores.values()))
    return wtpr

transformed_train_ehr_dataset = transform_data(train_ehr_data)
transformed_val_ehr_dataset = transform_data(val_ehr_data)
transformed_test_ehr_dataset = transform_data(test_ehr_data)
transformed_allData_1000 = transform_data(allData_1000)

max_visit_codes_length = max(
    len(sample["visit_codes"][0]) for sample in transformed_allData_1000
)
for sample in transformed_train_ehr_dataset:
    visit_codes = sample["visit_codes"][0]
    padded_visit_codes = np.pad(
        visit_codes, (0, max_visit_codes_length - len(visit_codes)), mode="constant"
    )
    sample["visit_codes"][0] = padded_visit_codes.tolist()
for sample in transformed_val_ehr_dataset:
    visit_codes = sample["visit_codes"][0]
    padded_visit_codes = np.pad(
        visit_codes, (0, max_visit_codes_length - len(visit_codes)), mode="constant"
    )
    sample["visit_codes"][0] = padded_visit_codes.tolist()
for sample in transformed_test_ehr_dataset:
    visit_codes = sample["visit_codes"][0]
    padded_visit_codes = np.pad(
        visit_codes, (0, max_visit_codes_length - len(visit_codes)), mode="constant"
    )
    sample["visit_codes"][0] = padded_visit_codes.tolist()

formatted_train_ehr_dataset = SampleEHRDataset(samples=transformed_train_ehr_dataset)
formatted_val_ehr_dataset = SampleEHRDataset(samples=transformed_val_ehr_dataset)
formatted_test_ehr_dataset = SampleEHRDataset(samples=transformed_test_ehr_dataset)

k = 5  # Number of folds
fairness_scores = {
    "disparate_impact": [],
    "statistical_parity_difference": [],
    "wtpr": [],
}

kf = KFold(n_splits=k, shuffle=True, random_state=42)
formatted_combined_ehr_dataset = formatted_train_ehr_dataset + formatted_val_ehr_dataset

for train_index, val_index in kf.split(formatted_combined_ehr_dataset):
    fold_train_dataset = [formatted_combined_ehr_dataset[i] for i in train_index]
    fold_val_dataset = [formatted_combined_ehr_dataset[i] for i in val_index]

    transformermodel = Transformer(
        dataset=formatted_train_ehr_dataset,
        # look up what are available for "feature_keys" and "label_keys" in dataset.samples[0]
        feature_keys=["visit_codes", "disease_label", "ethnicity", "gender"],
        label_key="label",
        mode="binary",
    )

    train_loader = get_dataloader(fold_train_dataset, batch_size=64, shuffle=True)
    val_loader = get_dataloader(fold_val_dataset, batch_size=64, shuffle=True)

    trainer = Trainer(model=transformermodel)
    trainer.train(
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        epochs=30,
        monitor="pr_auc",
    )

    y_true, y_prob, loss = trainer.inference(val_loader)

    # protected_group = 1  # female in gender
    # # Prepare the sensitive attribute array for the validation set
    # sensitive_attribute_array = np.zeros(len(fold_val_dataset), dtype=int)
    # for idx, visit in enumerate(fold_val_dataset):
    #     sensitive_attribute_value = visit["gender"][0][0]
    #     if sensitive_attribute_value == protected_group:
    #         sensitive_attribute_array[idx] = 1

    # Prepare the sensitive attribute array for the validation set
 
    unprotected_group = 0  # white in eth
    sensitive_attribute_array= np.zeros(len(fold_val_dataset), dtype=int)
    for idx,visit in enumerate(fold_val_dataset):
            sensitive_attribute_value = visit["ethnicity"][0][0]
            if sensitive_attribute_value != unprotected_group:
                sensitive_attribute_array[idx] = 1

    
    # Calculate fairness metrics for the current fold
    fold_fairness_metrics = fairness_metrics_fn(
        y_true,
        y_prob,
        sensitive_attributes=sensitive_attribute_array,
        favorable_outcome=1,
        metrics=None,
        threshold=0.5,
    )
    wtpr = calculate_wtpr(y_true, y_prob, sensitive_attribute_array)

    # Append the fairness metrics for the current fold
    fairness_scores["disparate_impact"].append(
        fold_fairness_metrics["disparate_impact"]
    )
    fairness_scores["statistical_parity_difference"].append(
        fold_fairness_metrics["statistical_parity_difference"]
    )
    fairness_scores["wtpr"].append(wtpr)

In [None]:
# Calculate the mean and standard deviation for each fairness metric
fairness_metrics = {}
for metric, scores in fairness_scores.items():
    values = scores
    mean = np.mean(values)
    std = np.std(values)
    print(f"{metric}: Mean = {mean:.4f}, Std = {std:.4f}")
    fairness_metrics[metric] = mean

# Discretization

In [None]:
import pickle

trainData = pickle.load(open("data/trainData.pkl", "rb"))
valData = pickle.load(open("data/valData.pkl", "rb"))
idToLab = pickle.load(open("./data/idx_to_lab.pkl", "rb"))
labToNumber = {
    l: i for (i, l) in enumerate(pickle.load(open("./data/id_to_channel.pkl", "rb")))
}
isCategorical = pickle.load(open("./data/is_categorical_channel.pkl", "rb"))
beginPos = pickle.load(open("./data/begin_pos.pkl", "rb"))
possibleValues = pickle.load(open("./data/possible_values.pkl", "rb"))
variableRanges = pickle.load(open("./data/variable_ranges.pkl", "rb"))
discretization = {
    "Diastolic blood pressure": [
        0,
        40,
        50,
        60,
        65,
        70,
        75,
        80,
        85,
        90,
        95,
        100,
        105,
        110,
        120,
        130,
        375,
    ],
    "Fraction inspired oxygen": [
        0.2,
        0.3,
        0.4,
        0.5,
        0.6,
        0.7,
        0.8,
        0.9,
        1.0,
        1.001,
        1.1,
    ],
    "Glucose": [
        0,
        40,
        60,
        80,
        100,
        110,
        120,
        130,
        140,
        150,
        160,
        170,
        180,
        200,
        225,
        275,
        325,
        400,
        600,
        800,
        1000,
        2200,
    ],
    "Heart Rate": [0, 40, 50, 60, 70, 80, 90, 100, 110, 120, 140, 160, 180, 200, 390],
    "Height": [0, 145, 150, 155, 160, 165, 170, 175, 180, 185, 190, 195, 230],
    "Mean blood pressure": [
        0,
        40,
        50,
        60,
        70,
        80,
        90,
        100,
        110,
        120,
        130,
        140,
        150,
        160,
        180,
        200,
        375,
    ],
    "Oxygen saturation": [
        0,
        30,
        40,
        50,
        55,
        60,
        65,
        70,
        75,
        80,
        85,
        90,
        100,
        100.001,
        150,
    ],
    "pH": [6.3, 6.7, 7.1, 7.35, 7.45, 7.6, 8.0, 8.3, 10],
    "Respiratory rate": [0, 6, 8, 10, 12, 14, 16, 18, 20, 25, 30, 35, 330],
    "Systolic blood pressure": [
        0,
        40,
        50,
        60,
        70,
        80,
        90,
        100,
        110,
        120,
        130,
        140,
        150,
        160,
        170,
        180,
        190,
        200,
        210,
        230,
        375,
    ],
    "Temperature": [
        14.2,
        30,
        32,
        33,
        33.5,
        34,
        34.5,
        35,
        35.5,
        36,
        36.5,
        37,
        37.5,
        38,
        38.5,
        39,
        39.5,
        40,
        47,
    ],
    "Weight": [
        0,
        30,
        40,
        45,
        50,
        55,
        60,
        65,
        70,
        75,
        80,
        85,
        90,
        95,
        100,
        105,
        110,
        115,
        120,
        125,
        130,
        135,
        140,
        145,
        150,
        160,
        170,
        190,
        210,
        250,
    ],
    "Age": [18, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90],
    "Days": [0, 11, 16, 21, 25, 30.1, 35.1, 43, 48, 54, 60, 66, 72, 81, 90, 100.1],
    "Hours": [
        0,
        0.5,
        1.5,
        2.5,
        3.5,
        6.5,
        10.5,
        16.5,
        26.5,
        48.0,
        48.1,
        60.1,
        80.1,
        110.1,
        150.1,
        200.1,
    ],
}

formatMap = {
    "Diastolic blood pressure": (".0f", int),
    "Fraction inspired oxygen": (".2f", float),
    "Glucose": (".0f", int),
    "Heart Rate": (".0f", int),
    "Height": (".0f", int),
    "Mean blood pressure": (".0f", int),
    "Oxygen saturation": (".0f", int),
    "pH": (".2f", float),
    "Respiratory rate": (".0f", int),
    "Systolic blood pressure": (".0f", int),
    "Temperature": (".1f", float),
    "Weight": (".1f", float),
    "Age": (".2f", float),
    "Days": (".2f", float),
    "Hours": (".1f", float),
}


def get_index(mapping, key, value):
    possible_values = mapping[key]
    for i in range(len(possible_values) - 1):
        if value < possible_values[i + 1]:
            return i

    print(f"{value} for {key} not in {possible_values}")
    return len(possible_values) - 2


# Convert to New Data Format
for p in trainData + valData:
    new_visits = []
    firstVisit = True
    for v in p["visits"]:
        if v[1] == []:
            new_cont = get_index(
                discretization, "Age" if firstVisit else "Days", v[3][-1]
            )
            firstVisit = False
            new_visits.append((v[0], [], [], [new_cont]))
        else:
            new_labs = []
            new_values = []
            for l, val in zip(v[1], v[2]):
                if isCategorical[idToLab[l]]:
                    if val == 1:
                        new_labs.append(labToNumber[idToLab[l]])
                        new_values.append(beginPos[labToNumber[idToLab[l]]] - l)
                else:
                    if (
                        val < variableRanges[idToLab[l]][0]
                        or val >= variableRanges[idToLab[l]][1]
                    ):
                        continue

                    new_labs.append(labToNumber[idToLab[l]])
                    new_values.append(get_index(discretization, idToLab[l], val))

            if not new_labs:
                continue
            new_cont = get_index(discretization, "Hours", v[3][-1])
            new_visits.append((v[0], new_labs, new_values, [new_cont]))

    p["visits"] = new_visits

pickle.dump(trainData, open("./discretized_data/trainDataset.pkl", "wb"))
pickle.dump(valData, open("./discretized_data/valDataset.pkl", "wb"))

newIdToLab = {i: l for (l, i) in labToNumber.items()}
newBeginPos = []
seenContinuous = False
for i in range(len(newIdToLab)):
    if not seenContinuous:
        newBeginPos.append(beginPos[i])
        if not isCategorical[newIdToLab[i]]:
            seenContinuous = True
            currPos = newBeginPos[i] + len(discretization[newIdToLab[i]]) - 1
    else:
        newBeginPos.append(currPos)
        currPos += len(discretization[newIdToLab[i]]) - 1

newIdxToId = {}
for i in range(len(newBeginPos) - 1):
    for j in range(newBeginPos[i], newBeginPos[i + 1]):
        newIdxToId[j] = i
for j in range(
    newBeginPos[-1],
    newBeginPos[-1] + len(discretization[newIdToLab[len(newBeginPos) - 1]]) - 1,
):
    newIdxToId[j] = len(newBeginPos) - 1

pickle.dump(newIdxToId, open("discretized_data/idxToId.pkl", "wb"))
pickle.dump(formatMap, open("discretized_data/formatMap.pkl", "wb"))
pickle.dump(newIdToLab, open("discretized_data/idToLab.pkl", "wb"))
pickle.dump(newBeginPos, open("discretized_data/beginPos.pkl", "wb"))
pickle.dump(isCategorical, open("discretized_data/isCategorical.pkl", "wb"))
pickle.dump(possibleValues, open("discretized_data/possibleValues.pkl", "wb"))
pickle.dump(discretization, open("discretized_data/discretization.pkl", "wb"))

print(f"NUM LABS: {newBeginPos[-1] + len(discretization[newIdToLab[16]]) - 1}")
print(f"NUM CONTINUOUS: {len(discretization['Age']) - 1}")

# Model

In [None]:
'''
    code by Brandon Theodorou
    Original GPT-2 Paper and repository here: https://github.com/openai/gpt-2
    Original GPT-2 Pytorch Model: https://github.com/huggingface/pytorch-pretrained-BERT
    GPT-2 Pytorch Model Derived From: https://github.com/graykode/gpt-2-Pytorch
'''
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


# GELU Activation and Layer Normalization:
# gelu(x): Gaussian Error Linear Unit (GELU) activation function.
# LayerNorm: Layer normalization module with learnable parameters.
def gelu(x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

class LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        """Construct a layernorm module in the TF style (epsilon inside the square root)."""
        super(LayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.weight * x + self.bias
    
# 1D convolutional layer with learnable weight and bias parameters.
class Conv1D(nn.Module):
    def __init__(self, nf, nx):
        super(Conv1D, self).__init__()
        self.nf = nf
        w = torch.empty(nx, nf)
        nn.init.normal_(w, std=0.02)
        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(torch.zeros(nf))

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(*size_out)
        return x
    
# Self-attention mechanism with scaled dot-product attention. It includes convolutional layers for query, key, and value projections.
class Attention(nn.Module):
    def __init__(self, nx, n_ctx, config, scale=False):
        super(Attention, self).__init__()
        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        assert n_state % config.n_head == 0
        self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale
        self.c_attn = Conv1D(n_state * 3, nx)
        self.c_proj = Conv1D(n_state, nx)

    def _attn(self, q, k, v):
        w = torch.matmul(q, k)
        if self.scale:
            w = w / math.sqrt(v.size(-1))
        nd, ns = w.size(-2), w.size(-1)
        b = self.bias[:, :, ns-nd:ns, :ns]
        w = w * b - 1e10 * (1 - b)
        w = nn.Softmax(dim=-1)(w)
        return torch.matmul(w, v)

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)
        if k:
            return x.permute(0, 2, 3, 1)  # (batch, head, head_features, seq_length)
        else:
            return x.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)

    def forward(self, x, layer_past=None):
        x = self.c_attn(x)
        query, key, value = x.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
        if layer_past is not None:
            past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1]  # transpose back cf below
            key = torch.cat((past_key, key), dim=-1)
            value = torch.cat((past_value, value), dim=-2)
        present = torch.stack((key.transpose(-2, -1), value))  # transpose to have same shapes for stacking
        a = self._attn(query, key, value)
        a = self.merge_heads(a)
        a = self.c_proj(a)
        return a, present
    
# Multi-Layer Perceptron module with a fully connected layer, activation function (GELU), and another fully connected layer.
class MLP(nn.Module):
    def __init__(self, n_state, config):  # in MLP: n_state=3072 (4 * n_embd)
        super(MLP, self).__init__()
        nx = config.n_embd
        self.c_fc = Conv1D(n_state, nx)
        self.c_proj = Conv1D(nx, n_state)
        self.act = gelu

    def forward(self, x):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
        return h2

# A block containing layer normalization, attention mechanism, and an MLP. These blocks are stacked to form the transformer model.
class Block(nn.Module):
    def __init__(self, n_ctx, config, scale=False):
        super(Block, self).__init__()
        nx = config.n_embd
        self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
        self.attn = Attention(nx, n_ctx, config, scale)
        self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
        self.mlp = MLP(4 * nx, config)

    def forward(self, x, layer_past=None):
        a, present = self.attn(self.ln_1(x), layer_past=layer_past)
        x = x + a
        m = self.mlp(self.ln_2(x))
        x = x + m
        return x, present

# The main transformer model composed of stacked blocks. It includes positional and visit embeddings.
class CoarseTransformerModel(nn.Module):
    def __init__(self, config):
        super(CoarseTransformerModel, self).__init__()
        self.n_layer = config.n_layer
        self.n_embd = config.n_embd
        self.n_vocab = config.total_vocab_size

        self.vis_embed_mat = nn.Linear(config.total_vocab_size, config.n_embd, bias=False)
        self.pos_embed_mat = nn.Embedding(config.n_positions, config.n_embd)
        block = Block(config.n_ctx, config, scale=True)
        self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
        self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)

    def forward(self, input_visits, position_ids=None, past=None):
        if past is None:
            past_length = 0
            past = [None] * len(self.h)
        else:
            past_length = past[0][0].size(-2)
        if position_ids is None:
            position_ids = torch.arange(past_length, input_visits.size(1) + past_length, dtype=torch.long,
                                        device=input_visits.device)
            position_ids = position_ids.unsqueeze(0).expand(input_visits.size(0), input_visits.size(1))

        inputs_embeds = self.vis_embed_mat(input_visits)
        position_embeds = self.pos_embed_mat(position_ids)
        hidden_states = inputs_embeds + position_embeds
        for block, layer_past in zip(self.h, past):
            hidden_states, _ = block(hidden_states, layer_past)
        hidden_states = self.ln_f(hidden_states)
        return hidden_states

# Linear layer with a configurable mask on the weights, ensuring an autoregressive property.    
class AutoregressiveLinear(nn.Linear):
    """ same as Linear except has a configurable mask on the weights """
    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features, bias)        
        self.register_buffer('mask', torch.tril(torch.ones(in_features, out_features)).int())
        
    def forward(self, input):
        return F.linear(input, self.mask * self.weight, self.bias)

    #A specific head that uses autoregressive linear layers for generating synthetic EHR data.
class FineAutoregressiveHead(nn.Module):
    def __init__(self, config):
        super(FineAutoregressiveHead, self).__init__()
        self.n_embd = config.n_embd
        self.total_vocab_size = config.total_vocab_size

        self.auto1 = AutoregressiveLinear(config.n_embd + self.total_vocab_size, config.n_embd + self.total_vocab_size)
        self.auto2 = AutoregressiveLinear(config.n_embd + self.total_vocab_size, config.n_embd + self.total_vocab_size)

    def forward(self, history, input_visits):
        history = history[:,:-1,:]
        input_visits = input_visits[:,1:,:]
        code_logits = self.auto2(torch.relu(self.auto1(torch.cat((history, input_visits), dim=2))))[:,:,self.n_embd-1:-1]
        return code_logits

    def sample(self, history, input_visits):
        history = history[:,:-1,:]
        input_visits = input_visits[:,1:,:]
        currVisit = torch.cat((history, input_visits), dim=2)[:,-1,:].unsqueeze(1)
        code_logits = self.auto2(torch.relu(self.auto1(currVisit)))[:,:,self.n_embd-1:-1]
        return code_logits

class HALOModel(nn.Module):
    def __init__(self, config):
        super(HALOModel, self).__init__()
        self.transformer = CoarseTransformerModel(config) # visit level
        self.ehr_head = FineAutoregressiveHead(config) # code level
        self.total_vocab_size = config.total_vocab_size

    def disparate_impact_loss_gen(self, code_probs, input_genders):
        # Get the unique gender labels
        gender_labels = torch.unique(input_genders)
        
        # Initialize variables to store the sum and count of positive predictions for each gender group
        gender_pos_sum = torch.zeros(len(gender_labels), device=code_probs.device)
        gender_pos_count = torch.zeros(len(gender_labels), device=code_probs.device)
        
        # Iterate over each patient
        for i in range(code_probs.size(0)):
            gender = input_genders[i]
            gender_idx = (gender_labels == gender).nonzero(as_tuple=True)[0]
            
            # Count the number of positive predictions for each gender group
            pos_pred = (code_probs[i] > 0.5).sum().item()
            gender_pos_sum[gender_idx] += pos_pred
            gender_pos_count[gender_idx] += 1
        
        # Calculate the positive prediction rate for each gender group
        gender_pos_rate = gender_pos_sum / gender_pos_count

        # Compute the disparate impact ratio
        if len(gender_pos_rate) > 1:
            di_ratio = gender_pos_rate[0] / gender_pos_rate[1]
        else:
            di_ratio = torch.tensor(1.0)  # Set di_ratio to 1 if there is only one gender group
        
        # Calculate the disparate impact loss
        di_loss = torch.abs(1 - di_ratio)
        
        return di_loss

    def disparate_impact_loss_eth(self, code_probs, input_ethnicities):
        # Get the unique ethnicity labels
        #print(code_probs)
        ethnicity_labels = torch.unique(input_ethnicities)
    
        # Initialize variables to store the sum and count of positive predictions for each ethnicity group
        ethnicity_pos_sum = torch.zeros(len(ethnicity_labels), device=code_probs.device)
        ethnicity_pos_count = torch.zeros(len(ethnicity_labels), device=code_probs.device)
    
        # Iterate over each patient
        for i in range(code_probs.size(0)):
            ethnicity = input_ethnicities[i]
            ethnicity_idx = (ethnicity_labels == ethnicity).nonzero(as_tuple=True)[0]
    
            # Count the number of positive predictions for each ethnicity group
            pos_pred = (code_probs[i] > 0.5).sum().item()
            ethnicity_pos_sum[ethnicity_idx] += pos_pred
            ethnicity_pos_count[ethnicity_idx] += 1
    
        # Calculate the positive prediction rate for each ethnicity group
        ethnicity_pos_rate = ethnicity_pos_sum / ethnicity_pos_count
    
        # Compute the disparate impact ratio
        if len(ethnicity_pos_rate) > 1:
            di_ratio = ethnicity_pos_rate[0] / ethnicity_pos_rate[1]
        else:
            di_ratio = torch.tensor(1.0)  # Set di_ratio to 1 if there is only one ethnicity group
    
        # Calculate the disparate impact loss
        di_loss = torch.abs(1 - di_ratio)
    
        return di_loss
        
    def forward(self, input_visits, input_eth, fairness_metrics, input_ethnicities=None, position_ids=None, ehr_labels=None, ehr_masks=None, past=None, pos_loss_weight=None):
        
        hidden_states = self.transformer(input_visits, position_ids, past)
        code_logits = self.ehr_head(hidden_states, input_visits)
        sig = nn.Sigmoid()
        code_probs = sig(code_logits)
        if ehr_labels is not None:    
            shift_labels = torch.clamp(ehr_labels[..., 1:, :].contiguous(), min=0.0, max=1.0)
            loss_weights = None
            if pos_loss_weight is not None:
                loss_weights = torch.ones(code_probs.shape, device=code_probs.device)
                loss_weights = loss_weights + (pos_loss_weight-1) * shift_labels
            if ehr_masks is not None:
                code_probs = code_probs * ehr_masks
                shift_labels = shift_labels * ehr_masks
                if pos_loss_weight is not None:
                    loss_weights = loss_weights * ehr_masks

            bce = nn.BCELoss(weight=loss_weights)
            loss = bce(code_probs, shift_labels)
            if input_eth is not None:
                di_loss = self.disparate_impact_loss_eth(code_probs, input_eth)
                if fairness_metrics is None:
                    loss = loss + 1 * di_loss
                else:
                    disparate_impact = torch.tensor(fairness_metrics['disparate_impact'], dtype=loss.dtype, device='cuda:0')
                    loss = (loss + disparate_impact).mean()
            return loss, code_probs, shift_labels
        return code_probs
            

    def sample(self, input_visits, random=True):
        sig = nn.Sigmoid()
        hidden_states = self.transformer(input_visits)
        i = 0
        while i < self.total_vocab_size:
            next_logits = self.ehr_head.sample(hidden_states, input_visits)
            next_probs = sig(next_logits)
            if random:
                visit = torch.bernoulli(next_probs)
            else:
                visit = torch.round(next_probs)

            remaining_visit = visit[:,0,i:]
            nonzero = torch.nonzero(remaining_visit, as_tuple=True)[1]
            if nonzero.numel() == 0:
                break

            first_nonzero = nonzero.min()
            input_visits[:,-1,i + first_nonzero] = visit[:,0,i + first_nonzero]
            i = i + first_nonzero + 1
            
        return input_visits

# Model Configuration

Depending on any dataset changes, you may need to adjust the config function file below to the dataset you are using. Specifically, you may need to set `code_vocab_size` and `label_vocab_size` based on what is printed at the end of running the data generation part and then set `lab_vocab_size` and `continuous_vocab_size` based on what is printed at the end of running the 'discretization'.

In [None]:
#config

class HALOConfig(object):
    def __init__(
            self,
            total_vocab_size=4101,
            code_vocab_size=3817,
            lab_vocab_size=237,
            continuous_vocab_size=15,
            label_vocab_size=29,
            special_vocab_size=3,

            categorical_lab_vocab_size=47,
            continuous_lab_vocab_size=190,
            
            phenotype_labels=25, 
            ethnicity_labels=10, 
            gender_labels=2,

            hidden_size = 128,
        
            fairness_weight = 1.0,
            
            n_positions=150,
            n_ctx=150, #context size
            n_embd=1440,
            n_layer=12,
            n_head=18,
            layer_norm_epsilon=1e-5,
            initializer_range=0.02,
            
            batch_size=10,
            sample_batch_size=25,
            epoch=1,
            lr=1e-4,
    ):
        self.total_vocab_size = total_vocab_size
        self.code_vocab_size = code_vocab_size
        self.label_vocab_size = label_vocab_size
        self.lab_vocab_size = lab_vocab_size
        self.categorical_lab_vocab_size = categorical_lab_vocab_size
        self.continuous_lab_vocab_size = continuous_lab_vocab_size
        self.continuous_vocab_size = continuous_vocab_size
        self.special_vocab_size = special_vocab_size
        self.phenotype_labels = phenotype_labels
        self.fairness_weight = fairness_weight
        self.gender_labels = gender_labels
        self.ethnicity_labels = ethnicity_labels
        self.hidden_size = hidden_size
        self.n_positions = n_positions
        self.n_ctx = n_ctx
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.layer_norm_epsilon = layer_norm_epsilon
        self.initializer_range = initializer_range
        self.batch_size = batch_size
        self.sample_batch_size = sample_batch_size
        self.epoch = epoch
        self.lr = lr

# Training the Model

The following cell can be run to train the model. Before running, please create an empty `save/` directory and run the cells below


In [None]:
import random
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

SEED = 4
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.empty_cache()

config = HALOConfig()

local_rank = -1
fp16 = False
if local_rank == -1:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
else:
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)
    n_gpu = 1
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.distributed.init_process_group(backend="nccl")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
print(device)
train_ehr_dataset = pickle.load(open("discretized_data/trainDataset.pkl", "rb"))
val_ehr_dataset = pickle.load(open("discretized_data/valDataset.pkl", "rb"))

# Convert to fully codes
beginPos = pickle.load(open("discretized_data/beginPos.pkl", "rb"))
for p in train_ehr_dataset + val_ehr_dataset:
    new_visits = []
    for v in p["visits"]:
        new_idx = v[0]
        for l, val in zip(v[1], v[2]):
            new_idx.append(config.code_vocab_size + beginPos[l] + val)
        new_idx.append(config.code_vocab_size + config.lab_vocab_size + v[3][-1])
        new_visits.append(new_idx)

    p["visits"] = new_visits

def get_batch(loc, batch_size, mode):
    if mode == "train":
        ehr = train_ehr_dataset[loc : loc + batch_size]
    elif mode == "valid":
        ehr = val_ehr_dataset[loc : loc + batch_size]
    else:
        ehr = test_ehr_dataset[loc : loc + batch_size]

    batch_gender = np.zeros((len(ehr)))
    batch_eth = np.zeros((len(ehr)))
    batch_ehr = np.zeros(
        (len(ehr), config.n_ctx, config.total_vocab_size)
    )  # 3d array len(ehr) * config.n_ctx * config.total_vocab_size
    batch_mask = np.zeros(
        (len(ehr), config.n_ctx, 1)
    )  # 3d array len(ehr) * config.n_ctx * 1
    for i, p in enumerate(ehr):
        visits = p['visits']
        #print(f"Lenght of visits{len(visits)}")
        
        for i, p in enumerate(ehr):
            visits = p['visits']
            for j, v in enumerate(visits):
                try:
                    batch_ehr[i, j+2][v] = 1
                except IndexError:
                # Handle the out-of-bounds index
                    print(f"Warning: Index {v} is out of bounds for batch_ehr[{i}, {j+2}]")
                    continue
                batch_mask[i, j+2] = 1
        batch_ehr[
            i,
            1,
            config.code_vocab_size
            + config.lab_vocab_size
            + config.continuous_vocab_size : config.code_vocab_size
            + config.lab_vocab_size
            + config.continuous_vocab_size
            + config.label_vocab_size,
        ] = np.array(
            p["labels"]
        )  # Set the patient labels

        batch_eth[i] = batch_ehr[
            i,
            1,
            config.code_vocab_size
            + config.lab_vocab_size
            + config.continuous_vocab_size
            + config.label_vocab_size
            - 3,
        ]

        batch_ehr[
            i,
            len(visits) + 1,
            config.code_vocab_size
            + config.lab_vocab_size
            + config.continuous_vocab_size
            + config.label_vocab_size
            + 1,
        ] = 1  # Set the final visit to have the end token
        batch_ehr[
            i,
            len(visits) + 2 :,
            config.code_vocab_size
            + config.lab_vocab_size
            + config.continuous_vocab_size
            + config.label_vocab_size
            + 2,
        ] = 1  # Set the rest to the padded visit token

    batch_mask[:, 1] = 1  # Set the mask to cover the labels
    batch_ehr[
        :,
        0,
        config.code_vocab_size
        + config.lab_vocab_size
        + config.continuous_vocab_size
        + config.label_vocab_size,
    ] = 1  # Set the first visits to be the start token
    batch_mask = batch_mask[
        :, 1:, :
    ]  # Shift the mask to match the shifted labels and predictions the model will return

    return batch_ehr, batch_mask, batch_eth


def shuffle_training_data(train_ehr_dataset):
    np.random.shuffle(train_ehr_dataset)

model = HALOModel(config).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
# if os.path.exists("./save/halo_model"):
#     print("Loading previous model")
#     checkpoint = torch.load("./save/halo_model", map_location=torch.device(device))
#     model.load_state_dict(checkpoint["model"])
#     optimizer.load_state_dict(checkpoint["optimizer"])

# Train Model

global_loss = 1e10
iteration = 0
for e in tqdm(range(config.epoch)):
    shuffle_training_data(train_ehr_dataset)
    for i in range(0, len(train_ehr_dataset), config.batch_size):
        model.train()

        batch_ehr, batch_mask, batch_eth = get_batch(i, config.batch_size, "train")
        batch_eth = torch.tensor(batch_eth, dtype=torch.float32).to(device)
        batch_ehr = torch.tensor(batch_ehr, dtype=torch.float32).to(device)
        batch_mask = torch.tensor(batch_mask, dtype=torch.float32).to(device)
        #print(batch_mask.shape)
        optimizer.zero_grad()
        loss, _, _ = model(
            batch_ehr,
            batch_eth,
            fairness_metrics = None,
            input_ethnicities=None,
            position_ids=None,
            ehr_labels=batch_ehr,
            ehr_masks=batch_mask,
        )
        # print(loss)
        loss.backward()
        optimizer.step()

        # if i % (50 * config.batch_size) == 0:
        #     print("Epoch %d, Iter %d: Training Loss:%.6f" % (e, i, loss))
   
        # if i % (250 * len(train_ehr_dataset)) == 0:  # Condition based on iterations
        #     if i == 0:
        #         continue

        if i % (25*config.batch_size) == 0:
            print("Epoch %d, Iter %d: Training Loss:%.6f"%(e, i, loss))
        if i % (100*config.batch_size) == 0:
            if i == 0:
               continue 
            print("I am entering eval stage")
            model.eval()

            with torch.no_grad():
                val_l = []
                for v_i in range(0, len(val_ehr_dataset), config.batch_size):
                    batch_ehr, batch_mask, batch_eth= get_batch(
                        v_i, config.batch_size, "valid"
                    )
                    batch_eth = torch.tensor(batch_eth, dtype=torch.float32).to(
                        device
                    )
                    # batch_ethnicities = torch.tensor(batch_ethnicities,dtype=torch.float32).to(device)
                    batch_ehr = torch.tensor(batch_ehr, dtype=torch.float32).to(device)
                    batch_mask = torch.tensor(batch_mask, dtype=torch.float32).to(
                        device
                    )

                    val_loss, _, _ = model(
                        batch_ehr,
                        batch_eth,
                        fairness_metrics = None,
                        input_ethnicities=None,
                        position_ids=None,
                        ehr_labels=batch_ehr,
                        ehr_masks=batch_mask,
                    )
                    val_l.append((val_loss).cpu().detach().numpy())

                cur_val_loss = np.mean(val_l)
                print("Epoch %d Validation Loss:%.7f" % (e, cur_val_loss))
                if cur_val_loss < global_loss:
                    global_loss = cur_val_loss
                    state = {
                        "model": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "iteration": i,
                    }
                    torch.save(state, "./save/generator_model")
                    print("\n------------ Save best model ------------\n")

# Generate
Ensure that the path `results/datasets` is created before running the following cell and the 'Convert discrete data to cont.' cell.

Note, if you want a different amount of data rather than the size of the training dataset, set the totEHRs variable on line 138 of the following cell

In [None]:
import json
import torch
import pickle
import random
import numpy as np
from sys import argv
from tqdm import tqdm

from model import HALOModel
from config import HALOConfig

config = HALOConfig()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HALOModel(config).to(device)
checkpoint = torch.load("./save/generator_model", map_location=torch.device(device))
model.load_state_dict(checkpoint["model"])
idxToId = pickle.load(open("discretized_data/idxToId.pkl", "rb"))
idToLab = pickle.load(open("discretized_data/idToLab.pkl", "rb"))
beginPos = pickle.load(open("discretized_data/beginPos.pkl", "rb"))
isCategorical = pickle.load(open("discretized_data/isCategorical.pkl", "rb"))
possible_values = pickle.load(open("discretized_data/possibleValues.pkl", "rb"))
discretization = pickle.load(open("discretized_data/discretization.pkl", "rb"))

def sample_sequence(model, length, context, batch_size, device="cuda", sample=True):
    empty = torch.zeros(
        (1, 1, config.total_vocab_size), device=device, dtype=torch.float32
    ).repeat(batch_size, 1, 1)
    context = (
        torch.tensor(context, device=device, dtype=torch.float32)
        .unsqueeze(0)
        .repeat(batch_size, 1)
    )
    prev = context.unsqueeze(1)
    context = None
    with torch.no_grad():
        for _ in range(length - 1):
            prev = model.sample(torch.cat((prev, empty), dim=1), sample)
            if (
                torch.sum(
                    torch.sum(
                        prev[
                            :, :, config.code_vocab_size + config.label_vocab_size + 1
                        ],
                        dim=1,
                    )
                    .bool()
                    .int(),
                    dim=0,
                ).item()
                == batch_size
            ):
                break
    ehr = prev.cpu().detach().numpy()
    prev = None
    empty = None
    return ehr
def convert_ehr(ehrs, index_to_code=None):
    ehr_outputs = []
    for i in range(len(ehrs)):
        ehr = ehrs[i]
        print(ehr)
        ehr_output = []
        ethnicity_output = ehr[3]
        gender_output = ehr[2]
        labels_output = ehr[1][
            config.code_vocab_size
            + config.lab_vocab_size
            + config.continuous_vocab_size : config.code_vocab_size
            + config.lab_vocab_size
            + config.continuous_vocab_size
            + config.label_vocab_size
        ]
        if index_to_code is not None:
            labels_output = [idToLabel[idx] for idx in np.nonzero(labels_output)[0]]
        
         
        for j in range(2, len(ehr)):
            visit = ehr[j]
            visit_output = []
            lab_mask = []
            lab_values = []
            cont_idx = -1
            indices = np.nonzero(visit)[0]
            end = False
            for idx in indices:
                if idx < config.code_vocab_size:
                    visit_output.append(
                        index_to_code[idx] if index_to_code is not None else idx
                    )
                elif idx < config.code_vocab_size + config.lab_vocab_size:
                    lab_idx = idx - (config.code_vocab_size)
                    lab_num = idxToId[lab_idx]
                    if lab_num in lab_mask:
                        continue
                    else:
                        lab_mask.append(lab_num)
                        lab_values.append(lab_idx - beginPos[lab_num])
                elif (
                    idx
                    < config.code_vocab_size
                    + config.lab_vocab_size
                    + config.continuous_vocab_size
                ):
                    cont_idx = (
                        cont_idx
                        if cont_idx != -1
                        else idx - (config.code_vocab_size + config.lab_vocab_size)
                    )
                elif (
                    idx
                    == config.code_vocab_size
                    + config.lab_vocab_size
                    + config.continuous_vocab_size
                    + config.label_vocab_size
                    + 1
                ):
                    end = True

            if cont_idx == -1:
                cont_idx = random.randint(0, config.continuous_vocab_size) - 1
            if visit_output != [] or lab_mask != []:
                ehr_output.append((visit_output, lab_mask, lab_values, [cont_idx]))
            if end:
                break

        ehr_outputs.append({"visits": ehr_output, "labels": labels_output, "gender":gender_output,"ethnicity":ethnicity_output})
    ehr = None
    ehr_output = None
    labels_output = None
    visit = None
    visit_output = None
    indices = None
    return ehr_outputs
pakEHRs = pickle.load(open("discretized_data/trainDataset.pkl", "rb"))

# Generate Synthetic EHR dataset
# totEHRs = len(pickle.load(open("discretized_data/trainDataset.pkl", "rb")))
totEHRs = 1000
stoken = np.zeros(config.total_vocab_size)
stoken[
    config.code_vocab_size
    + config.lab_vocab_size
    + config.continuous_vocab_size
    + config.label_vocab_size
] = 1
synthetic_ehr_dataset = []
for i in tqdm(range(0, totEHRs, config.sample_batch_size)):
    bs = min([totEHRs - i, config.sample_batch_size])
    batch_synthetic_ehrs = sample_sequence(
        model, config.n_ctx, stoken, batch_size=bs, device=device, sample=True
    )
    batch_synthetic_ehrs = convert_ehr(batch_synthetic_ehrs)
    synthetic_ehr_dataset += batch_synthetic_ehrs

In [None]:
pickle.dump(synthetic_ehr_dataset, open(f"./results/datasets/haloDataset_eth_1000.pkl", "wb"))

# Convert discrete data to continuous

In [None]:
import pickle
import random

idToLab = pickle.load(open("discretized_data/idToLab.pkl", "rb"))
isCategorical = pickle.load(open("discretized_data/isCategorical.pkl", "rb"))
discretization = pickle.load(open("discretized_data/discretization.pkl", "rb"))
possibleValues = pickle.load(open("discretized_data/possibleValues.pkl", "rb"))
discretization = pickle.load(open("discretized_data/discretization.pkl", "rb"))
formatMap = pickle.load(open("discretized_data/formatMap.pkl", "rb"))

dataset = pickle.load(open("./results/datasets/haloDataset_feedback_2000.pkl", "rb"))


def formatCont(value, key):
    return formatMap[key][1](("{:" + formatMap[key][0] + "}").format(value))


for p in dataset:
    new_visits = []
    firstVisit = True
    for v in p["visits"]:
        new_labs = []
        new_values = []
        for i in range(len(v[1])):
            new_labs.append(idToLab[v[1][i]])
            if isCategorical[idToLab[v[1][i]]]:
                new_values.append(possibleValues[idToLab[v[1][i]]][v[2][i]])
            else:
                new_values.append(
                    formatCont(
                        random.uniform(
                            discretization[idToLab[v[1][i]]][v[2][i]],
                            discretization[idToLab[v[1][i]]][v[2][i] + 1],
                        ),
                        idToLab[v[1][i]],
                    )
                )
        contType = "Hours" if new_labs != [] else "Age" if firstVisit else "Days"
        if contType == "Age":
            firstVisit = False
        new_cont = formatCont(
            random.uniform(
                discretization[contType][v[3][-1]],
                discretization[contType][v[3][-1] + 1],
            ),
            contType,
        )
        new_visits.append((v[0], new_labs, new_values, [new_cont]))
    p["visits"] = new_visits

pickle.dump(dataset, open("results/datasets/haloDataset_converted_feedback_2000.pkl", "wb"))

# Training the prediction model and calculating the fairness metrics

The following cells are used for running the experiments described in the paper. Pyhealth library has been used.

In [None]:
#5000 real + 1000 synth
import pickle
import random
import numpy as np
from pyhealth.datasets import SampleEHRDataset, get_dataloader, split_by_patient
from pyhealth.datasets.splitter import split_by_patient

all_data_5000 = pickle.load(open(f"./data/allData_5000.pkl", "rb"))
halo_ehr_dataset = pickle.load(open(f"./results/datasets/haloDataset_converted_2000.pkl", "rb"))

#all_data_2500 = random.sample(all_data_5000,2500)
halo_1000 = random.sample(halo_ehr_dataset, 1000)

combined_data = (
    all_data_5000 + halo_1000
)

for patient in combined_data:
    patient["labels"] = [int(float(label)) for label in patient["labels"].tolist()]

def transform_data(ehr_dataset):
    final_data = []
    patient_id = 0  # Starting patient ID

    for patient in ehr_dataset:
        for i, visit in enumerate(patient["visits"]):
            visit_data = {
                "visit_id": i,
                "patient_id": patient_id,
                "visit_codes": [[int(x) for x in visit[0]]],
                "gender": [[int(patient["labels"][26])]],
                "ethnicity": [[int(patient["labels"][25])]],
                "disease_label": [[int(x) for x in patient["labels"][0:25]]],
                "label": int(patient["labels"][27]),
            }
            final_data.append(visit_data)
        patient_id += 1
    return final_data


def calculate_wtpr(y_true, y_prob, sensitive_attribute, threshold=0.5):
    y_pred = (y_prob >= threshold).astype(int)
    subgroups = np.unique(sensitive_attribute)
    tpr_scores = {}

    for subgroup in subgroups:
        subgroup_mask = sensitive_attribute == subgroup
        y_true_subgroup = y_true[subgroup_mask]
        y_pred_subgroup = y_pred[subgroup_mask]

        confusion_mat = confusion_matrix(y_true_subgroup, y_pred_subgroup)

        if confusion_mat.size == 1:
            if y_true_subgroup[0] == 1:
                tp = confusion_mat[0, 0]
                fn = 0
            else:
                tp = 0
                fn = confusion_mat[0, 0]
            tn = fp = 0
        else:
            tn, fp, fn, tp = confusion_mat.ravel()

        tpr = tp / (tp + fn) if (tp + fn) != 0 else 0
        tpr_scores[subgroup] = tpr

    wtpr = np.mean(list(tpr_scores.values()))
    return wtpr, tpr_scores


random.shuffle(combined_data)

# Calculate the split indices
total_length = len(combined_data)
train_split = int(0.8 * total_length)
val_split = int(0.9 * total_length)

# Split the combined list into train, validation, and test sets
train_ehr_data = combined_data[:train_split]
val_ehr_data = combined_data[train_split:val_split]
test_ehr_data = combined_data[val_split:]

transformed_train_ehr_dataset = transform_data(train_ehr_data)
transformed_val_ehr_dataset = transform_data(val_ehr_data)
transformed_test_ehr_dataset = transform_data(test_ehr_data)
transformed_combined_ehr_dataset = transform_data(combined_data)

max_visit_codes_length = max(
    len(sample["visit_codes"][0]) for sample in transformed_combined_ehr_dataset
)
for sample in transformed_train_ehr_dataset:
    visit_codes = sample["visit_codes"][0]
    padded_visit_codes = np.pad(
        visit_codes, (0, max_visit_codes_length - len(visit_codes)), mode="constant"
    )
    sample["visit_codes"][0] = padded_visit_codes.tolist()
for sample in transformed_val_ehr_dataset:
    visit_codes = sample["visit_codes"][0]
    padded_visit_codes = np.pad(
        visit_codes, (0, max_visit_codes_length - len(visit_codes)), mode="constant"
    )
    sample["visit_codes"][0] = padded_visit_codes.tolist()
for sample in transformed_test_ehr_dataset:
    visit_codes = sample["visit_codes"][0]
    padded_visit_codes = np.pad(
        visit_codes, (0, max_visit_codes_length - len(visit_codes)), mode="constant"
    )
    sample["visit_codes"][0] = padded_visit_codes.tolist()
for sample in transformed_combined_ehr_dataset:
    visit_codes = sample["visit_codes"][0]
    padded_visit_codes = np.pad(
        visit_codes, (0, max_visit_codes_length - len(visit_codes)), mode="constant"
    )
    sample["visit_codes"][0] = padded_visit_codes.tolist()

formatted_train_ehr_dataset = SampleEHRDataset(samples=transformed_train_ehr_dataset)
formatted_val_ehr_dataset = SampleEHRDataset(samples=transformed_val_ehr_dataset)
formatted_test_ehr_dataset = SampleEHRDataset(samples=transformed_test_ehr_dataset)
formatted_combined_ehr_dataset = SampleEHRDataset(
    samples=transformed_combined_ehr_dataset
)

# train_loader = get_dataloader(formatted_train_ehr_dataset, batch_size=64, shuffle=True)
# val_loader = get_dataloader(formatted_val_ehr_dataset, batch_size=64, shuffle=False)
# test_loader = get_dataloader(formatted_test_ehr_dataset, batch_size=64, shuffle=False)

k = 5  # Number of folds
fairness_scores = {
    "disparate_impact": [],
    "statistical_parity_difference": [],
    "wtpr": [],
}

kf = KFold(n_splits=k, shuffle=True, random_state=42)

for train_index, val_index in kf.split(formatted_combined_ehr_dataset):
    fold_train_dataset = [formatted_combined_ehr_dataset[i] for i in train_index]
    fold_val_dataset = [formatted_combined_ehr_dataset[i] for i in val_index]

    transformermodel = Transformer(
        dataset=formatted_train_ehr_dataset,
        # look up what are available for "feature_keys" and "label_keys" in dataset.samples[0]
        feature_keys=["visit_codes", "disease_label", "ethnicity", "gender"],
        label_key="label",
        mode="binary",
    )

    train_loader = get_dataloader(fold_train_dataset, batch_size=64, shuffle=True)
    val_loader = get_dataloader(fold_val_dataset, batch_size=64, shuffle=True)

    trainer = Trainer(model=transformermodel)
    trainer.train(
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        epochs=30,
        monitor="pr_auc",
    )

    y_true, y_prob, loss = trainer.inference(val_loader)

    # protected_group = 1  # female in gender
    # # Prepare the sensitive attribute array for the validation set
    # sensitive_attribute_array = np.zeros(len(fold_val_dataset), dtype=int)
    # for idx, visit in enumerate(fold_val_dataset):
    #     sensitive_attribute_value = visit["gender"][0][0]
    #     if sensitive_attribute_value == protected_group:
    #         sensitive_attribute_array[idx] = 1

    #Prepare the sensitive attribute array for the validation set
 
    unprotected_group = 0  # white in eth
    sensitive_attribute_array= np.zeros(len(fold_val_dataset), dtype=int)
    for idx, visit in enumerate(fold_val_dataset):
            sensitive_attribute_value = visit["ethnicity"][0][0]
            if sensitive_attribute_value != unprotected_group:
                sensitive_attribute_array[idx] = 1

    
    # Calculate fairness metrics for the current fold
    fold_fairness_metrics = fairness_metrics_fn(
        y_true,
        y_prob,
        sensitive_attributes=sensitive_attribute_array,
        favorable_outcome=1,
        metrics=None,
        threshold=0.5,
    )
    wtpr = calculate_wtpr(y_true, y_prob, sensitive_attribute_array)

    # Append the fairness metrics for the current fold
    fairness_scores["disparate_impact"].append(
        fold_fairness_metrics["disparate_impact"]
    )
    fairness_scores["statistical_parity_difference"].append(
        fold_fairness_metrics["statistical_parity_difference"]
    )
    fairness_scores["wtpr"].append(wtpr)

# Training with feedback-generated metrics

In [None]:
import random
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

SEED = 4
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.empty_cache()

config = HALOConfig()

local_rank = -1
fp16 = False
if local_rank == -1:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
else:
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)
    n_gpu = 1
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.distributed.init_process_group(backend="nccl")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
print(device)
train_ehr_dataset = pickle.load(open("discretized_data/trainDataset.pkl", "rb"))
val_ehr_dataset = pickle.load(open("discretized_data/valDataset.pkl", "rb"))

# Convert to fully codes
beginPos = pickle.load(open("discretized_data/beginPos.pkl", "rb"))
for p in train_ehr_dataset + val_ehr_dataset:
    new_visits = []
    for v in p["visits"]:
        new_idx = v[0]
        for l, val in zip(v[1], v[2]):
            new_idx.append(config.code_vocab_size + beginPos[l] + val)
        new_idx.append(config.code_vocab_size + config.lab_vocab_size + v[3][-1])
        new_visits.append(new_idx)

    p["visits"] = new_visits


def get_batch(loc, batch_size, mode):
    if mode == "train":
        ehr = train_ehr_dataset[loc : loc + batch_size]
    elif mode == "valid":
        ehr = val_ehr_dataset[loc : loc + batch_size]
    else:
        ehr = test_ehr_dataset[loc : loc + batch_size]

    batch_gender = np.zeros((len(ehr)))
    batch_ehr = np.zeros(
        (len(ehr), config.n_ctx, config.total_vocab_size)
    )  # 3d array len(ehr) * config.n_ctx * config.total_vocab_size
    batch_mask = np.zeros(
        (len(ehr), config.n_ctx, 1)
    )  # 3d array len(ehr) * config.n_ctx * 1
    for i, p in enumerate(ehr):
        visits = p["visits"]
        for j, v in enumerate(visits):
            batch_ehr[i, j + 2][v] = 1
            batch_mask[i, j + 2] = 1
        batch_ehr[
            i,
            1,
            config.code_vocab_size
            + config.lab_vocab_size
            + config.continuous_vocab_size : config.code_vocab_size
            + config.lab_vocab_size
            + config.continuous_vocab_size
            + config.label_vocab_size,
        ] = np.array(
            p["labels"]
        )  # Set the patient labels

        batch_gender[i] = batch_ehr[
            i,
            1,
            config.code_vocab_size
            + config.lab_vocab_size
            + config.continuous_vocab_size
            + config.label_vocab_size
            - 2,
        ]
        # batch_ethnicities[i] = batch_ehr[i,1,config.code_vocab_size + config.lab_vocab_size + config.continuous_vocab_size + config.label_vocab_size - 3]

        batch_ehr[
            i,
            len(visits) + 1,
            config.code_vocab_size
            + config.lab_vocab_size
            + config.continuous_vocab_size
            + config.label_vocab_size
            + 1,
        ] = 1  # Set the final visit to have the end token
        batch_ehr[
            i,
            len(visits) + 2 :,
            config.code_vocab_size
            + config.lab_vocab_size
            + config.continuous_vocab_size
            + config.label_vocab_size
            + 2,
        ] = 1  # Set the rest to the padded visit token

    batch_mask[:, 1] = 1  # Set the mask to cover the labels
    batch_ehr[
        :,
        0,
        config.code_vocab_size
        + config.lab_vocab_size
        + config.continuous_vocab_size
        + config.label_vocab_size,
    ] = 1  # Set the first visits to be the start token
    batch_mask = batch_mask[
        :, 1:, :
    ]  # Shift the mask to match the shifted labels and predictions the model will return

    return batch_ehr, batch_mask, batch_gender


def shuffle_training_data(train_ehr_dataset):
    np.random.shuffle(train_ehr_dataset)


model = HALOModel(config).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
# if os.path.exists("./save/halo_model"):
#     print("Loading previous model")
#     checkpoint = torch.load("./save/halo_model", map_location=torch.device(device))
#     model.load_state_dict(checkpoint["model"])
#     optimizer.load_state_dict(checkpoint["optimizer"])


# Train Model
global_loss = 1e10
for e in tqdm(range(config.epoch)):
    shuffle_training_data(train_ehr_dataset)
    for i in range(0, len(train_ehr_dataset), config.batch_size):
        model.train()

        batch_ehr, batch_mask, batch_gender = get_batch(i, config.batch_size, "train")
        batch_gender = torch.tensor(batch_gender, dtype=torch.float32).to(device)
        batch_ehr = torch.tensor(batch_ehr, dtype=torch.float32).to(device)
        batch_mask = torch.tensor(batch_mask, dtype=torch.float32).to(device)
        #print(batch_mask.shape)
        optimizer.zero_grad()
        loss, _, _ = model(
            batch_ehr,
            batch_gender,
            fairness_metrics,
            input_ethnicities=None,
            position_ids=None,
            ehr_labels=batch_ehr,
            ehr_masks=batch_mask,
        )
        # print(loss)
        loss.backward()
        optimizer.step()

        if i % (50 * config.batch_size) == 0:
            print("Epoch %d, Iter %d: Training Loss:%.6f" % (e, i, loss))
        if i % (250 * config.batch_size) == 0:
            if i == 0:
                continue

            model.eval()

            with torch.no_grad():
                val_l = []
                for v_i in range(0, len(val_ehr_dataset), config.batch_size):
                    batch_ehr, batch_mask, batch_gender = get_batch(
                        i, config.batch_size, "train"
                    )
                    batch_gender = torch.tensor(batch_gender, dtype=torch.float32).to(
                        device
                    )
                    # batch_ethnicities = torch.tensor(batch_ethnicities,dtype=torch.float32).to(device)
                    batch_ehr = torch.tensor(batch_ehr, dtype=torch.float32).to(device)
                    batch_mask = torch.tensor(batch_mask, dtype=torch.float32).to(
                        device
                    )

                    val_loss, _, _ = model(
                        batch_ehr,
                        batch_gender,
                        fairness_metrics,
                        input_ethnicities=None,
                        position_ids=None,
                        ehr_labels=batch_ehr,
                        ehr_masks=batch_mask,
                    )
                    val_l.append((val_loss).cpu().detach().numpy())

                cur_val_loss = np.mean(val_l)
                print("Epoch %d Validation Loss:%.7f" % (e, cur_val_loss))
                if cur_val_loss < global_loss:
                    global_loss = cur_val_loss
                    state = {
                        "model": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "iteration": i,
                    }
                    torch.save(state, "./save/generated_model_from_downstream")
                    print("\n------------ Save best model ------------\n")

In [None]:
import pickle

config = HALOConfig()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HALOModel(config).to(device)
checkpoint = torch.load("./save/generated_model_from_downstream", map_location=torch.device(device))
model.load_state_dict(checkpoint["model"])
idxToId = pickle.load(open("discretized_data/idxToId.pkl", "rb"))
idToLab = pickle.load(open("discretized_data/idToLab.pkl", "rb"))
beginPos = pickle.load(open("discretized_data/beginPos.pkl", "rb"))
isCategorical = pickle.load(open("discretized_data/isCategorical.pkl", "rb"))
possible_values = pickle.load(open("discretized_data/possibleValues.pkl", "rb"))
discretization = pickle.load(open("discretized_data/discretization.pkl", "rb"))

def sample_sequence(model, length, context, batch_size, device="cuda", sample=True):
    empty = torch.zeros(
        (1, 1, config.total_vocab_size), device=device, dtype=torch.float32
    ).repeat(batch_size, 1, 1)
    context = (
        torch.tensor(context, device=device, dtype=torch.float32)
        .unsqueeze(0)
        .repeat(batch_size, 1)
    )
    prev = context.unsqueeze(1)
    context = None
    with torch.no_grad():
        for _ in range(length - 1):
            prev = model.sample(torch.cat((prev, empty), dim=1), sample)
            if (
                torch.sum(
                    torch.sum(
                        prev[
                            :, :, config.code_vocab_size + config.label_vocab_size + 1
                        ],
                        dim=1,
                    )
                    .bool()
                    .int(),
                    dim=0,
                ).item()
                == batch_size
            ):
                break
    ehr = prev.cpu().detach().numpy()
    prev = None
    empty = None
    return ehr
def convert_ehr(ehrs, index_to_code=None):
    ehr_outputs = []
    for i in range(len(ehrs)):
        ehr = ehrs[i]
        print(ehr)
        ehr_output = []
        ethnicity_output = ehr[3]
        gender_output = ehr[2]
        labels_output = ehr[1][
            config.code_vocab_size
            + config.lab_vocab_size
            + config.continuous_vocab_size : config.code_vocab_size
            + config.lab_vocab_size
            + config.continuous_vocab_size
            + config.label_vocab_size
        ]
        if index_to_code is not None:
            labels_output = [idToLabel[idx] for idx in np.nonzero(labels_output)[0]]
        
         
        for j in range(2, len(ehr)):
            visit = ehr[j]
            visit_output = []
            lab_mask = []
            lab_values = []
            cont_idx = -1
            indices = np.nonzero(visit)[0]
            end = False
            for idx in indices:
                if idx < config.code_vocab_size:
                    visit_output.append(
                        index_to_code[idx] if index_to_code is not None else idx
                    )
                elif idx < config.code_vocab_size + config.lab_vocab_size:
                    lab_idx = idx - (config.code_vocab_size)
                    lab_num = idxToId[lab_idx]
                    if lab_num in lab_mask:
                        continue
                    else:
                        lab_mask.append(lab_num)
                        lab_values.append(lab_idx - beginPos[lab_num])
                elif (
                    idx
                    < config.code_vocab_size
                    + config.lab_vocab_size
                    + config.continuous_vocab_size
                ):
                    cont_idx = (
                        cont_idx
                        if cont_idx != -1
                        else idx - (config.code_vocab_size + config.lab_vocab_size)
                    )
                elif (
                    idx
                    == config.code_vocab_size
                    + config.lab_vocab_size
                    + config.continuous_vocab_size
                    + config.label_vocab_size
                    + 1
                ):
                    end = True

            if cont_idx == -1:
                cont_idx = random.randint(0, config.continuous_vocab_size) - 1
            if visit_output != [] or lab_mask != []:
                ehr_output.append((visit_output, lab_mask, lab_values, [cont_idx]))
            if end:
                break

        ehr_outputs.append({"visits": ehr_output, "labels": labels_output, "gender":gender_output,"ethnicity":ethnicity_output})
    ehr = None
    ehr_output = None
    labels_output = None
    visit = None
    visit_output = None
    indices = None
    return ehr_outputs
pakEHRs = pickle.load(open("discretized_data/trainDataset.pkl", "rb"))

# Generate Synthetic EHR dataset
# totEHRs = len(pickle.load(open("discretized_data/trainDataset.pkl", "rb")))
totEHRs = 2000
stoken = np.zeros(config.total_vocab_size)
stoken[
    config.code_vocab_size
    + config.lab_vocab_size
    + config.continuous_vocab_size
    + config.label_vocab_size
] = 1
synthetic_ehr_dataset = []
for i in tqdm(range(0, totEHRs, config.sample_batch_size)):
    bs = min([totEHRs - i, config.sample_batch_size])
    batch_synthetic_ehrs = sample_sequence(
        model, config.n_ctx, stoken, batch_size=bs, device=device, sample=True
    )
    batch_synthetic_ehrs = convert_ehr(batch_synthetic_ehrs)
    synthetic_ehr_dataset += batch_synthetic_ehrs
pickle.dump(synthetic_ehr_dataset, open(f"./results/datasets/haloDataset_downstream_optimized.pkl", "wb"))

In [None]:
import pickle

train_ehr_dataset = pickle.load(open(f"./data/trainData.pkl", "rb"))
val_ehr_dataset = pickle.load(open(f"./data/valData.pkl", "rb"))
test_ehr_dataset = pickle.load(open(f"./data/testData.pkl", "rb"))
halo_ehr_dataset = pickle.load(
    open(f"./results/datasets/haloDataset_downstream_optimized.pkl", "rb")
)

combined_data = (
    train_ehr_dataset + val_ehr_dataset + test_ehr_dataset + halo_ehr_dataset
)

for patient in combined_data:
    patient["labels"] = [int(float(label)) for label in patient["labels"].tolist()]

def transform_data(ehr_dataset):
    final_data = []
    patient_id = 0  # Starting patient ID

    for patient in ehr_dataset:
        for i, visit in enumerate(patient["visits"]):
            visit_data = {
                "visit_id": i,
                "patient_id": patient_id,
                "visit_codes": [[int(x) for x in visit[0]]],
                "gender": [[int(patient["labels"][26])]],
                "ethnicity": [[int(patient["labels"][25])]],
                "disease_label": [[int(x) for x in patient["labels"][0:25]]],
                "label": int(patient["labels"][27]),
            }
            final_data.append(visit_data)
        patient_id += 1
    return final_data


def calculate_wtpr(y_true, y_pred, sensitive_attribute):
    subgroups = np.unique(sensitive_attribute)
    tpr_scores = {}
    y_pred = (y_prob > 0.5).astype(int)
    for subgroup in subgroups:
        subgroup_mask = sensitive_attribute == subgroup
        y_true_subgroup = y_true[subgroup_mask]
        y_pred_subgroup = y_pred[subgroup_mask]

        tn, fp, fn, tp = confusion_matrix(y_true_subgroup, y_pred_subgroup).ravel()
        tpr = tp / (tp + fn)
        tpr_scores[subgroup] = tpr

    wtpr = min(tpr_scores.values())

    print("TPR scores for each subgroup:")
    for subgroup, tpr in tpr_scores.items():
        print(f"Subgroup {subgroup}: TPR = {tpr:.3f}")

    return wtpr
import random

import numpy as np
from pyhealth.datasets import SampleEHRDataset, get_dataloader, split_by_patient
from pyhealth.datasets.splitter import split_by_patient

random.shuffle(combined_data)

# Calculate the split indices
total_length = len(combined_data)
train_split = int(0.8 * total_length)
val_split = int(0.9 * total_length)

# Split the combined list into train, validation, and test sets
train_ehr_data = combined_data[:train_split]
val_ehr_data = combined_data[train_split:val_split]
test_ehr_data = combined_data[val_split:]

transformed_train_ehr_dataset = transform_data(train_ehr_data)
transformed_val_ehr_dataset = transform_data(val_ehr_data)
transformed_test_ehr_dataset = transform_data(test_ehr_data)
transformed_combined_ehr_dataset = transform_data(combined_data)

max_visit_codes_length = max(
    len(sample["visit_codes"][0]) for sample in transformed_combined_ehr_dataset
)
for sample in transformed_train_ehr_dataset:
    visit_codes = sample["visit_codes"][0]
    padded_visit_codes = np.pad(
        visit_codes, (0, max_visit_codes_length - len(visit_codes)), mode="constant"
    )
    sample["visit_codes"][0] = padded_visit_codes.tolist()
for sample in transformed_val_ehr_dataset:
    visit_codes = sample["visit_codes"][0]
    padded_visit_codes = np.pad(
        visit_codes, (0, max_visit_codes_length - len(visit_codes)), mode="constant"
    )
    sample["visit_codes"][0] = padded_visit_codes.tolist()
for sample in transformed_test_ehr_dataset:
    visit_codes = sample["visit_codes"][0]
    padded_visit_codes = np.pad(
        visit_codes, (0, max_visit_codes_length - len(visit_codes)), mode="constant"
    )
    sample["visit_codes"][0] = padded_visit_codes.tolist()
for sample in transformed_combined_ehr_dataset:
    visit_codes = sample["visit_codes"][0]
    padded_visit_codes = np.pad(
        visit_codes, (0, max_visit_codes_length - len(visit_codes)), mode="constant"
    )
    sample["visit_codes"][0] = padded_visit_codes.tolist()

formatted_train_ehr_dataset = SampleEHRDataset(samples=transformed_train_ehr_dataset)
formatted_val_ehr_dataset = SampleEHRDataset(samples=transformed_val_ehr_dataset)
formatted_test_ehr_dataset = SampleEHRDataset(samples=transformed_test_ehr_dataset)
formatted_combined_ehr_dataset = SampleEHRDataset(
    samples=transformed_combined_ehr_dataset
)

In [None]:
import numpy as np
from pyhealth.datasets import SampleEHRDataset, get_dataloader
from pyhealth.metrics.binary import binary_metrics_fn
from pyhealth.metrics.fairness import fairness_metrics_fn
from pyhealth.models import RNN, Transformer
from pyhealth.trainer import Trainer
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import KFold

k = 5  # Number of folds
fairness_scores = {
    "disparate_impact": [],
    "statistical_parity_difference": [],
    "wtpr": [],
}

kf = KFold(n_splits=k, shuffle=True, random_state=42)

for train_index, val_index in kf.split(formatted_combined_ehr_dataset):
    fold_train_dataset = [formatted_combined_ehr_dataset[i] for i in train_index]
    fold_val_dataset = [formatted_combined_ehr_dataset[i] for i in val_index]

    transformermodel = Transformer(
        dataset=formatted_train_ehr_dataset,
        # look up what are available for "feature_keys" and "label_keys" in dataset.samples[0]
        feature_keys=["visit_codes", "disease_label", "ethnicity", "gender"],
        label_key="label",
        mode="binary",
    )

    train_loader = get_dataloader(fold_train_dataset, batch_size=64, shuffle=True)
    val_loader = get_dataloader(fold_val_dataset, batch_size=64, shuffle=True)

    trainer = Trainer(model=transformermodel)
    trainer.train(
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        epochs=50,
        monitor="pr_auc",
    )

    y_true, y_prob, loss = trainer.inference(val_loader)

    protected_group = 1  # female in gender
    # Prepare the sensitive attribute array for the validation set
    sensitive_attribute_array = np.zeros(len(fold_val_dataset), dtype=int)
    for idx, visit in enumerate(fold_val_dataset):
        sensitive_attribute_value = visit["gender"][0][0]
        if sensitive_attribute_value == protected_group:
            sensitive_attribute_array[idx] = 1

    # Calculate fairness metrics for the current fold
    fold_fairness_metrics = fairness_metrics_fn(
        y_true,
        y_prob,
        sensitive_attributes=sensitive_attribute_array,
        favorable_outcome=1,
        metrics=None,
        threshold=0.5,
    )
    wtpr = calculate_wtpr(y_true, y_prob, sensitive_attribute_array)

    # Append the fairness metrics for the current fold
    fairness_scores["disparate_impact"].append(
        fold_fairness_metrics["disparate_impact"]
    )
    fairness_scores["statistical_parity_difference"].append(
        fold_fairness_metrics["statistical_parity_difference"]
    )
    fairness_scores["wtpr"].append(wtpr)

# Calculate the mean and standard deviation for each fairness metric
for metric, scores in fairness_scores.items():
    mean = np.mean(scores)
    std = np.std(scores)
    print(f"{metric}: Mean = {mean:.4f}, Std = {std:.4f}")

In [None]:
from pyhealth.models import RNN

k = 5  # Number of folds
fairness_scores = []

kf = KFold(n_splits=k, shuffle=True, random_state=42)
fairness_scores = {
    "disparate_impact": [],
    "statistical_parity_difference": [],
    "wtpr": [],
}

for train_index, val_index in kf.split(formatted_combined_ehr_dataset):
    fold_train_dataset = [formatted_combined_ehr_dataset[i] for i in train_index]
    fold_val_dataset = [formatted_combined_ehr_dataset[i] for i in val_index]

    rnnModel = RNN(
        dataset=formatted_train_ehr_dataset,
        # look up what are available for "feature_keys" and "label_keys" in dataset.samples[0]
        feature_keys=["visit_codes", "disease_label", "ethnicity", "gender"],
        label_key="label",
        mode="binary",
    )

    train_loader = get_dataloader(fold_train_dataset, batch_size=64, shuffle=True)
    val_loader = get_dataloader(fold_val_dataset, batch_size=64, shuffle=True)

    trainer = Trainer(model=rnnModel)
    trainer.train(
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        epochs=50,
        monitor="pr_auc",
    )

    y_true, y_prob, loss = trainer.inference(val_loader)

    protected_group = 1  # female in gender
    # Prepare the sensitive attribute array for the validation set
    sensitive_attribute_array = np.zeros(len(fold_val_dataset), dtype=int)
    for idx, visit in enumerate(fold_val_dataset):
        sensitive_attribute_value = visit["gender"][0][0]
        if sensitive_attribute_value == protected_group:
            sensitive_attribute_array[idx] = 1

    # Calculate fairness metrics for the current fold
    fold_fairness_metrics = fairness_metrics_fn(
        y_true,
        y_prob,
        sensitive_attributes=sensitive_attribute_array,
        favorable_outcome=1,
        metrics=None,
        threshold=0.5,
    )
    wtpr = calculate_wtpr(y_true, y_prob, sensitive_attribute_array)

    # Append the fairness metrics for the current fold
    fairness_scores["disparate_impact"].append(
        fold_fairness_metrics["disparate_impact"]
    )
    fairness_scores["statistical_parity_difference"].append(
        fold_fairness_metrics["statistical_parity_difference"]
    )
    fairness_scores["wtpr"].append(wtpr)
    print(fairness_scores)

# Calculate the mean and standard deviation for each fairness metric
for metric, scores in fairness_scores.items():
    mean = np.mean(scores)
    std = np.std(scores)
    print(f"{metric}: Mean = {mean:.4f}, Std = {std:.4f}")

# Real Data +Double Prioritized Bias Correction

In [None]:
import pandas as pd
import pickle

In [None]:
allData_5000 = pickle.load(open("data\\allData_5000.pkl", "rb"))

In [None]:
# Initialize a dictionary to store the counts of each ethnicity
ethnicity_counts = {}

# Iterate over each patient visit dictionary in allData_5000
for patient_visit in allData_5000:
    # Extract the ethnicity value from the labels list
    ethnicity = patient_visit['labels'][25]  # Assuming ethnicity is at index 25
    
    # Update the count for the corresponding ethnicity
    if ethnicity in ethnicity_counts:
        ethnicity_counts[ethnicity] += 1
    else:
        ethnicity_counts[ethnicity] = 1

# Print the counts of each ethnicity
for ethnicity, count in ethnicity_counts.items():
    print(f"Ethnicity {ethnicity}: {count} patients")

In [None]:
# Calculate the number of additional patients needed for ethnicity 2 and 3
additional_patients_2 = 100
additional_patients_3 = 500
additional_patients_4 = 400

oversampled_1000 = []
# Duplicate patient visits with ethnicity 2 or 3 until the desired count is reached
while additional_patients_2 > 0 or additional_patients_3 > 0 or additional_patients_4 > 0:
    patient_visit = random.choice(allData_5000)
    ethnicity = patient_visit['labels'][25]  # Assuming ethnicity is at index 25
    
    if ethnicity == '3.0' and additional_patients_3 > 0:
        # Duplicate the patient visit and add it to allData_5000
        oversampled_1000.append(patient_visit)
        additional_patients_3 -= 1
        print(f"add_pat_3: {additional_patients_3}")
    elif ethnicity == '4.0' and additional_patients_4 > 0:
        # Duplicate the patient visit and add it to allData_5000
        oversampled_1000.append(patient_visit)
        additional_patients_4 -= 1
        print(f"add_pat_4: {additional_patients_4}")
    elif ethnicity == '2.0' and additional_patients_2 > 0:
        oversampled_1000.append(patient_visit)
        additional_patients_2 -= 1
        print(f"add_pat_2: {additional_patients_2}")

In [None]:
# Print the updated counts of each ethnicity
ethnicity_counts = {}
for patient_visit in oversampled_1000:
    ethnicity = patient_visit['labels'][25]  # Assuming ethnicity is at index 25
    
    if ethnicity in ethnicity_counts:
        ethnicity_counts[ethnicity] += 1
    else:
        ethnicity_counts[ethnicity] = 1

for ethnicity, count in ethnicity_counts.items():
    print(f"Ethnicity {ethnicity}: {count} patients")
print(f"TOTAL: {len(oversampled_1000)}")

# Training Prediction Model with the Oversampled Data

In [None]:
import pickle
import random
import numpy as np
from pyhealth.datasets import SampleEHRDataset, get_dataloader, split_by_patient
from pyhealth.datasets.splitter import split_by_patient

def transform_data(ehr_dataset):
    final_data = []
    patient_id = 0  # Starting patient ID

    for patient in ehr_dataset:
        for i, visit in enumerate(patient["visits"]):
            visit_data = {
                "visit_id": i,
                "patient_id": patient_id,
                "visit_codes": [[int(x) for x in visit[0]]],
                "gender": [[int(patient["labels"][26])]],
                "ethnicity": [[int(patient["labels"][25])]],
                "disease_label": [[int(x) for x in patient["labels"][0:25]]],
                "label": int(patient["labels"][27]),
            }
            final_data.append(visit_data)
        patient_id += 1
    return final_data


def calculate_wtpr(y_true, y_prob, sensitive_attribute, threshold=0.5):
    y_pred = (y_prob >= threshold).astype(int)
    subgroups = np.unique(sensitive_attribute)
    tpr_scores = {}

    for subgroup in subgroups:
        subgroup_mask = sensitive_attribute == subgroup
        y_true_subgroup = y_true[subgroup_mask]
        y_pred_subgroup = y_pred[subgroup_mask]

        confusion_mat = confusion_matrix(y_true_subgroup, y_pred_subgroup)

        if confusion_mat.size == 1:
            if y_true_subgroup[0] == 1:
                tp = confusion_mat[0, 0]
                fn = 0
            else:
                tp = 0
                fn = confusion_mat[0, 0]
            tn = fp = 0
        else:
            tn, fp, fn, tp = confusion_mat.ravel()

        tpr = tp / (tp + fn) if (tp + fn) != 0 else 0
        tpr_scores[subgroup] = tpr

    wtpr = np.mean(list(tpr_scores.values()))
    return wtpr

In [None]:
pickle.dump(oversampled_1000, open(f"./results/datasets/oversampled_1000.pkl", "wb"))

In [None]:
#5000 real + 1000 synth
all_data_5000 = pickle.load(open(f"./data/allData_5000.pkl", "rb"))

allData_1000 = random.sample(all_data_5000, 1000)
allData_2500 = random.sample(all_data_5000, 2500)

In [None]:
real_datasets = [allData_2500]
for real_dataset in real_datasets:
    combined_data = (real_dataset + oversampled_1000)
    
    for patient in combined_data:
        patient["labels"] = [int(float(label)) for label in patient["labels"]]#.tolist()]
    
    random.shuffle(combined_data)
    
    # Calculate the split indices
    total_length = len(combined_data)
    train_split = int(0.8 * total_length)
    val_split = int(0.9 * total_length)
    
    # Split the combined list into train, validation, and test sets
    train_ehr_data = combined_data[:train_split]
    val_ehr_data = combined_data[train_split:val_split]
    test_ehr_data = combined_data[val_split:]
    
    transformed_train_ehr_dataset = transform_data(train_ehr_data)
    transformed_val_ehr_dataset = transform_data(val_ehr_data)
    transformed_test_ehr_dataset = transform_data(test_ehr_data)
    transformed_combined_ehr_dataset = transform_data(combined_data)
    
    max_visit_codes_length = max(
        len(sample["visit_codes"][0]) for sample in transformed_combined_ehr_dataset
    )
    for sample in transformed_train_ehr_dataset:
        visit_codes = sample["visit_codes"][0]
        padded_visit_codes = np.pad(
            visit_codes, (0, max_visit_codes_length - len(visit_codes)), mode="constant"
        )
        sample["visit_codes"][0] = padded_visit_codes.tolist()
    for sample in transformed_val_ehr_dataset:
        visit_codes = sample["visit_codes"][0]
        padded_visit_codes = np.pad(
            visit_codes, (0, max_visit_codes_length - len(visit_codes)), mode="constant"
        )
        sample["visit_codes"][0] = padded_visit_codes.tolist()
    for sample in transformed_test_ehr_dataset:
        visit_codes = sample["visit_codes"][0]
        padded_visit_codes = np.pad(
            visit_codes, (0, max_visit_codes_length - len(visit_codes)), mode="constant"
        )
        sample["visit_codes"][0] = padded_visit_codes.tolist()
    for sample in transformed_combined_ehr_dataset:
        visit_codes = sample["visit_codes"][0]
        padded_visit_codes = np.pad(
            visit_codes, (0, max_visit_codes_length - len(visit_codes)), mode="constant"
        )
        sample["visit_codes"][0] = padded_visit_codes.tolist()
    
    formatted_train_ehr_dataset = SampleEHRDataset(samples=transformed_train_ehr_dataset)
    formatted_val_ehr_dataset = SampleEHRDataset(samples=transformed_val_ehr_dataset)
    formatted_test_ehr_dataset = SampleEHRDataset(samples=transformed_test_ehr_dataset)
    formatted_combined_ehr_dataset = SampleEHRDataset(
        samples=transformed_combined_ehr_dataset
    )
    
    k = 5  # Number of folds
    fairness_scores = {
        "disparate_impact": [],
        "statistical_parity_difference": [],
        "wtpr": [],
    }
    
    kf = KFold(n_splits=k, shuffle=True, random_state=42)
    
    for train_index, val_index in kf.split(formatted_combined_ehr_dataset):
        fold_train_dataset = [formatted_combined_ehr_dataset[i] for i in train_index]
        fold_val_dataset = [formatted_combined_ehr_dataset[i] for i in val_index]
    
        transformermodel = Transformer(
            dataset=formatted_train_ehr_dataset,
            # look up what are available for "feature_keys" and "label_keys" in dataset.samples[0]
            feature_keys=["visit_codes", "disease_label", "ethnicity", "gender"],
            label_key="label",
            mode="binary",
        )
    
        train_loader = get_dataloader(fold_train_dataset, batch_size=64, shuffle=True)
        val_loader = get_dataloader(fold_val_dataset, batch_size=64, shuffle=True)
    
        trainer = Trainer(model=transformermodel)
        trainer.train(
            train_dataloader=train_loader,
            val_dataloader=val_loader,
            epochs=30,
            monitor="pr_auc",
        )
    
        y_true, y_prob, loss = trainer.inference(val_loader)
    
        # protected_group = 1  # female in gender
        # # Prepare the sensitive attribute array for the validation set
        # sensitive_attribute_array = np.zeros(len(fold_val_dataset), dtype=int)
        # for idx, visit in enumerate(fold_val_dataset):
        #     sensitive_attribute_value = visit["gender"][0][0]
        #     if sensitive_attribute_value == protected_group:
        #         sensitive_attribute_array[idx] = 1
    
        #Prepare the sensitive attribute array for the validation set
     
        unprotected_group = 0  # white in eth
        sensitive_attribute_array= np.zeros(len(fold_val_dataset), dtype=int)
        for idx, visit in enumerate(fold_val_dataset):
                sensitive_attribute_value = visit["ethnicity"][0][0]
                if sensitive_attribute_value != unprotected_group:
                    sensitive_attribute_array[idx] = 1
    
        
        # Calculate fairness metrics for the current fold
        fold_fairness_metrics = fairness_metrics_fn(
            y_true,
            y_prob,
            sensitive_attributes=sensitive_attribute_array,
            favorable_outcome=1,
            metrics=None,
            threshold=0.5,
        )
        wtpr = calculate_wtpr(y_true, y_prob, sensitive_attribute_array)
    
        # Append the fairness metrics for the current fold
        fairness_scores["disparate_impact"].append(
            fold_fairness_metrics["disparate_impact"]
        )
        fairness_scores["statistical_parity_difference"].append(
            fold_fairness_metrics["statistical_parity_difference"]
        )
        fairness_scores["wtpr"].append(wtpr)
    for metric, scores in fairness_scores.items():
        mean = np.mean(scores)
        std = np.std(scores)
        print(f"{metric}: Mean = {mean:.4f}, Std = {std:.4f}")

# Experiments with Lambda

In [None]:
def transform_data(ehr_dataset):
    final_data = []
    patient_id = 0  # Starting patient ID

    for patient in ehr_dataset:
        for i, visit in enumerate(patient["visits"]):
            visit_data = {
                "visit_id": i,
                "patient_id": patient_id,
                "visit_codes": [[int(x) for x in visit[0]]],
                "gender": [[int(float(patient["labels"][26]))]],
                "ethnicity": [[int(float(patient["labels"][25]))]],
                "disease_label": [[int(float(x)) for x in patient["labels"][0:25]]],
                "label": int(float(patient["labels"][27])),
            }
            final_data.append(visit_data)
        patient_id += 1
    return final_data

def calculate_wtpr(y_true, y_prob, sensitive_attribute, threshold=0.5):
    y_pred = (y_prob >= threshold).astype(int)
    subgroups = np.unique(sensitive_attribute)
    tpr_scores = {}

    for subgroup in subgroups:
        subgroup_mask = sensitive_attribute == subgroup
        y_true_subgroup = y_true[subgroup_mask]
        y_pred_subgroup = y_pred[subgroup_mask]

        confusion_mat = confusion_matrix(y_true_subgroup, y_pred_subgroup)

        if confusion_mat.size == 1:
            if y_true_subgroup[0] == 1:
                tp = confusion_mat[0, 0]
                fn = 0
            else:
                tp = 0
                fn = confusion_mat[0, 0]
            tn = fp = 0
        else:
            tn, fp, fn, tp = confusion_mat.ravel()

        tpr = tp / (tp + fn) if (tp + fn) != 0 else 0
        tpr_scores[subgroup] = tpr

    wtpr = np.mean(list(tpr_scores.values()))
    return wtpr

def formatter(data, max_visit_codes_length):
    for patient in data:
        patient["labels"] = [int(float(label)) for label in patient["labels"]]
    random.shuffle(data)
    transformed_data = transform_data(data)
    for sample in transformed_data:
        visit_codes = sample["visit_codes"][0]
        padded_visit_codes = np.pad(
            visit_codes, (0, max_visit_codes_length - len(visit_codes)), mode="constant"
        )
        sample["visit_codes"][0] = padded_visit_codes.tolist()
    return SampleEHRDataset(samples=transformed_data)

In [None]:
def fairness_acc_calc(fixed_dataset, var_datasets, num_epochs):
    for var_dataset in var_datasets:

        f1_scores = []
        fairness_scores = {
            "disparate_impact": [],
            "wtpr": [],
        }
        combined_data = fixed_dataset + var_dataset
        train_ehr_data = combined_data[:int(0.8 * len(combined_data))]
        max_visit_codes_length = max(len(sample["visit_codes"][0]) for sample in transform_data(combined_data))
        formatted_combined_ehr_dataset = formatter(combined_data, max_visit_codes_length)
        formatted_train_ehr_dataset = formatter(train_ehr_data, max_visit_codes_length)

        k = 5  # Number of folds
        kf = KFold(n_splits=k, shuffle=True, random_state=42)

        for train_index, val_index in kf.split(formatted_combined_ehr_dataset):
            fold_train_dataset = [formatted_combined_ehr_dataset[i] for i in train_index]
            fold_val_dataset = [formatted_combined_ehr_dataset[i] for i in val_index]

            transformermodel = Transformer(
                dataset=formatted_train_ehr_dataset,
                feature_keys=["visit_codes", "disease_label", "ethnicity", "gender"],
                label_key="label",
                mode="binary",
            )
            train_loader = get_dataloader(fold_train_dataset, batch_size=32, shuffle=True)
            val_loader = get_dataloader(fold_val_dataset, batch_size=32, shuffle=True)

            trainer = Trainer(model=transformermodel)
            trainer.train(
                train_dataloader=train_loader,
                val_dataloader=val_loader,
                epochs=num_epochs,
                optimizer_params={'lr': 1e-3, 'weight_decay': 1e-4},  # Experiment with different learning rates and weight decay
                weight_decay=1e-4,
                max_grad_norm=1.0,  # Gradient clipping
                monitor='pr_auc',  # Monitoring PR AUC
                monitor_criterion='max',
                load_best_model_at_last=True  # Load the best model at the end
            )
            unprotected_group = 1  # han in eth
            sensitive_attribute_array = np.zeros(len(fold_val_dataset), dtype=int)
            for idx, visit in enumerate(fold_val_dataset):
                sensitive_attribute_value = visit["ethnicity"][0][0]
                if sensitive_attribute_value != unprotected_group:
                    sensitive_attribute_array[idx] = 1

            # Calculate fairness metrics for the current fold
            # Calculate fairness metrics for the current fold
            try:
                fold_fairness_metrics = fairness_metrics_fn(
                y_true,
                y_prob,
                sensitive_attributes=sensitive_attribute_array,
                favorable_outcome=1,
                metrics=None,
                threshold=0.5,)
                wtpr = calculate_wtpr(y_true, y_prob, sensitive_attribute_array)
                fairness_scores["disparate_impact"].append(fold_fairness_metrics["disparate_impact"])
                fairness_scores["statistical_parity_difference"].append(fold_fairness_metrics["statistical_parity_difference"])
                fairness_scores["wtpr"].append(wtpr)
            except:
                print("DI and WTPR undefined")
            y_true, y_prob, loss = trainer.inference(val_loader)
            wtpr = calculate_wtpr(y_true, y_prob, sensitive_attribute_array)
            fairness_scores["wtpr"].append(wtpr)

            # Calculate F1-score
            score = binary_metrics_fn(y_true, y_prob, metrics=["f1"])
            f1 = score['f1']
            f1_scores.append(f1)

        f1_mean = np.mean(f1_scores)
        f1_std = np.std(f1_scores)
        print(f"VARIABLE DATA SIZE: {len(var_dataset)}")
        print(f"F1-score: {f1_mean:.2f}±{f1_std:.2f}")
        for metric, scores in fairness_scores.items():
            mean = np.nanmean(scores)  # Use nanmean to ignore NaN values
            std = np.nanstd(scores)  # Use nanstd to ignore NaN values
            print(f"{metric}: {mean:.2f}±{std:.2f}")

In [None]:
# Real + FAIRSYNTH EXPERIMENT
all_data_5000 = pickle.load(open(f"./data/allData_pic.pkl", "rb"))
allData_1000 = all_data_5000[0:1000]
allData_2500 = all_data_5000[0:2500]
allData_5000 = all_data_5000[0:5000]

real_datasets= [allData_1000,allData_2500,all_data_5000]
synth_datasets= #DATA LOCATION
fairness_acc_calc(fixed_dataset=synth_datasets,var_datasets=real_datasets,num_epochs=50)