In [2]:
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR 
import numpy as np
import pandas as pd
import os
from tqdm.notebook import tqdm # Use notebook version of tqdm (ensure ipywidgets is installed)
# from tqdm import tqdm # Alternative if ipywidgets causes issues
from sklearn.metrics import roc_curve 


In [3]:
# --- Configuration ---
T_HOURS = 48
N_BINS = 20
SEED = 0
MAX_LEN = 10000 # Max sequence length used during training

# Paths
PROJECT_ROOT = "/changed" # Adjust if needed
DATA_ROOT_DIR = os.path.join(PROJECT_ROOT, "final_data") # Or your output_dir
RESULTS_DIR = os.path.join(PROJECT_ROOT, "results")

# --- !! Use the correct base name from the grid search !! ---
MODEL_RUN_NAME_BASE = "MortalityLSTM_GridSearch" # <--- CORRECTED BASE NAME

# Hyperparameters
LR = 0.0005
EPOCHS = 150 # Number of epochs to run
BATCH_SIZE = 128
LATENT_DIM = 64
HIDDEN_DIM = 256
P_DROPOUT = 0.1
EARLY_STOPPING_PATIENCE = 5 # Set to large number or None to effectively disable for full run
MODEL_TYPE = 'Mortality'
DT = 1.0
WEIGHTED = False # Set to False to test without weighted embeddings
DYNAMIC = True



In [4]:
PREDICTION_SCORE_TYPE = 'last'

In [5]:
# --- Device Setup ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}") # Keep this one

# --- Import necessary project modules ---
try:
    # Assuming utils and models are in the python path or current directory structure
    from utils.modelIO import load_model # Using simplified load_model call below
    # Make sure this imports the Model class with the NaN checks
    from models.models import init_model
    # Make sure this imports the LSTM with LayerNorm enabled if you're using it
    # from models.lstms import LSTM
except ImportError as e:
    print(f"Import Error: {e}. Ensure project structure is correct and accessible.")
    raise

# --- Construct Paths ---
dict_dir = os.path.join(DATA_ROOT_DIR, 'dictionaries')
array_dir = os.path.join(DATA_ROOT_DIR, 'arrays')
split_dir = os.path.join(DATA_ROOT_DIR, 'splits')

# --- Define Model Name ---
model_name_suffix = f't{T_HOURS}_lr{LR}_z{LATENT_DIM}' \
                  + f'_h{HIDDEN_DIM}_p{P_DROPOUT}_w{WEIGHTED}_d{DYNAMIC}_seed{SEED}'
model_run_name = f'{MODEL_RUN_NAME_BASE}_{model_name_suffix}'
model_dir = os.path.join(RESULTS_DIR, model_run_name)
model_path = os.path.join(model_dir, 'model.pt') # Path to saved model


# --- Define Data Paths ---
token_map_path = os.path.join(dict_dir, f'{T_HOURS}_{SEED}_{N_BINS}-token2index.npy')
array_path = os.path.join(array_dir, f'{T_HOURS}_{SEED}_{N_BINS}-arrays.npz')
test_split_path = os.path.join(split_dir, f'{SEED}-{T_HOURS}-test.csv')
# !! Important: Update these paths to your actual MIMIC CSV locations !!
d_items_path = '/D_ITEMS.csv'
d_labitems_path = '/D_LABITEMS.csv'

# --- 1. Load Mappings (Token & ITEMID) ---
if not os.path.exists(token_map_path): raise FileNotFoundError(f"Token map not found: {token_map_path}")
token2index = np.load(token_map_path, allow_pickle=True).item()
index2token = {v: k for k, v in token2index.items()}
n_tokens = len(token2index)
padding_idx = token2index.get('<PAD>', 0)

itemid_to_label = {}
d_items_df = pd.read_csv(d_items_path, usecols=['ITEMID', 'LABEL'])
itemid_to_label.update(pd.Series(d_items_df.LABEL.values, index=d_items_df.ITEMID).to_dict())
d_labitems_df = pd.read_csv(d_labitems_path, usecols=['ITEMID', 'LABEL'])
itemid_to_label.update(pd.Series(d_labitems_df.LABEL.values, index=d_labitems_df.ITEMID).to_dict())



# --- 2. Load Model ---
if not os.path.exists(model_path): raise FileNotFoundError(f"Saved model model.pt not found at {model_path}")

