# Cell 1

In [1]:
## 1. Setup and Configuration


# Import standard libraries
import logging
import os
import sys
import torch
import numpy as np
import pandas as pd
import pprint # For pretty printing config

# --- Configure Logging ---
# Setup basic logging to see outputs from modules
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(name)s] - %(message)s')
log = logging.getLogger(__name__) # Logger for the notebook itself

# --- Add project directory to Python path (if needed) ---
# Ensure Python can find your modules (utils.py, data_loader.py, etc.)
# Adjust the path as necessary relative to your notebook location
project_path = '.' # Assuming notebook is in the root project directory
if project_path not in sys.path:
    sys.path.append(project_path)
    log.info(f"Added {project_path} to sys.path")

# --- Import Project Modules ---
try:
    from utils import load_config, load_preprocessed_data, safe_get
    from data_loader import load_all_datasets
    # Import the main function from the updated preprocessing module
    from preprocessing import preprocess_all_subjects
    # Import the main function from the updated data_pipeline module
    from data_pipeline import prepare_dataloaders
    from models import get_model
    from training import train_model
    from evaluation import evaluate_model, find_best_threshold, calculate_shap_importance
    from tuning import run_tuning # Optional: for hyperparameter tuning
    # Import widget setup functions
    from widget_setup import (
        setup_raw_signal_plotter,
        setup_comparison_plotter,
        setup_hrv_plotter,
        setup_prediction_plotter,
        display_evaluation_results
    )
    # Import specific plotting functions if needed directly (usually called via widgets)
    # from visualization import plot_training_history
except ImportError as e:
    log.critical(f"Failed to import necessary project modules: {e}. Ensure modules are in the Python path.")
    # Stop execution or handle error appropriately

# --- Load Configuration ---
CONFIG_PATH = 'config.json' # Path to your config file
config = load_config(CONFIG_PATH)

if config:
    log.info("Configuration loaded successfully.")
    # Optional: Pretty print the loaded config
    # print("--- Configuration ---")
    # pprint.pprint(config)
    # print("-" * 20)
