In [1]:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils # For gradient clipping
import torch.autograd # For anomaly detection
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR 
import numpy as np
import pandas as pd
import os
import sys
import logging
from collections import defaultdict
from tqdm.notebook import tqdm # Use notebook version of tqdm (ensure ipywidgets is installed)
# from tqdm import tqdm # Alternative if ipywidgets causes issues
from sklearn.metrics import roc_auc_score
from timeit import default_timer
from datetime import datetime
import csv


from utils.modelIO import load_metadata, load_model, save_model, save_metadata
from utils.datasets import get_dataloaders # Use corrected version
from utils.helpers import get_n_param, new_model_dir, set_seed, array # Import array helper
from models.losses import BCE # Use corrected version
from models.models import MODELS, init_model # Use corrected version


In [2]:
# --- Configuration ---
T_HOURS = 48
N_BINS = 20
SEED = 0
MAX_LEN = 10000 # Max sequence length used during training

# Paths
PROJECT_ROOT = "/changed" # Adjust if needed
DATA_ROOT_DIR = os.path.join(PROJECT_ROOT, "final_data") # Or your output_dir
RESULTS_DIR = os.path.join(PROJECT_ROOT, "results")
MODEL_RUN_NAME_BASE = "MortalityLSTM" # Give a unique name for this run

# Hyperparameters
LR = 0.0001 
EPOCHS = 150 # Number of epochs to run
BATCH_SIZE = 128
LATENT_DIM = 32
HIDDEN_DIM = 256
P_DROPOUT = 0.0
EARLY_STOPPING_PATIENCE = 5 # Set to large number or None to effectively disable for full run
MODEL_TYPE = 'Mortality'
DT = 1.0
WEIGHTED = False # Set to False to test without weighted embeddings
DYNAMIC = True




In [3]:

# --- Setup Logging ---
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s [%(levelname)s] %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger("FullRunNotebook")

# --- Enable Anomaly Detection (Optional) ---
# torch.autograd.set_detect_anomaly(True)
# logger.info("PyTorch anomaly detection enabled.")

# --- Redefine LossesLogger (or import from train.py) ---
class LossesLogger(object):
    """ Minimal logger for epoch results """
    def __init__(self, file_path_name):
        self.file_path_name = file_path_name
        self.header_written = False
        # Ensure directory exists
        log_dir = os.path.dirname(file_path_name)
        if log_dir: os.makedirs(log_dir, exist_ok=True)
        # Clear file if exists
        if os.path.isfile(file_path_name): os.remove(file_path_name)

    def log(self, epoch, storer):
        try:
            fieldnames = ['Epoch', 'Train_Loss', 'Valid_Loss', 'Valid_AUROC']
            # Extract metrics, providing default NaN if key missing or value is None/empty
            train_loss = np.nanmean(storer.get('train_loss', [np.nan]))
            valid_loss = np.nanmean(storer.get('valid_loss', [np.nan]))
            valid_auroc = np.nanmean(storer.get('auroc', [np.nan])) # Assumes 'auroc' key is used

            result_data = {
                'Epoch': epoch + 1,
                'Train_Loss': f"{train_loss:.4f}" if np.isfinite(train_loss) else "nan",
                'Valid_Loss': f"{valid_loss:.4f}" if np.isfinite(valid_loss) else "nan",
                'Valid_AUROC': f"{valid_auroc:.4f}" if np.isfinite(valid_auroc) else "nan"
            }

            file_exists = os.path.isfile(self.file_path_name)
            with open(self.file_path_name, 'a', newline='', encoding='utf-8') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                if not file_exists or os.path.getsize(self.file_path_name) == 0:
                    writer.writeheader()
                writer.writerow(result_data)
        except Exception as e:
             print(f"Error logging epoch {epoch+1} results: {e}")


set_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

2025-04-30 13:01:12 [INFO] Using device: cpu


In [4]:
# ## 3. Define Paths and Create Directories

model_name_suffix = f't{T_HOURS}_lr{LR}_z{LATENT_DIM}' \
                  + f'_h{HIDDEN_DIM}_p{P_DROPOUT}_w{WEIGHTED}_d{DYNAMIC}_seed{SEED}' # Added weighted/dynamic
