In [1]:
# Install Libraries
!pip install -q tensorflow scikit-learn

import pandas as pd
import numpy as np
import re
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.metrics import roc_auc_score, confusion_matrix
from sklearn.model_selection import train_test_split

AUTOTUNE = tf.data.AUTOTUNE

np.random.seed(42)
tf.random.set_seed(42)
print("Environment ready")


Environment ready


In [2]:
# read the csv
df = pd.read_csv("clinical_notes.csv")
# see if data is imported succesfully
print(df.head())
print(df.shape)
print(df.columns)


     age  gender   race     ethnicity language maritalstatus  \
0  56.56  female  black  non-hispanic  english        single   
1  53.91  female  white  non-hispanic  english        single   
2  46.30  female  white  non-hispanic  english        single   
3  66.52    male  white  non-hispanic  english        single   
4  82.52  female  black  non-hispanic  english      divorced   

                                                note  \
0  ms. PERSON is a 56 yo woman presenting to esta...   
1  referred for evaluation of narrow angles ou #p...   
2  1. left upper lid ptosis: occurred after botox...   
3  right plano +0.50 082 left LOCATION -0.50 83 a...   
4  in step. os with nonspecific peripheral defect...   

                                        gpt4_summary glaucoma       use  
0  The 56 y/o female patient has optic nerve head...      yes  training  
1  Patient was referred for narrow angle evaluati...      yes  training  
2  Patient experienced ptosis, ear and eye pain, ...    

In [3]:
# clean data with lowercase, pronunciation
def clean_text(t):
    if pd.isna(t):
      return ""
    t = str(t).lower()
    t = re.sub(r'[^a-z0-9\s]', ' ', t)
    return re.sub(r'\s+', ' ', t).strip()

data = df.copy()
data["clean_text"] = data["note"].apply(clean_text)
data["label"] = data["glaucoma"].map({"yes": 1, "no": 0})
data = data.dropna(subset=["label"])

# Race column for Asian, Black, White
race_col = "race" if "race" in data.columns else "ethnicity"
data[race_col] = data[race_col].astype(str)

print("Race distribution:")
print(data[race_col].value_counts())


Race distribution:
race
white    7690
black    1491
asian     819
Name: count, dtype: int64


In [4]:
X = data["clean_text"].values
y = data["label"].astype(int).values
race = data[race_col].values

# Train, Test split
X_train, X_test, y_train, y_test, race_train, race_test = train_test_split(
    X, y, race, test_size=0.20, stratify=y, random_state=42
)
# Train, validation split
X_train, X_val, y_train, y_val, race_train, race_val = train_test_split(
    X_train, y_train, race_train, test_size=0.20, stratify=y_train, random_state=42
)
# print summary
print(f"Train={len(X_train)}, Val={len(X_val)}, Test={len(X_test)}")


Train=6400, Val=1600, Test=2000


In [5]:
# Count glaucoma vs non-glaucoma samples
glaucoma_counts = data['label'].value_counts().sort_index()

num_no_glaucoma = glaucoma_counts.get(0, 0)
num_glaucoma     = glaucoma_counts.get(1, 0)

print("========== Glaucoma Label Distribution ==========")
print(f"No Glaucoma (0): {num_no_glaucoma}")
print(f"Glaucoma (1):     {num_glaucoma}")
print(f"Total Samples:    {num_no_glaucoma + num_glaucoma}")


No Glaucoma (0): 4952
Glaucoma (1):     5048
Total Samples:    10000


In [7]:
MAX_WORDS = 6000 # vocabulary size
MAX_LEN = 250 # max length
BATCH = 16 # 250 * 16 = 6000

# vectorization
text_vec = layers.TextVectorization(
    max_tokens=MAX_WORDS,
    output_sequence_length=MAX_LEN,
    standardize=None
)
text_vec.adapt(X_train)

def vectorize(text, label, race):
    text = text_vec(text)
    return text, label, race


In [8]:
def make_training_val_dataset(texts, labels):
    ds = tf.data.Dataset.from_tensor_slices((texts, labels))
    ds = ds.batch(BATCH).prefetch(AUTOTUNE)
    return ds.map(lambda t, l: (text_vec(t), l))

train_ds = make_training_val_dataset(X_train, y_train)
val_ds   = make_training_val_dataset(X_val, y_val)

# Create a separate dataset for test predictions, containing only input features
test_ds_for_prediction = tf.data.Dataset.from_tensor_slices(X_test)
test_ds_for_prediction = test_ds_for_prediction.batch(BATCH).map(text_vec).prefetch(AUTOTUNE)