else:
    log.critical("Failed to load configuration. Please check config.json path and format.")
    # Stop execution if config fails to load

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log.info(f"Using device: {device}")
if device == torch.device("cuda"):
    log.info(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")

2025-04-01 15:40:15,568 - INFO - [__main__] - Added . to sys.path
2025-04-01 15:40:19,765 - INFO - [utils] - Attempting to load configuration from: config.json
2025-04-01 15:40:19,787 - INFO - [utils] - Configuration loaded successfully.
2025-04-01 15:40:19,787 - INFO - [utils] - Checking and creating save directories specified in config...
2025-04-01 15:40:19,789 - INFO - [tuning] - Loading preprocessed data for tuning...
2025-04-01 15:40:19,789 - INFO - [utils] - --- Attempting to LOAD preprocessed data from .joblib files ---
2025-04-01 15:40:20,030 - INFO - [utils] - Successfully loaded Processed Data from: d:\Downloads\StressProject\outputs\processed_data\processed_aligned_data.joblib
2025-04-01 15:40:20,042 - INFO - [utils] - Successfully loaded Static Features from: d:\Downloads\StressProject\outputs\static_features\static_features_results.joblib
2025-04-01 15:40:20,052 - INFO - [utils] - Successfully loaded R-Peak Indices from: d:\Downloads\StressProject\outputs\static_features\

# Cell 2 

In [None]:
# Load raw data using the data_loader module
all_subject_data, subjects_loaded, subjects_failed = load_all_datasets(config)

if not all_subject_data:
    log.critical("No raw data loaded. Exiting.")
    # Handle error appropriately
else:
    log.info(f"Loaded raw data for {len(subjects_loaded)} subjects.")
    if subjects_failed:
        log.warning(f"Failed to load data for: {subjects_failed}")

# Store the list of loaded subject IDs (original format, e.g., 2, 'NURSE_1')
# This is needed for the widget setups
loaded_subject_ids_orig = subjects_loaded

# Cell 3 

In [None]:
# Setup the interactive plotter for raw signals
# Pass the raw data dictionary and the list of original subject IDs
if all_subject_data and loaded_subject_ids_orig:
    setup_raw_signal_plotter(config, all_subject_data, loaded_subject_ids_orig)
else:
    print("Skipping raw signal plotter setup: Raw data or subject list missing.")

# Cell 4

In [None]:
import logging
import os
# Assuming other necessary imports like preprocess_all_subjects, load_preprocessed_data, safe_get are done elsewhere

# Get logger instance (replace __name__ if needed)
log = logging.getLogger(__name__)

# --- Assume config, all_subject_data, subjects_loaded are defined earlier ---
# Example placeholders (replace with your actual variables):
# config = load_config('config.json')
# all_subject_data, subjects_loaded, _ = load_all_datasets(config)

# --- Control Flag ---
# Set this flag to True to load existing data, False to run preprocessing again.
LOAD_SAVED_DATA = True # Or False

# --- Initialize variables ---
processed_data = None
static_features_results = None
r_peak_results = None

# --- Conditional Preprocessing or Loading ---
if LOAD_SAVED_DATA:
    log.info("Attempting to load preprocessed data...")
    try:
        # Ensure the function is available in the scope
        from utils import load_preprocessed_data
        processed_data, static_features_results, r_peak_results = load_preprocessed_data(config)
        if not processed_data:
            log.warning("load_preprocessed_data returned empty data. Preprocessing might need to be run.")
        else:
            log.info("Successfully loaded preprocessed data.")
    except ImportError:
         log.error("Could not import load_preprocessed_data from utils. Cannot load data.")
    except Exception as load_e:
        log.error(f"Error loading preprocessed data: {load_e}", exc_info=True)
        # Decide if you want to fallback to running preprocessing here, or just exit
        # For now, it will just proceed and fail the final check if data is None

else:
    log.info("Running preprocessing pipeline (LOAD_SAVED_DATA is False)...")
    if 'all_subject_data' in locals() and 'subjects_loaded' in locals() and 'config' in locals():
        try:
            # Ensure the function is available in the scope
            from preprocessing import preprocess_all_subjects
            processed_data, static_features_results, r_peak_results = preprocess_all_subjects(
                all_subject_data, subjects_loaded, config
            )
        except ImportError:
            log.error("Could not import preprocess_all_subjects from preprocessing. Cannot run preprocessing.")
        except Exception as preprocess_e:
            log.error(f"Error running preprocessing pipeline: {preprocess_e}", exc_info=True)
            # processed_data will remain None
    else:
        log.error("Raw data ('all_subject_data', 'subjects_loaded') or 'config' not available. Cannot run preprocessing.")


# --- Final Check ---
# This check runs regardless of whether loading or preprocessing was attempted
if not processed_data:
    log.critical("Preprocessing/loading failed or no processed data obtained. Exiting or handling error...")
    # Handle error appropriately (e.g., raise exception, exit script, etc.)
    # exit() or raise RuntimeError("Failed to obtain processed data.")
else:
    log.info(f"Preprocessing complete or data loaded successfully for {len(processed_data)} subjects.")
    # Proceed with the rest of your pipeline using processed_data, etc.



# Cell 5

In [None]:
# Setup the comparison plotter (Raw vs. Resampled)
if all_subject_data and processed_data and loaded_subject_ids_orig:
    setup_comparison_plotter(config, all_subject_data, processed_data, loaded_subject_ids_orig)
else:
    print("Skipping comparison plotter setup: Raw or processed data missing.")

# Setup the ECG + R-Peaks + Labels plotter
# It will try to load r_peak_results if not provided
if processed_data:
     setup_hrv_plotter(config, processed_data, r_peak_results) # Pass loaded r_peak_results if available
else:
     print("Skipping ECG plotter setup: Processed data missing.")


# Cell 6 

In [None]:
if processed_data and static_features_results:
    train_loader, val_loader, test_loader, input_dim_sequence, input_dim_static = prepare_dataloaders(
        processed_data, static_features_results, config
    )
    if train_loader and val_loader and test_loader:
        log.info("DataLoaders created successfully.")
        log.info(f"Input Dimensions - Sequence: {input_dim_sequence}, Static: {input_dim_static}")
    else:
        log.critical("Failed to create DataLoaders. Cannot proceed with training.")
        # Handle error
else:
    log.critical("Processed data or static features missing. Cannot create DataLoaders.")
    # Handle error

# Cell 7 

In [None]:
# Ensure dimensions were calculated successfully in the previous step
if 'input_dim_sequence' in locals() and 'input_dim_static' in locals():
    model = get_model(config, input_dim_sequence, input_dim_static)
    log.info(f"Model '{type(model).__name__}' built.")
    # Optional: Print model summary
    # print(model)
else:
    log.critical("Input dimensions not available. Cannot build model.")
    # Handle error

# Cell 8

In [None]:
import torch
import logging
import os
# Assuming other necessary imports like train_model, safe_get, plot_training_history are done elsewhere

# Get logger instance (replace __name__ if needed)
log = logging.getLogger(__name__)

# --- Assume config, model, train_loader, val_loader, device are defined earlier ---
# Example placeholders (replace with your actual variables):
# config = load_config('config.json')
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = get_model(config, input_dim_sequence, input_dim_static).to(device)
# train_loader, val_loader, test_loader, _, _ = prepare_dataloaders(...)

# Ensure model and loaders are available
if 'model' in locals() and 'train_loader' in locals() and 'val_loader' in locals() and 'config' in locals() and 'device' in locals():
    output_dir_models = safe_get(config, ['save_paths', 'models'])
    if not output_dir_models:
        log.warning("Model output directory not specified in config. Best model will not be saved to file.")

    log.info("Starting model training...")
    # Run training
    best_model_state, history = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        config=config,
        device=device,
        output_dir=output_dir_models # Pass directory to save best model
    )
    log.info("Model training finished.")

    # --- Plot Training History ---
    if history:
        output_dir_results = safe_get(config, ['save_paths', 'results'])
        if output_dir_results:
            try:
                # Ensure the function is available in the scope
                from visualization import plot_training_history
                plot_training_history(history, output_dir_results)
            except ImportError:
                 log.error("Could not import plot_training_history from visualization.")
            except Exception as plot_e:
                 log.error(f"Error plotting training history: {plot_e}", exc_info=True)
        else:
            log.warning("Results directory not specified. Cannot save training history plot.")
    elif best_model_state is not None:
         log.warning("Training completed but history dictionary was not returned.")
    else:
        # This case implies train_model returned (None, None), indicating failure
        log.error("Training failed or history was not returned.")


    # --- Clear CUDA Cache (if CUDA was used) ---
    # Important after training, especially if followed by other GPU tasks or loops
    if device == torch.device("cuda"):
        try:
            log.info("Clearing CUDA cache...")
            torch.cuda.empty_cache()
            log.info("CUDA cache cleared.")
        except Exception as e:
            log.error(f"Failed to clear CUDA cache: {e}")

    # --- Optional: Clean up variables to free memory ---
    # Useful in notebooks or long-running scripts
    # del model, train_loader, val_loader, history, best_model_state
    # log.debug("Cleaned up training variables.")