model = init_model(
    model_type=MODEL_TYPE, n_tokens=n_tokens, latent_dim=LATENT_DIM,
    hidden_dim=HIDDEN_DIM, p_dropout=P_DROPOUT, dt=DT,
    weighted=WEIGHTED, dynamic=DYNAMIC#, use_layer_norm=True # Pass if your init_model accepts it
)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()


# --- 3. Load Test Set Data and Create Ordered Index Mapping ---
if not os.path.exists(test_split_path): raise FileNotFoundError(f"Test split file not found: {test_split_path}")
test_df = pd.read_csv(test_split_path)
if test_df.empty: raise ValueError("Test split file is empty.")
test_entity_paths_ordered = test_df['Paths'].tolist()

if not os.path.exists(array_path): raise FileNotFoundError(f"Data array file not found: {array_path}")
with np.load(array_path, allow_pickle=True) as data:
    X_all = data['X']
    Paths_all_list = data['paths'].tolist()
    Y_all = data['Y']

path_to_idx_map = {path: i for i, path in enumerate(Paths_all_list)}
ordered_test_indices_in_all = []
for path in tqdm(test_entity_paths_ordered, desc="Mapping Paths"):
    idx = path_to_idx_map.get(path, -1)
    if idx != -1:
        ordered_test_indices_in_all.append(idx)
    else:
        print(f"Warning: Path {path} from test split not found in data array. Skipping.")

if not ordered_test_indices_in_all:
    raise ValueError("Could not map any test paths to the data array.")

# --- 4. Get Predictions and Interval Counts for the Entire Test Set ---
all_test_preds = []
all_test_trues = []
all_test_num_intervals = []

for full_idx in tqdm(ordered_test_indices_in_all, desc="Predicting"):
    X_single_np = X_all[full_idx]
    Y_true = Y_all[full_idx]
    X_single_tensor = torch.tensor(X_single_np, dtype=torch.float32).unsqueeze(0).to(device)

    with torch.no_grad():
        # Assuming model.forward now handles potential NaNs internally and raises errors
        output_logits = model(X_single_tensor)
        if output_logits.dim() > 1 and output_logits.shape[0] == 1:
             output_logits = output_logits.squeeze(0)
        output_probs = torch.sigmoid(output_logits).cpu().numpy()

    num_intervals = len(output_probs)
    all_test_num_intervals.append(num_intervals)

    pred_score = 0.0
    if num_intervals > 0:
        if PREDICTION_SCORE_TYPE == 'last': pred_score = output_probs[-1]
        elif PREDICTION_SCORE_TYPE == 'max': pred_score = np.max(output_probs)
        elif PREDICTION_SCORE_TYPE == 'mean': pred_score = np.mean(output_probs)
        else: pred_score = output_probs[-1] # Defaulting
    all_test_preds.append(pred_score)
    all_test_trues.append(Y_true)

all_test_preds = np.array(all_test_preds)
all_test_trues = np.array(all_test_trues)
all_test_num_intervals = np.array(all_test_num_intervals)

if len(all_test_preds) == 0:
    raise ValueError("Failed to generate any predictions for the test set.")

# --- 5. Calculate Optimal Threshold using Youden's J ---
fpr, tpr, thresholds = roc_curve(all_test_trues, all_test_preds)
j_scores = tpr - fpr
valid_indices = np.where(np.isfinite(j_scores))[0]
if len(valid_indices) == 0:
    print("Warning: Could not calculate valid Youden's J scores. Using default threshold 0.5")
    optimal_threshold = 0.5
    ix = -1
else:
    ix = valid_indices[np.argmax(j_scores[valid_indices])]
    # Ensure threshold is not -inf or inf if they are included by roc_curve
    if ix > 0 and ix < len(thresholds) -1 :
         optimal_threshold = thresholds[ix]
    elif ix == 0: # Corresponds to inf threshold? Take next one maybe?
        optimal_threshold = thresholds[1] if len(thresholds) > 1 else 0.5
        print(f"Warning: Optimal Youden's J index is 0 (threshold={thresholds[0]}). Using next threshold {optimal_threshold:.4f} instead.")
    else: # Corresponds to threshold ~0?
        optimal_threshold = thresholds[-2] if len(thresholds) > 1 else 0.5
        print(f"Warning: Optimal Youden's J index is last (threshold={thresholds[-1]}). Using previous threshold {optimal_threshold:.4f} instead.")