model_run_name = f'{MODEL_RUN_NAME_BASE}_{model_name_suffix}'
model_dir = os.path.join(RESULTS_DIR, model_run_name)

# Define paths to data files
array_dir = os.path.join(DATA_ROOT_DIR, 'arrays')
dict_dir = os.path.join(DATA_ROOT_DIR, 'dictionaries')
split_dir = os.path.join(DATA_ROOT_DIR, 'splits')

array_path = os.path.join(array_dir, f'{T_HOURS}_{SEED}_{N_BINS}-arrays.npz')
token_map_path = os.path.join(dict_dir, f'{T_HOURS}_{SEED}_{N_BINS}-token2index.npy')
train_split_path = os.path.join(split_dir, f'{SEED}-{T_HOURS}-train.csv')
valid_split_path = os.path.join(split_dir, f'{SEED}-{T_HOURS}-valid.csv')
test_split_path = os.path.join(split_dir, f'{SEED}-{T_HOURS}-test.csv')

# Check files exist
required_files = [array_path, token_map_path, train_split_path, valid_split_path, test_split_path]
for f_path in required_files:
    if not os.path.exists(f_path):
        raise FileNotFoundError(f"Required data file not found: {f_path}")

# Create results directory for this run
new_model_dir(model_dir, logger=logger)
losses_logger = LossesLogger(os.path.join(model_dir, 'epoch_losses.csv')) # Init logger

In [5]:
logger.info("Loading token map...")

token2index = np.load(token_map_path, allow_pickle=True).item()
n_tokens = len(token2index)
logger.info(f"Vocabulary size (n_tokens): {n_tokens}")
if n_tokens == 0: raise ValueError("Token map is empty.")


logger.info("Loading train/validation dataloaders...")
train_loader, valid_loader = get_dataloaders(
    array_path=array_path, token_map_path=token_map_path,
    train_split_path=train_split_path, valid_split_path=valid_split_path,
    validation=True, t_hours=T_HOURS, dt=DT, dynamic=DYNAMIC,
    batch_size=BATCH_SIZE, logger=logger, shuffle=True
)
logger.info(f'Loaded {len(train_loader.dataset)} training samples, {len(valid_loader.dataset)} validation samples.')


logger.info("Loading test dataloader...")
test_loader, _ = get_dataloaders(
    array_path=array_path, token_map_path=token_map_path,
    test_split_path=test_split_path, # Provide test split path
    validation=False, # Load test set
    t_hours=T_HOURS, dt=DT, dynamic=DYNAMIC,
    batch_size=BATCH_SIZE, logger=logger, shuffle=False # No shuffle for test
)


2025-04-30 13:01:12 [INFO] Loading token map...
2025-04-30 13:01:12 [INFO] Vocabulary size (n_tokens): 39727
2025-04-30 13:01:12 [INFO] Loading train/validation dataloaders...
2025-04-30 13:01:12 [INFO] Loading full dataset arrays from: /changed/final_data/arrays/48_0_20-arrays.npz
2025-04-30 13:02:06 [INFO]  Loaded data shapes: X=(24424, 10000, 2), Y=(24424,), Paths=(24424,)
2025-04-30 13:02:06 [INFO] Loading token map from: /changed/final_data/dictionaries/48_0_20-token2index.npy
2025-04-30 13:02:06 [INFO]  Token map loaded. Vocabulary size (n_tokens): 39727
2025-04-30 13:02:06 [INFO] Loaded training set: 9539 samples based on /changed/final_data/splits/0-48-train.csv
2025-04-30 13:02:06 [INFO] Loaded validation set: 10000 samples based on /changed/final_data/splits/0-48-valid.csv
2025-04-30 13:02:07 [INFO] 🔍 Initial Token ID range in loaded data: min=0, max=39725
2025-04-30 13:02:07 [INFO]   Expected vocabulary size (n_tokens): 39727
2025-04-30 13:02:07 [INFO]   Token IDs are within

In [6]:
# ## 5. Initialize Model, Optimizer, Loss