else:
    log.critical("Model, DataLoaders, Config, or Device not available. Skipping training and cache clearing.")
    # Handle error appropriately (e.g., raise exception, exit script)



KeyboardInterrupt: 

# Cell 9 

In [None]:
import json
# Ensure the best model state and test loader are available
if 'best_model_state' in locals() and best_model_state and 'test_loader' in locals() and test_loader:
    log.info("Evaluating the best model on the test set...")
    # Load the best model state into the model structure
    model.load_state_dict(best_model_state)

    # Get loss function (needed for reporting loss during evaluation)
    # Note: pos_weight calculated earlier might be based on the *sampled* train set.
    # For consistent loss reporting on the test set, usually use unweighted BCE or FocalLoss.
    from losses import FocalLoss # Ensure import
    eval_criterion = None
    eval_loss_type = safe_get(config, ['training_config', 'loss_function'], 'bce').lower()
    if eval_loss_type == 'focal':
        alpha = safe_get(config, ['training_config', 'focal_loss_alpha'], 0.25)
        gamma = safe_get(config, ['training_config', 'focal_loss_gamma'], 2.0)
        eval_criterion = FocalLoss(alpha=alpha, gamma=gamma, reduction='mean')
    else: # Default to unweighted BCE for evaluation reporting
        eval_criterion = torch.nn.BCEWithLogitsLoss()
    log.info(f"Using {type(eval_criterion).__name__} for reporting evaluation loss.")

    output_dir_results = safe_get(config, ['save_paths', 'results'])
    if not output_dir_results:
        log.error("Results output directory not specified. Cannot save evaluation results/plots.")
        # Handle error or proceed without saving

    # Run evaluation
    test_results = evaluate_model(
        model=model,
        dataloader=test_loader,
        criterion=eval_criterion,
        device=device,
        config=config,
        output_dir=output_dir_results, # Pass directory to save plots
        set_name="Test",
        threshold=None # Set to None to find best F1 threshold on test set, or specify e.g., 0.5
    )

    # Save numerical results to JSON
    if test_results and output_dir_results:
        results_save_path = os.path.join(output_dir_results, "test_evaluation_results.json")
        try:
            # Remove potentially large items before saving if desired
            # test_results.pop('probabilities', None)
            # test_results.pop('labels', None)
            with open(results_save_path, 'w') as f:
                json.dump(test_results, f, indent=4)
            log.info(f"Test evaluation results saved to {results_save_path}")

            # Display results and plots using the widget setup function
            display_evaluation_results(config, results_save_path)

        except Exception as e:
            log.error(f"Failed to save or display test evaluation results: {e}")

    elif not test_results:
        log.error("Model evaluation failed.")