# --- 6. Classify Test Set and Identify Categories ---
predicted_labels = (all_test_preds >= optimal_threshold).astype(int)
tp_test_indices = np.where((all_test_trues == 1) & (predicted_labels == 1))[0]
fp_test_indices = np.where((all_test_trues == 0) & (predicted_labels == 1))[0]
fn_test_indices = np.where((all_test_trues == 1) & (predicted_labels == 0))[0]
tn_test_indices = np.where((all_test_trues == 0) & (predicted_labels == 0))[0]

# print(f"\nClassification counts at threshold {optimal_threshold:.4f}:")
# print(f"  True Positives (TP): {len(tp_test_indices)}")
# print(f"  False Positives (FP): {len(fp_test_indices)}")
# print(f"  False Negatives (FN): {len(fn_test_indices)}")
# print(f"  True Negatives (TN): {len(tn_test_indices)}")

# --- Modified Helper Function ---
def count_intervals_with_events(full_idx, num_intervals_for_patient, X_all_data, dt_val, pad_idx):
    """Counts the number of intervals with at least one non-padding event."""
    if num_intervals_for_patient <= 0:
        return 0 # No intervals means 0 intervals with events
    count = 0
    X_single_np_check = X_all_data[full_idx]
    for interval_idx in range(num_intervals_for_patient):
        t_start = interval_idx * dt_val
        t_end = (interval_idx + 1) * dt_val
        tolerance = 1e-6
        interval_events_mask = (X_single_np_check[:, 0] >= (t_start - tolerance)) & (X_single_np_check[:, 0] < (t_end - tolerance))
        # Check if ANY event exists in mask before trying to access indices
        if interval_events_mask.any():
            interval_event_indices = X_single_np_check[interval_events_mask, 1].astype(int)
            non_padding_events_in_interval = interval_event_indices[interval_event_indices != pad_idx]
            if len(non_padding_events_in_interval) > 0:
                count += 1
        # If no events fall in the mask, count remains unchanged (0 events for this interval)
    return count

# --- 7. Select One Patient Index per Category (New Logic) ---
selected_tp_idx, selected_fp_idx, selected_fn_idx, selected_tn_idx = -1, -1, -1, -1

# print("\nSelecting representative patients (Prioritizing events in all intervals)...")

# --- Generic Selection Function ---
def select_best_patient(category_indices, category_label):
    selected_idx = -1
    best_imperfect_idx = -1
    max_intervals_count = -1
    perfect_found = False

    for k in category_indices:
        full_idx = ordered_test_indices_in_all[k]
        num_intervals = all_test_num_intervals[k]
        count = count_intervals_with_events(full_idx, num_intervals, X_all, DT, padding_idx)

        if count == num_intervals and num_intervals > 0: # Perfect match
            selected_idx = full_idx
            perfect_found = True
            # print(f"{category_label} search result: Found perfect match (Index {selected_idx})")
            break # Stop searching for this category
        elif count > max_intervals_count: # Better imperfect match
            max_intervals_count = count
            best_imperfect_idx = full_idx

    if not perfect_found:
        selected_idx = best_imperfect_idx # Use the best imperfect match found
        if selected_idx != -1:
             print(f"{category_label} search result: No perfect match. Using best imperfect (Index {selected_idx}, {max_intervals_count} active intervals)")
        else:
             print(f"{category_label} search result: No suitable patient found (neither perfect nor imperfect)")

    return selected_idx

# --- Select for each category ---
selected_tp_idx = select_best_patient(tp_test_indices, "TP")
selected_fp_idx = select_best_patient(fp_test_indices, "FP")
selected_fn_idx = select_best_patient(fn_test_indices, "FN")
selected_tn_idx = select_best_patient(tn_test_indices, "TN")

# # Keep final selected indices print
# print("\nSelected Patient Indices (in full dataset) for Demo:")
# print(f"  TP Example Index: {selected_tp_idx}" + (f" (Path: {Paths_all_list[selected_tp_idx]})" if selected_tp_idx !=-1 else " (None available)"))
# print(f"  FP Example Index: {selected_fp_idx}" + (f" (Path: {Paths_all_list[selected_fp_idx]})" if selected_fp_idx !=-1 else " (None available)"))
# print(f"  FN Example Index: {selected_fn_idx}" + (f" (Path: {Paths_all_list[selected_fn_idx]})" if selected_fn_idx !=-1 else " (None available)"))
# print(f"  TN Example Index: {selected_tn_idx}" + (f" (Path: {Paths_all_list[selected_tn_idx]})" if selected_tn_idx !=-1 else " (None available)"))