In [9]:
# Building the LSTM model
def build_lstm():
    inputs = keras.Input(shape=(MAX_LEN,), dtype=tf.int32)
    x = layers.Embedding(MAX_WORDS, 32)(inputs) # 32 dims
    x = layers.SpatialDropout1D(0.3)(x)
    x = layers.LSTM(32)(x)
    x = layers.Dense(16, activation="relu")(x) # 16 unit dense layer
    x = layers.Dropout(0.4)(x)
    outputs = layers.Dense(1, activation="sigmoid")(x)
    model = keras.Model(inputs, outputs)
    model.compile(
        optimizer=keras.optimizers.Adam(5e-4),
        loss="binary_crossentropy",
        metrics=["accuracy", keras.metrics.AUC(name="auc")]
    )
    return model


In [10]:
# Building the GRU model
def build_gru():
    inputs = keras.Input(shape=(MAX_LEN,), dtype=tf.int32)
    x = layers.Embedding(MAX_WORDS, 32)(inputs) # 32 dims
    x = layers.SpatialDropout1D(0.3)(x)
    x = layers.GRU(32)(x)
    x = layers.Dense(16, activation="relu")(x) # 16 unit
    x = layers.Dropout(0.4)(x)
    outputs = layers.Dense(1, activation="sigmoid")(x)
    model = keras.Model(inputs, outputs)
    model.compile(
        optimizer=keras.optimizers.Adam(5e-4),
        loss="binary_crossentropy",
        metrics=["accuracy", keras.metrics.AUC(name="auc")]
    )
    return model


In [11]:
# Building CNN model
def build_cnn():
    inputs = keras.Input(shape=(MAX_LEN,), dtype=tf.int32)
    x = layers.Embedding(MAX_WORDS, 32)(inputs) # dims 32
    x = layers.Conv1D(64, 5, activation="relu")(x)
    x = layers.GlobalMaxPooling1D()(x)
    x = layers.Dense(32, activation="relu")(x)
    x = layers.Dropout(0.4)(x)
    outputs = layers.Dense(1, activation="sigmoid")(x)
    model = keras.Model(inputs, outputs)
    model.compile(
        optimizer=keras.optimizers.Adam(5e-4),
        loss="binary_crossentropy",
        metrics=["accuracy", keras.metrics.AUC(name="auc")]
    )
    return model


In [12]:
callbacks = [
    keras.callbacks.EarlyStopping(monitor="val_auc", patience=3,
                                  mode="max", restore_best_weights=True),
    keras.callbacks.ReduceLROnPlateau(monitor="val_auc", patience=2,
                                      factor=0.5, mode="max"),
    keras.callbacks.ModelCheckpoint("best_model.h5",
                                    monitor="val_auc", mode="max",
                                    save_best_only=True)
]

models = {
    "LSTM": build_lstm(),
    "GRU": build_gru(),
    "CNN": build_cnn()
}

histories = {}
# Training
for name, model in models.items():
    print(f"\n=== Training {name} ===")
    hist = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=20, # training for 20 epochs with early stopping
        callbacks=callbacks,
        verbose=1
    )
    histories[name] = hist



=== Training LSTM ===
Epoch 1/20
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 83ms/step - accuracy: 0.5091 - auc: 0.5150 - loss: 0.6927