logger.info(f"Initializing model: {MODEL_TYPE}")

# Ensure using models.py that doesn't have ReLU after embedder if testing that fix
model = init_model(
    model_type=MODEL_TYPE, n_tokens=n_tokens, latent_dim=LATENT_DIM,
    hidden_dim=HIDDEN_DIM, p_dropout=P_DROPOUT, dt=DT,
    weighted=WEIGHTED, # Pass the flag
    dynamic=DYNAMIC
)
logger.info(f'# trainable parameters: {get_n_param(model):,}')
model = model.to(device)


# Optimizer and Loss
optimizer = optim.Adam(model.parameters(), lr=LR)
loss_f = BCE() # Use corrected version expecting logits



2025-04-30 13:02:10 [INFO] Initializing model: Mortality


Initializing model of type Mortality...


2025-04-30 13:02:10 [INFO] # trainable parameters: 1,572,097


In [7]:
# ## 6. Define Evaluation Function

def run_evaluation(model, loader, loss_f, device, phase='valid'):
    """Runs evaluation on validation or test set."""
    model.eval() # Set model to evaluation mode
    epoch_loss = 0.
    y_preds_list = []
    y_trues_list = []
    num_batches = len(loader)
    storer = defaultdict(list) # Storer for this evaluation run

    iterator = tqdm(enumerate(loader), total=num_batches, desc=f"Evaluating ({phase})", leave=False)

    with torch.no_grad(): # Disable gradients
        for i, batch in iterator:
            try:
                 data, y_true = batch
            except ValueError:
                 logger.error(f"Unexpected batch format at eval iteration {i}. Skipping batch.")
                 continue

            data = data.to(device)
            y_true = y_true.to(device)

            try:
                y_pred = model(data) # Get logits
                iter_loss = loss_f(y_pred, y_true, is_train=False, storer=storer) # Pass storer

                if not torch.isfinite(iter_loss):
                     logger.warning(f"NaN or Inf loss detected during {phase} iteration {i+1}.")
                     epoch_loss += float('nan')
                else:
                     epoch_loss += iter_loss.item()

                # Store predictions (logits) and true labels
                y_preds_list.append(y_pred.detach().cpu())
                y_trues_list.append(y_true.detach().cpu())

                iterator.set_postfix(loss=f"{iter_loss.item():.4f}" if torch.isfinite(iter_loss) else "nan")

            except Exception as e:
                 logger.error(f"Error during {phase} iteration {i+1}: {e}", exc_info=True)

    avg_loss = np.nanmean([l for l in storer.get(f'{phase}_loss', []) if l is not None]) if storer.get(f'{phase}_loss') else float('nan')
    if not np.isfinite(avg_loss) and num_batches > 0: avg_loss = epoch_loss / num_batches # Fallback calculation

    # --- Compute Metrics ---
    final_metrics = {'loss': avg_loss, 'auroc': None} # Default metrics
    try:
        y_preds_all = torch.cat(y_preds_list, dim=0)
        y_trues_all = torch.cat(y_trues_list, dim=0)

        # Convert to numpy for sklearn
        y_pred_np = array(y_preds_all)
        y_true_np = array(y_trues_all)

        # Handle dynamic output shapes (take last time step)
        if y_pred_np.ndim == 2: y_pred_np = y_pred_np[:, -1]
        if y_true_np.ndim == 2: y_true_np = y_true_np[:, -1]

        # Apply sigmoid to get probabilities for AUROC
        probs = 1 / (1 + np.exp(-y_pred_np))

        if np.isnan(probs).any() or np.isnan(y_true_np).any():
             logger.warning(f"NaN values detected in {phase} predictions or labels before metric calculation.")
        else:
             if len(np.unique(y_true_np)) > 1:
                  final_metrics['auroc'] = roc_auc_score(y_true_np, probs)
             else:
                  logger.warning(f"Only one class present in true labels during {phase}. AUROC is not defined.")

    except Exception as e:
         logger.error(f"Error computing {phase} metrics: {e}")

    return final_metrics, y_preds_all, y_trues_all # Return preds/trues as well