Using device: cpu
Initializing model of type Mortality...


Mapping Paths:   0%|          | 0/4885 [00:00<?, ?it/s]

Predicting:   0%|          | 0/4885 [00:00<?, ?it/s]

In [6]:
# --- 8. Prediction and Display Function (Assume it's defined correctly as before) ---
def run_demo_for_entity(entity_idx, category_label, model, X_all, Y_all, Paths_all_list, index2token, itemid_to_label, DT, device, padding_idx):
    """Loads data for a single entity, runs prediction, and prints the demo output with category label."""
    if entity_idx == -1:
        print(f"\n--- Skipping Demo for Category: {category_label} (No suitable patient selected) ---")
        return

    X_single_np = X_all[entity_idx] # Shape: (max_len, 2)
    Y_true_label = Y_all[entity_idx] # True mortality label
    entity_relative_path = Paths_all_list[entity_idx]

    # Get the single prediction score used for thresholding again for reference
    X_single_tensor_ref = torch.tensor(X_single_np, dtype=torch.float32).unsqueeze(0).to(device)
    pred_score_ref = 0.0
    output_probs = np.array([]) # Initialize empty
    num_intervals = 0
    with torch.no_grad():
        try:
             output_logits_ref = model(X_single_tensor_ref) # Run prediction
             if output_logits_ref.dim() > 1 and output_logits_ref.shape[0] == 1: output_logits_ref = output_logits_ref.squeeze(0)
             output_probs = torch.sigmoid(output_logits_ref).cpu().numpy() # Get all interval probs
             num_intervals = len(output_probs)

             if num_intervals > 0: # Calculate score based on type
                  if PREDICTION_SCORE_TYPE == 'last': pred_score_ref = output_probs[-1]
                  elif PREDICTION_SCORE_TYPE == 'max': pred_score_ref = np.max(output_probs)
                  elif PREDICTION_SCORE_TYPE == 'mean': pred_score_ref = np.mean(output_probs)
                  else: pred_score_ref = output_probs[-1] # Default
        except Exception as e:
             print(f"\nError during prediction within run_demo_for_entity for {entity_relative_path}: {e}")
             # Cannot proceed with demo if prediction fails
             return


    print(f"\n--- Demo for [{category_label}] Entity: {entity_relative_path} ---")
    print(f"    True Mortality: {Y_true_label}")



    if num_intervals == 0:
        print("Model produced 0 intervals. Cannot display results.")
        return

    # Display interval-by-interval
    for interval_idx in range(num_intervals):
        t_start = interval_idx * DT
        t_end = (interval_idx + 1) * DT
        interval_label = f"Hour {int(t_start):<2d}-{int(t_end):<2d}"

        tolerance = 1e-6
        interval_events_mask = (X_single_np[:, 0] >= (t_start - tolerance)) & (X_single_np[:, 0] < (t_end - tolerance))
        interval_event_indices = X_single_np[interval_events_mask, 1].astype(int)
        interval_event_indices = interval_event_indices[interval_event_indices != padding_idx]

        interval_prob = output_probs[interval_idx]

        print(f"\n{interval_label}:")
        if len(interval_event_indices) > 0:
            print("  Events:")
            for token_idx in interval_event_indices:
                token_string = index2token.get(token_idx, f"<Unknown Index: {token_idx}>")
                clinical_label = "<Special Token>"
                if not token_string.startswith('<') and not token_string.endswith('>'):
                    try:
                        itemid_str = token_string.split('_')[0]
                        itemid = int(itemid_str)
                        clinical_label = itemid_to_label.get(itemid, "<Unknown ITEMID>")
                    except ValueError:
                        clinical_label = "<Invalid ITEMID format>"
                    except Exception as e:
                        clinical_label = "<Label Lookup Error>"
                print(f"    - Token: {token_string:<45} | Meaning: {clinical_label}")
        else:
            print("  (No non-padding events in this interval)") # This might now appear for imperfect matches

        print(f"  Predicted Mortality Probability at end of hour {int(t_end)}: {interval_prob:.4f}")

    print(f"--- Finished Demo for [{category_label}] Entity: {entity_relative_path} ---")

In [7]:
# --- 9. Run Demo for Selected Patients ---
run_demo_for_entity(selected_tp_idx, "True Positive (TP)", model, X_all, Y_all, Paths_all_list, index2token, itemid_to_label, DT, device, padding_idx)