[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 91ms/step - accuracy: 0.5091 - auc: 0.5150 - loss: 0.6927 - val_accuracy: 0.5050 - val_auc: 0.5392 - val_loss: 0.6921 - learning_rate: 5.0000e-04
Epoch 2/20
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 93ms/step - accuracy: 0.5111 - auc: 0.5274 - loss: 0.6922



[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 98ms/step - accuracy: 0.5111 - auc: 0.5274 - loss: 0.6922 - val_accuracy: 0.5337 - val_auc: 0.5428 - val_loss: 0.6904 - learning_rate: 5.0000e-04
Epoch 3/20
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 95ms/step - accuracy: 0.5333 - auc: 0.5319 - loss: 0.6891



[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 101ms/step - accuracy: 0.5332 - auc: 0.5319 - loss: 0.6891 - val_accuracy: 0.5431 - val_auc: 0.5496 - val_loss: 0.6884 - learning_rate: 5.0000e-04
Epoch 4/20
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 83ms/step - accuracy: 0.5337 - auc: 0.5323 - loss: 0.6853



[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m36s[0m 89ms/step - accuracy: 0.5337 - auc: 0.5324 - loss: 0.6853 - val_accuracy: 0.5450 - val_auc: 0.5554 - val_loss: 0.6823 - learning_rate: 5.0000e-04
Epoch 5/20
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 82ms/step - accuracy: 0.5503 - auc: 0.5828 - loss: 0.6706 - val_accuracy: 0.5069 - val_auc: 0.5465 - val_loss: 0.6876 - learning_rate: 5.0000e-04
Epoch 6/20
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 82ms/step - accuracy: 0.5530 - auc: 0.6032 - loss: 0.6511 - val_accuracy: 0.5462 - val_auc: 0.5542 - val_loss: 0.6848 - learning_rate: 5.0000e-04
Epoch 7/20
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 76ms/step - accuracy: 0.5853 - auc: 0.6324 - loss: 0.6353



[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 81ms/step - accuracy: 0.5853 - auc: 0.6324 - loss: 0.6353 - val_accuracy: 0.5600 - val_auc: 0.5806 - val_loss: 0.6812 - learning_rate: 2.5000e-04
Epoch 8/20
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 82ms/step - accuracy: 0.5865 - auc: 0.6417 - loss: 0.6278 - val_accuracy: 0.5194 - val_auc: 0.5691 - val_loss: 0.6921 - learning_rate: 2.5000e-04
Epoch 9/20
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 79ms/step - accuracy: 0.5830 - auc: 0.6380 - loss: 0.6255 - val_accuracy: 0.5587 - val_auc: 0.5747 - val_loss: 0.6962 - learning_rate: 2.5000e-04
Epoch 10/20
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 78ms/step - accuracy: 0.6031 - auc: 0.6732 - loss: 0.6079



[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 83ms/step - accuracy: 0.6031 - auc: 0.6732 - loss: 0.6079 - val_accuracy: 0.5631 - val_auc: 0.5892 - val_loss: 0.6995 - learning_rate: 1.2500e-04
Epoch 11/20
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 81ms/step - accuracy: 0.6090 - auc: 0.6600 - loss: 0.6067



[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 86ms/step - accuracy: 0.6090 - auc: 0.6600 - loss: 0.6067 - val_accuracy: 0.5713 - val_auc: 0.6030 - val_loss: 0.7087 - learning_rate: 1.2500e-04
Epoch 12/20
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 76ms/step - accuracy: 0.6117 - auc: 0.6824 - loss: 0.5963



[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 80ms/step - accuracy: 0.6117 - auc: 0.6824 - loss: 0.5963 - val_accuracy: 0.5725 - val_auc: 0.6038 - val_loss: 0.7106 - learning_rate: 1.2500e-04
Epoch 13/20
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 82ms/step - accuracy: 0.6025 - auc: 0.6749 - loss: 0.5963



[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m35s[0m 87ms/step - accuracy: 0.6026 - auc: 0.6750 - loss: 0.5963 - val_accuracy: 0.5838 - val_auc: 0.6117 - val_loss: 0.6993 - learning_rate: 1.2500e-04
Epoch 14/20
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 91ms/step - accuracy: 0.6066 - auc: 0.6910 - loss: 0.5886



[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 96ms/step - accuracy: 0.6066 - auc: 0.6910 - loss: 0.5886 - val_accuracy: 0.6488 - val_auc: 0.6701 - val_loss: 0.6910 - learning_rate: 1.2500e-04
Epoch 15/20
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 95ms/step - accuracy: 0.6284 - auc: 0.7017 - loss: 0.5857 - val_accuracy: 0.6006 - val_auc: 0.6384 - val_loss: 0.6897 - learning_rate: 1.2500e-04
Epoch 16/20
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 92ms/step - accuracy: 0.6378 - auc: 0.7070 - loss: 0.5880 - val_accuracy: 0.5844 - val_auc: 0.6265 - val_loss: 0.7149 - learning_rate: 1.2500e-04
Epoch 17/20
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 94ms/step - accuracy: 0.6638 - auc: 0.7338 - loss: 0.5670 - val_accuracy: 0.6137 - val_auc: 0.6466 - val_loss: 0.7112 - learning_rate: 6.2500e-05

=== Training GRU ===




[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 14ms/step - accuracy: 0.5717 - auc: 0.5883 - loss: 0.6787 - val_accuracy: 0.7256 - val_auc: 0.8131 - val_loss: 0.5328 - learning_rate: 5.0000e-04
Epoch 2/20
[1m397/400[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 14ms/step - accuracy: 0.7546 - auc: 0.8280 - loss: 0.5063



[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.7547 - auc: 0.8282 - loss: 0.5061 - val_accuracy: 0.7769 - val_auc: 0.8625 - val_loss: 0.4525 - learning_rate: 5.0000e-04
Epoch 3/20
[1m397/400[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 12ms/step - accuracy: 0.8161 - auc: 0.8989 - loss: 0.4039



[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 13ms/step - accuracy: 0.8162 - auc: 0.8990 - loss: 0.4037 - val_accuracy: 0.7800 - val_auc: 0.8734 - val_loss: 0.4438 - learning_rate: 5.0000e-04
Epoch 4/20
[1m397/400[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 14ms/step - accuracy: 0.8619 - auc: 0.9360 - loss: 0.3288



[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.8621 - auc: 0.9361 - loss: 0.3285 - val_accuracy: 0.7862 - val_auc: 0.8753 - val_loss: 0.4502 - learning_rate: 5.0000e-04
Epoch 5/20
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 13ms/step - accuracy: 0.9023 - auc: 0.9633 - loss: 0.2544 - val_accuracy: 0.7837 - val_auc: 0.8733 - val_loss: 0.4855 - learning_rate: 5.0000e-04
Epoch 6/20
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 16ms/step - accuracy: 0.9346 - auc: 0.9797 - loss: 0.1909 - val_accuracy: 0.7856 - val_auc: 0.8718 - val_loss: 0.5190 - learning_rate: 5.0000e-04
Epoch 7/20
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 13ms/step - accuracy: 0.9608 - auc: 0.9896 - loss: 0.1361 - val_accuracy: 0.7862 - val_auc: 0.8712 - val_loss: 0.5409 - learning_rate: 2.5000e-04


In [13]:
TARGET_GROUPS = ["asian", "black", "white"]

# fairness report evaluation
def fairness_report(model, test_texts, test_labels, test_races, test_ds_for_prediction):
    preds = model.predict(test_ds_for_prediction).flatten()
    binary = (preds >= 0.5).astype(int)

    # Overall
    auc = roc_auc_score(test_labels, preds)
    tn, fp, fn, tp = confusion_matrix(test_labels, binary).ravel()
    sens = tp/(tp+fn)
    spec = tn/(tn+fp)

    print(f"\nOverall AUC={auc:.4f}, Sens={sens:.4f}, Spec={spec:.4f}")

    # Per group
    results = {"Overall": {"auc": auc, "sens": sens, "spec": spec}}

    for grp in TARGET_GROUPS:
        mask = (test_races == grp)
        if mask.sum() < 5:
            print(f"{grp}: too few samples")
            continue

        auc_g = roc_auc_score(test_labels[mask], preds[mask])
        tn, fp, fn, tp = confusion_matrix(test_labels[mask],
                                         binary[mask]).ravel()
        sens_g = tp/(tp+fn)
        spec_g = tn/(tn+fp)

        print(f"{grp} \u2014 n={mask.sum()} \u2014 AUC={auc_g:.4f}, Sens={sens_g:.4f}, Spec={spec_g:.4f}")
        results[grp] = {"auc": auc_g, "sens": sens_g, "spec": spec_g}

    return results

In [14]:
# Performance of the 3 models
lstm_results = fairness_report(models["LSTM"], X_test, y_test, race_test, test_ds_for_prediction)
gru_results  = fairness_report(models["GRU"], X_test, y_test, race_test, test_ds_for_prediction)
cnn_results  = fairness_report(models["CNN"], X_test, y_test, race_test, test_ds_for_prediction)

[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 19ms/step

Overall AUC=0.7247, Sens=0.4109, Spec=0.8818
asian — n=157 — AUC=0.7546, Sens=0.3735, Spec=0.8919
black — n=301 — AUC=0.7316, Sens=0.4241, Spec=0.8909
white — n=1542 — AUC=0.7198, Sens=0.4117, Spec=0.8797
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 19ms/step

Overall AUC=0.5605, Sens=0.1317, Spec=0.9202
asian — n=157 — AUC=0.5314, Sens=0.0964, Spec=0.8919
black — n=301 — AUC=0.5735, Sens=0.1257, Spec=0.8818
white — n=1542 — AUC=0.5611, Sens=0.1372, Spec=0.9280
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 5ms/step

Overall AUC=0.8777, Sens=0.8663, Spec=0.6990
asian — n=157 — AUC=0.8881, Sens=0.8675, Spec=0.7703
black — n=301 — AUC=0.8840, Sens=0.8534, Spec=0.6818
white — n=1542 — AUC=0.8767, Sens=0.8696, Spec=0.6948


In [15]:
print("\n========== FINAL SUMMARY ==========")

for name, res in [("LSTM", lstm_results),
                  ("GRU", gru_results),
                  ("CNN", cnn_results)]:
    print(f"\n{name} Overall AUC = {res['Overall']['auc']:.4f}")

    asian_auc = res.get("asian", {}).get("auc", None)
    black_auc = res.get("black", {}).get("auc", None)
    white_auc = res.get("white", {}).get("auc", None)

    print(f"{name} Asian AUC: {asian_auc:.4f}")
    print(f"{name} Black AUC: {black_auc:.4f}")
    print(f"{name} White AUC: {white_auc:.4f}")




LSTM Overall AUC = 0.7247
LSTM Asian AUC: 0.7546
LSTM Black AUC: 0.7316
LSTM White AUC: 0.7198

GRU Overall AUC = 0.5605
GRU Asian AUC: 0.5314
GRU Black AUC: 0.5735
GRU White AUC: 0.5611

CNN Overall AUC = 0.8777
CNN Asian AUC: 0.8881
CNN Black AUC: 0.8840
CNN White AUC: 0.8767