else:
    log.critical("Best model state or test loader not available. Skipping evaluation.")
    # Handle error


# Cell 10

In [None]:
# This requires generating predictions for all windows first.
# The evaluate_model function returns probabilities and labels, but not mapped back to subjects easily.
# You might need a separate function to run inference on the test_loader (or val_loader)
# and store predictions per subject ID and window start time.

# --- Placeholder: Function to get predictions per subject ---
def get_all_predictions(model, dataloader, device, threshold=0.5):
    model.eval()
    predictions_by_subject = {} # {subj_id: ([starts], [preds])}
    with torch.no_grad():
        for batch_data in dataloader:
            seq, static, _, subj_ids_tensor, starts_tensor = batch_data # Get subj_ids and starts
            seq = seq.to(device)
            static = static.to(device) if hasattr(model, 'input_dim_static') and model.input_dim_static > 0 else None
            outputs = model(seq, static)
            probs = torch.sigmoid(outputs.squeeze())
            preds = (probs > threshold).int().cpu().numpy()
            starts = starts_tensor.cpu().numpy()
            # Assuming subject IDs are stored in dataset.subject_ids_list if they were strings
            # This part needs careful handling depending on how string IDs are managed in the Dataset
            subj_ids_list = [dataloader.dataset.subject_ids_list[i] for i in range(len(seq))] # Example access

            for i in range(len(preds)):
                subj_id = subj_ids_list[i] # Get original subject ID
                start = starts[i]
                pred = preds[i]
                if subj_id not in predictions_by_subject:
                    predictions_by_subject[subj_id] = ([], [])
                predictions_by_subject[subj_id][0].append(start)
                predictions_by_subject[subj_id][1].append(pred)
    return predictions_by_subject
# --- End Placeholder ---

# --- Generate predictions (e.g., on test set) ---
if 'model' in locals() and 'test_loader' in locals() and test_loader:
    log.info("Generating predictions for visualization...")
    # Use the threshold determined during evaluation or a fixed one (e.g., 0.5)
    eval_threshold = test_results.get('threshold_used', 0.5) if 'test_results' in locals() and test_results else 0.5
    all_predictions_map = get_all_predictions(model, test_loader, device, threshold=eval_threshold)

    # --- Setup Prediction Plotter ---
    if processed_data and all_predictions_map:
        setup_prediction_plotter(config, processed_data, all_predictions_map)
    else:
        print("Skipping prediction plotter: Processed data or predictions map missing.")
else:
    print("Skipping prediction plotter: Model or test loader not available.")


# Cell 11

In [None]:
# --- Run Optuna Tuning ---
# Set the number of trials
num_tuning_trials = 50 # Example: 50 trials
log.info(f"Starting hyperparameter tuning for {num_tuning_trials} trials...")
run_tuning(n_trials=num_tuning_trials)
log.info("Hyperparameter tuning finished.")
# # After tuning, check the 'best_hyperparameters.json' file saved in the results directory
# # and update your main 'config.json' accordingly before final training.

# Cell 12

# Cell 13