--- Demo for [True Positive (TP)] Entity: 32725/episode1_timeseries_48.csv ---
    True Mortality: 1

Hour 0 -1 :
  Events:
    - Token: 226253_%:1                                    | Meaning: SpO2 Desat Limit
    - Token: 224162_insp/min:1                             | Meaning: Resp Alarm - Low
    - Token: 224161_insp/min:1                             | Meaning: Resp Alarm - High
    - Token: 220047_bpm:2                                  | Meaning: Heart Rate Alarm - Low
    - Token: 220046_bpm:2                                  | Meaning: Heart rate Alarm - High
    - Token: 223752_mmHg:3                                 | Meaning: Non-Invasive Blood Pressure Alarm - Low
    - Token: 223769_%:1                                    | Meaning: O2 Saturation Pulseoxymetry Alarm - High
    - Token: 223770_%:1                                    | Meaning: O2 Saturation Pulseoxymetry Alarm - Low
    - Token: 223751_mmHg:3                                 | Meaning: Non-Invasive Blood Pressu

In [8]:
run_demo_for_entity(selected_fp_idx, "False Positive (FP)", model, X_all, Y_all, Paths_all_list, index2token, itemid_to_label, DT, device, padding_idx)



--- Demo for [False Positive (FP)] Entity: 32803/episode1_timeseries_48.csv ---
    True Mortality: 0

Hour 0 -1 :
  Events:
    - Token: 4171_:13                                      | Meaning: Time
    - Token: 4185_:nan                                     | Meaning: 5 min
    - Token: 4184_:nan                                     | Meaning: Apgar      1 min
    - Token: 4175_:NEWBORN                                 | Meaning: Admit Reason
    - Token: 4171_:nan                                     | Meaning: Time
    - Token: 3446_:nan                                     | Meaning: Gestational Age
    - Token: 4169_:nan                                     | Meaning: EDC
    - Token: 926_:U                                        | Meaning: Religion
    - Token: 4186_:nan                                     | Meaning: 10 min
    - Token: 4180_:nan                                     | Meaning: Current Level
    - Token: 4181_:nan                                     | Meaning: Highest 

In [9]:
run_demo_for_entity(selected_fn_idx, "False Negative (FN)", model, X_all, Y_all, Paths_all_list, index2token, itemid_to_label, DT, device, padding_idx)



--- Demo for [False Negative (FN)] Entity: 77927/episode1_timeseries_48.csv ---
    True Mortality: 1

Hour 0 -1 :
  Events:
    - Token: 220277_%:3                                    | Meaning: O2 saturation pulseoxymetry
    - Token: 220210_insp/min:10                            | Meaning: Respiratory Rate
    - Token: 220045_bpm:18                                 | Meaning: Heart Rate
    - Token: 223876_sec:1                                  | Meaning: Apnea Interval
    - Token: 224697_cmH2O:1                                | Meaning: Mean Airway Pressure
    - Token: 224695_cmH2O:3                                | Meaning: Peak Insp. Pressure
    - Token: 224689_insp/min:7                             | Meaning: Respiratory Rate (spontaneous)
    - Token: 224687_L/min:18                               | Meaning: Minute Volume
    - Token: 220339_cmH2O:1                                | Meaning: PEEP set
    - Token: 223835_:3                                     | Meaning: Inspired

In [10]:
run_demo_for_entity(selected_tn_idx, "True Negative (TN)", model, X_all, Y_all, Paths_all_list, index2token, itemid_to_label, DT, device, padding_idx)


--- Demo for [True Negative (TN)] Entity: 19150/episode1_timeseries_48.csv ---
    True Mortality: 0

Hour 0 -1 :
  Events:
    - Token: 1530_:11                                      | Meaning: INR
    - Token: 51222_g/dL:0                                  | Meaning: Hemoglobin
    - Token: 51214_mg/dL:0                                 | Meaning: Fibrinogen, Functional
    - Token: 51221_%:0                                     | Meaning: Hematocrit
    - Token: 51266_:LOW                                    | Meaning: Platelet Smear
    - Token: 51265_K/uL:6                                  | Meaning: Platelet Count
    - Token: 1528_:0                                       | Meaning: Fibrinogen
    - Token: 51237_:10                                     | Meaning: INR(PT)
    - Token: 51279_m/uL:0                                  | Meaning: Red Blood Cells
    - Token: 51277_%:0                                     | Meaning: RDW
    - Token: 51275_sec:17                                