In [8]:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR, CosineAnnealingLR
import numpy as np
from tqdm import tqdm
from timeit import default_timer
import logging
import os
from collections import defaultdict


# --- Scheduler and Optimizer Setup (from your code) ---
SCHEDULER_TYPE = 'ReduceLROnPlateau'
SCHEDULER_FACTOR = 0.5 # Factor by which the LR is reduced (new_lr = lr * factor)
SCHEDULER_PATIENCE = 2 # Number of epochs with no improvement after which LR is reduced
SCHEDULER_MODE = 'max' # 'max' for AUROC, 'min' for loss
SCHEDULER_MIN_LR = 1e-6 # Minimum learning rate


optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = None
if SCHEDULER_TYPE == 'ReduceLROnPlateau':
    scheduler = ReduceLROnPlateau(optimizer,
                                  mode=SCHEDULER_MODE,
                                  factor=SCHEDULER_FACTOR,
                                  patience=SCHEDULER_PATIENCE,
                                  min_lr=SCHEDULER_MIN_LR,
                                  verbose=True)
# Add other scheduler types if needed...

logger.info("--- Starting Full Training & Validation Loop ---")

start_time = default_timer()
max_v_auroc = -np.inf
early_stopping_counter = 0
patience = float('inf') if EARLY_STOPPING_PATIENCE is None or EARLY_STOPPING_PATIENCE <= 0 else EARLY_STOPPING_PATIENCE

# --- Calculate Total Batches and Initialize Overall Progress Bar ---
num_train_batches_per_epoch = len(train_loader)
total_batches = EPOCHS * num_train_batches_per_epoch
overall_progress_bar = tqdm(total=total_batches, desc="Overall Training Progress", unit="batch")
# -------------------------------------------------------------------

