In [None]:
!pip install tensorflow scikit-learn pandas matplotlib imbalanced-learn -q

import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.metrics import roc_auc_score, classification_report, precision_recall_curve, confusion_matrix
import matplotlib.pyplot as plt
from google.colab import drive
from imblearn.over_sampling import SMOTE

drive.mount('/content/gdrive')

# Load dataset
data = pd.read_csv('/content/gdrive/MyDrive/datamining/clinical_notes.csv')

# Data Preprocessing
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import re

# Function to clean text data
def clean_text(text):
    text = re.sub(r'\n', ' ', text)  # Remove newline characters
    text = re.sub(r'\s+', ' ', text)  # Remove multiple spaces
    text = re.sub(r'[^\w\s]', '', text)  # Remove punctuation
    text = text.lower()  # Convert to lowercase
    return text

# Clean the notes
data['note'] = data['note'].apply(clean_text)

# Tokenization
MAX_WORDS = 5000
MAX_SEQ_LENGTH = 500

tokenizer = Tokenizer(num_words=MAX_WORDS, oov_token="")
tokenizer.fit_on_texts(data['note'])
sequences = tokenizer.texts_to_sequences(data['note'])
padded_sequences = pad_sequences(sequences, maxlen=MAX_SEQ_LENGTH, padding='post')

# Encode labels
label_map = {'yes': 1, 'no': 0}
data['glaucoma'] = data['glaucoma'].map(label_map)

# Split data based on 'use' column
train_indices = data['use'] == 'training'
validation_indices = data['use'] == 'validation'
test_indices = data['use'] == 'test'

X_train, y_train = padded_sequences[train_indices], data['glaucoma'].values[train_indices]
X_val, y_val = padded_sequences[validation_indices], data['glaucoma'].values[validation_indices]
X_test, y_test = padded_sequences[test_indices], data['glaucoma'].values[test_indices]

race_test = data['race'].values[test_indices]  # For racial group evaluation

# Handle class imbalance using SMOTE (Synthetic Minority Over-sampling Technique)
smote = SMOTE(random_state=42)
X_train_resampled, y_train_resampled = smote.fit_resample(X_train, y_train)

# Fairness-Aware Training
# The fairness-aware model should try to ensure equal performance across groups
# We will use adversarial debiasing or other fairness techniques here if necessary.
# In practice, this requires additional setup such as adversarial networks or fairness constraints.

# Define Models
def build_lstm_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Embedding(MAX_WORDS, 128, input_length=MAX_SEQ_LENGTH),
        tf.keras.layers.LSTM(64, return_sequences=False),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

def build_cnn_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Embedding(MAX_WORDS, 128, input_length=MAX_SEQ_LENGTH),
        tf.keras.layers.Conv1D(128, 5, activation='relu'),
        tf.keras.layers.GlobalMaxPooling1D(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

def build_transformer_model():
    input_layer = tf.keras.layers.Input(shape=(MAX_SEQ_LENGTH,))
    embedding = tf.keras.layers.Embedding(MAX_WORDS, 128)(input_layer)
    transformer_block = tf.keras.layers.MultiHeadAttention(
        num_heads=4, key_dim=128, dropout=0.1
    )(embedding, embedding)
    flatten = tf.keras.layers.GlobalAveragePooling1D()(transformer_block)
    dense = tf.keras.layers.Dense(64, activation='relu')(flatten)
    output_layer = tf.keras.layers.Dense(1, activation='sigmoid')(dense)

    model = tf.keras.Model(inputs=input_layer, outputs=output_layer)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

# Hyperparameter Tuning using GridSearch (for better model accuracy)
from sklearn.model_selection import GridSearchCV

param_grid = {
    'batch_size': [16, 32, 64],
    'epochs': [10, 20, 30],
    'optimizer': ['adam', 'rmsprop'],
    'learning_rate': [0.001, 0.01, 0.1]
}

# Early Stopping and Learning Rate Scheduler
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3)

# Training and Evaluation
models = {
    'LSTM': build_lstm_model(),
    '1D CNN': build_cnn_model(),
    'Transformer': build_transformer_model()
}

history = {}
aucs = {}
race_aucs = {}
precision_recall = {}

for model_name, model in models.items():
    print(f"Training {model_name}...")
    history[model_name] = model.fit(
        X_train_resampled, y_train_resampled, validation_data=(X_val, y_val),
        epochs=5, batch_size=32, verbose=1,
        callbacks=[early_stopping, lr_scheduler]
    )

    # Compute Overall AUC
    y_pred = model.predict(X_test).ravel()
    aucs[model_name] = roc_auc_score(y_test, y_pred)
    print(f"Overall AUC for {model_name}: {aucs[model_name]}")

    # Compute AUC per racial group
    race_aucs[model_name] = {}
    for race in ['asian', 'black', 'white']:
        race_idx = np.where(race_test == race)
        race_auc = roc_auc_score(y_test[race_idx], y_pred[race_idx])
        race_aucs[model_name][race] = race_auc
        print(f"  {race} AUC for {model_name}: {race_auc}")

    # Additional metrics: Precision, Recall, F1-Score
    print(f"Classification Report for {model_name}:\n")
    print(classification_report(y_test, (y_pred > 0.5).astype(int)))

    # Plot Precision-Recall curve
    precision, recall, _ = precision_recall_curve(y_test, y_pred)
    plt.plot(recall, precision, label=f'{model_name} Precision-Recall')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.legend()
    plt.show()

# 6. Visualization: Validation Loss
for model_name, hist in history.items():
    plt.plot(hist.history['val_loss'], label=f'{model_name} Loss')
plt.legend()
plt.title('Validation Loss Comparison')
plt.show()

# 7. Summarize Results
print("Overall AUC Scores:", aucs)
print("Race AUC Scores:", race_aucs)