try: # Use try...finally to ensure the progress bar is closed
    for epoch in range(EPOCHS):
        model.train() # Set model to training mode
        epoch_train_loss = 0.
        epoch_storer = defaultdict(list) # Storer for metrics logged by loss_f
        nan_detected_in_epoch = False

        # --- Update Progress Bar Description for the Current Epoch ---
        overall_progress_bar.set_description(f"Epoch {epoch+1}/{EPOCHS}")
        # -------------------------------------------------------------

        # --- Training Batch Loop (Iterate directly, NO inner tqdm) ---
        for i, batch in enumerate(train_loader):
            try: data, y_true = batch
            except ValueError: logger.error(f"Train batch format error iter {i}. Skipping."); continue

            data = data.to(device)
            y_true = y_true.to(device)
            optimizer.zero_grad()

            try:
                y_pred = model(data)
                iter_loss = loss_f(y_pred, y_true, is_train=True, storer=epoch_storer)

                if not torch.isfinite(iter_loss):
                    logger.error(f"NaN or Inf loss detected during training iteration {i+1} in Epoch {epoch+1}. Stopping epoch.")
                    nan_detected_in_epoch = True; break

                epoch_train_loss += iter_loss.item()
                iter_loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

                # --- Update Overall Progress Bar ---
                current_lr = optimizer.param_groups[0]['lr']
                overall_progress_bar.set_postfix(loss=f"{iter_loss.item():.4f}", lr=f"{current_lr:.1e}", refresh=True) # refresh=True might be needed
                overall_progress_bar.update(1) # Increment the overall progress bar
                # -----------------------------------

            except Exception as e:
                logger.error(f"Error during training iteration {i+1} in Epoch {epoch+1}: {e}", exc_info=True)
                nan_detected_in_epoch = True; break # Stop epoch on error



        # If epoch completed normally, calculate average loss
        avg_train_loss = epoch_train_loss / num_train_batches_per_epoch if num_train_batches_per_epoch > 0 else float('nan')
        epoch_storer['train_loss'] = [avg_train_loss]

        # --- Validation Step ---
        val_metrics, _, _ = run_evaluation(model, valid_loader, loss_f, device, phase='valid')
        val_loss = val_metrics.get('loss', float('nan'))
        val_auroc = val_metrics.get('auroc', None)
        epoch_storer['valid_loss'] = [val_loss]
        if val_auroc is not None:
            epoch_storer['auroc'] = [val_auroc]


        # --- Logging ---
        current_lr = optimizer.param_groups[0]['lr']
        log_msg = f"Epoch {epoch+1} End: Train Loss={avg_train_loss:.4f}, Valid Loss={val_loss:.4f}"
        if val_auroc is not None: log_msg += f", Valid AUROC={val_auroc:.4f}"
        else: log_msg += ", Valid AUROC=N/A"
        logger.info(log_msg)
        if losses_logger: losses_logger.log(epoch, epoch_storer)

        # --- Scheduler Step (End of Epoch) ---
        if scheduler:
            if isinstance(scheduler, ReduceLROnPlateau):
                 metric_to_monitor = val_auroc if val_auroc is not None else (-np.inf if SCHEDULER_MODE == 'max' else np.inf)
                 scheduler.step(metric_to_monitor)
            else:
                 scheduler.step()


        # --- Checkpointing & Early Stopping ---
        improved = False
        if val_auroc is not None and np.isfinite(val_auroc):
            if val_auroc > max_v_auroc:
                max_v_auroc = val_auroc
                # Save model and metadata... (Your existing code)
                model_save_path = os.path.join(model_dir, 'model.pt')
                torch.save(model.state_dict(), model_save_path)
                run_metadata = { # Collect metadata for this run
                     't_hours': T_HOURS, 'n_bins': N_BINS, 'seed': SEED, 'lr': LR,
                     'epochs': EPOCHS, 'batch_size': BATCH_SIZE, 'latent_dim': LATENT_DIM,
                     'hidden_dim': HIDDEN_DIM, 'p_dropout': P_DROPOUT, 'dt': DT,
                     'weighted': WEIGHTED, 'dynamic': DYNAMIC, 'n_tokens': n_tokens,
                     'model_type': MODEL_TYPE,
                     'best_epoch': epoch + 1,
                     'best_auroc': max_v_auroc
                 }
                save_metadata(run_metadata, model_dir, filename='meta.json')
                early_stopping_counter = 0
                improved = True

        if not improved:
            early_stopping_counter += 1
            # Log lack of improvement... (Your existing code)
            if val_auroc is not None:
                 logger.info(f"  Validation AUROC did not improve from {max_v_auroc:.4f}. Early stopping counter: {early_stopping_counter}/{patience}")
            else:
                 logger.info(f"  Validation AUROC is N/A. Not checking improvement. Early stopping counter: {early_stopping_counter}/{patience}")


        if early_stopping_counter >= patience:
            logger.info(f"Early stopping triggered after {patience} epochs without improvement.")
            break # Break the main epoch loop

finally: # Ensure the progress bar is closed even if errors occur
    overall_progress_bar.close()
    # --- End of Training Loop ---
    delta_time = (default_timer() - start_time) / 60
    logger.info(f'Finished training loop after {delta_time:.1f} minutes. Best Valid AUROC: {max_v_auroc:.4f}')

2025-04-30 13:02:44 [INFO] --- Starting Full Training & Validation Loop ---
Epoch 1/150:   0%|          | 0/11250 [00:00<?, ?batch/s]              

Epoch 1/150:   0%|          | 5/11250 [00:17<10:40:50,  3.42s/batch, loss=0.6705, lr=1.0e-04]
2025-04-30 13:03:02 [INFO] Finished training loop after 0.3 minutes. Best Valid AUROC: -inf


KeyboardInterrupt: 

In [None]:
# ## 8. Final Test Set Evaluation

logger.info("--- Starting Final Test Set Evaluation ---")

# Load the best model saved during training
logger.info(f"Loading best model from {model_dir}...")
try:
    # Re-initialize model architecture
    best_model = init_model(
        model_type=MODEL_TYPE, n_tokens=n_tokens, latent_dim=LATENT_DIM,
        hidden_dim=HIDDEN_DIM, p_dropout=P_DROPOUT, dt=DT,
        weighted=WEIGHTED, dynamic=DYNAMIC
    )
    # Load state dict
    model_path = os.path.join(model_dir, 'model.pt')
    if not os.path.exists(model_path): raise FileNotFoundError("Best model file not found.")
    best_model.load_state_dict(torch.load(model_path, map_location=device))
    best_model.to(device)
    best_model.eval()
    logger.info("Best model loaded successfully.")
except Exception as e:
     logger.error(f"Error loading best model for testing: {e}"); raise

# Run evaluation on the test set
test_metrics, test_preds, test_trues = run_evaluation(best_model, test_loader, loss_f, device, phase='test')

test_loss = test_metrics['loss']
test_auroc = test_metrics['auroc']

logger.info(f"--- Test Set Results ---")
logger.info(f"✅ Test Loss : {test_loss:.4f}" if np.isfinite(test_loss) else "Test Loss : nan")
logger.info(f"✅ Test AUROC: {test_auroc:.4f}" if test_auroc is not None and np.isfinite(test_auroc) else "Test AUROC: N/A")

2025-04-22 18:14:14 [INFO] --- Starting Final Test Set Evaluation ---
2025-04-22 18:14:14 [INFO] Loading best model from /changed/results/MortalityLSTM_t48_lr0.0001_z32_h256_p0.0_wFalse_dTrue_seed0...


Initializing model of type Mortality...


2025-04-22 18:14:15 [INFO] Best model loaded successfully.
2025-04-22 18:14:15 [INFO] --- Test Set Results ---                          
2025-04-22 18:14:15 [INFO] ✅ Test Loss : 0.3734
2025-04-22 18:14:15 [INFO] ✅ Test AUROC: 0.6725


In [None]:
# Save test predictions and labels
test_results_path = os.path.join(model_dir, 'test_predictions.npz')

np.savez(test_results_path,
            predictions=test_preds.numpy(), # Convert tensors to numpy
            labels=test_trues.numpy())
logger.info(f"Test predictions saved to {test_results_path}")


# Append to overall summary CSV
summary_file_path = os.path.join(RESULTS_DIR, 'summary_results.csv')
logger.info(f"Appending results to summary file: {summary_file_path}")

fieldnames = [
    'timestamp', 'model_run_name', 't_hours', 'seed', 'n_bins', 'lr',
    'batch_size', 'latent_dim', 'hidden_dim', 'p_dropout', 'weighted', 'dynamic',
    'max_epochs', 'best_valid_auroc', 'test_loss', 'test_auroc', 'n_tokens', 'n_params',
    'model_dir', 'predictions_path'
]
n_params = get_n_param(best_model)

result_data = {
    'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
    'model_run_name': model_run_name,
    't_hours': T_HOURS, 'seed': SEED, 'n_bins': N_BINS, 'lr': LR,
    'batch_size': BATCH_SIZE, 'latent_dim': LATENT_DIM, 'hidden_dim': HIDDEN_DIM,
    'p_dropout': P_DROPOUT, 'weighted': WEIGHTED, 'dynamic': DYNAMIC,
    'max_epochs': EPOCHS, # Or store actual epochs run before early stopping
    'best_valid_auroc': f"{max_v_auroc:.4f}" if np.isfinite(max_v_auroc) else "-inf",
    'test_loss': f"{test_loss:.4f}" if np.isfinite(test_loss) else "nan",
    'test_auroc': f"{test_auroc:.4f}" if test_auroc is not None and np.isfinite(test_auroc) else "N/A",
    'n_tokens': n_tokens, 'n_params': n_params,
    'model_dir': model_dir,
    'predictions_path': test_results_path if os.path.exists(test_results_path) else 'N/A'
}


file_exists = os.path.isfile(summary_file_path)
with open(summary_file_path, 'a', newline='', encoding='utf-8') as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames, extrasaction='ignore') # Ignore extra keys if any
    if not file_exists or os.path.getsize(summary_file_path) == 0:
        writer.writeheader()
    writer.writerow(result_data)




2025-04-22 18:14:15 [INFO] Test predictions saved to /changed/results/MortalityLSTM_t48_lr0.0001_z32_h256_p0.0_wFalse_dTrue_seed0/test_predictions.npz
2025-04-22 18:14:15 [INFO] Appending results to summary file: /changed/results/summary_results.csv
