## Libraries

Note: This notebook should be run from the `high-order-anesthesia` folder to ensure the correct imports and file paths are used.

In [1]:
from pathlib import Path
import os

TARGET_DIR_NAME = "high-order-anesthesia"
def ensure_project_root(target_name: str = "high-order-anesthesia") -> Path:
    cwd = Path.cwd().resolve()
    if cwd.name == target_name:
        return cwd
    for parent in cwd.parents:
        if parent.name == target_name:
            os.chdir(parent)
            return parent
    raise RuntimeError(
        f"Could not find '{target_name}' in current path or parents. "
        f"Please run the notebook from inside the project."
    )
ROOT = ensure_project_root("high-order-anesthesia")
print(f"Now in: {ROOT.name}")


Now in: high-order-anesthesia


In [None]:
import pandas as pd
import torch
import numpy as np
import pandas as pd
import h5py
from collections import defaultdict
import time
import itertools
import logging
from tqdm.notebook import tqdm, trange

  from .autonotebook import tqdm as notebook_tqdm


Custom Libraries

In [3]:
from src.hoi_anesthesia.thoi_utils import simulated_annealing_parallel
from src.hoi_anesthesia.utils import max_difference_pairs
from src.hoi_anesthesia.io import load_covariance_dict, print_time
from src.hoi_anesthesia.plotting import plot_measures_accross_states

#### Data loading and preparation

In [5]:
results_path = "results"

# Load covariance matrices
all_covs = load_covariance_dict(f"{results_path}/covariance_matrices_gc.h5")

# States for each dataset; MA: Multi-anesthesia - DBS: Deep Brain Stimulation
conscious_states = {
    "MA": ["MA_awake"],  
    "DBS": ["DBS_awake", "ts_on_5V"],
}
nonresponsive_states = {
    "MA": ["ts_selv2", "ts_selv4", "moderate_propofol", "deep_propofol", "ketamine"],
     "DBS": ["ts_off", "ts_on_3V_control", "ts_on_5V_control"],
}


#### Simulated Annealing parameters

In [6]:
early_stop = 1000
max_iter = 10000
repeat = 100
batch_size = 300
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.empty_cache()

#### Select dataset and orders to optimize 
The script saves a checkpoint csv at each order so the execution can be interrupted without losing all of the work.

In [7]:
datasets_to_optimize = ['DBS', 'MA']
orders = [3,4]

In [None]:
for selected_dataset in datasets_to_optimize:
    t_i = time.time()
    for order in orders:
        results = []
        print("*" * 30)
        print("ORDER:", order)
        # Iterate over dataset/state combinations
        for state_c, state_nr in itertools.product(
            conscious_states[selected_dataset], nonresponsive_states[selected_dataset]
        ):
            covs_c = all_covs[selected_dataset][state_c]  # shape (N_c, 82, 82)
            covs_nr = all_covs[selected_dataset][state_nr]  # shape (N_nr, 82, 82)
            for target_task in ["Cpos", "NRpos"]:
                torch.cuda.empty_cache()
                cov_list = []
                subject_indices = []
                for i in range(covs_c.shape[0]):
                    for j in range(covs_nr.shape[0]):
                        cov_c = torch.from_numpy(covs_c[i])  # 82x82
                        cov_nr = torch.from_numpy(covs_nr[j])  # 82x82
                        if target_task == "Cpos":
                            cov_list.append(torch.stack([cov_c, cov_nr], dim=0))
                        elif target_task == "NRpos":
                            cov_list.append(torch.stack([cov_nr, cov_c], dim=0))
                        subject_indices.append([i, j])

                X = torch.stack(cov_list, dim=0).to(device)
                n_batches = X.shape[0] // batch_size + 1
                t_x = time.time()
                print(
                    f"Evaluating {state_c} vs {state_nr} with {X.shape[0]} pairs for task: {target_task}"
                )
                for idx in range(n_batches):
                    batched_X = X[idx * batch_size : (idx + 1) * batch_size, ...]
                    batched_sub_indices = subject_indices[
                        idx * batch_size : (idx + 1) * batch_size
                    ]
                    torch.cuda.empty_cache()
                    optimal_nplets, optimal_scores = simulated_annealing_parallel(
                        X=batched_X,
                        order=order,
                        device=device,
                        largest=True,
                        metric=max_difference_pairs,
                        repeat=repeat,
                        early_stop=early_stop,
                        max_iterations=max_iter,
                        covmat_precomputed=True,
                        batch_size=batch_size,
                        verbose=logging.WARNING,
                    )
                    max_idx = torch.argmax(optimal_scores, dim=0)  #

                    best_nplets = optimal_nplets[
                        max_idx, torch.arange(optimal_nplets.size(1))
                    ]
                    best_scores = optimal_scores[
                        max_idx, torch.arange(optimal_scores.size(1))
                    ]

                    for best_score, best_nplet, sub_indices in zip(
                        best_scores, best_nplets, batched_sub_indices
                    ):
                        best_nplet = best_nplet.tolist()
                        best_nplet.sort()
                        dataset_c = selected_dataset
                        dataset_nr = selected_dataset
                        subject_c, subject_nr = sub_indices

                        results.append(
                            {
                                "order": order,
                                "task": target_task,
                                "state_c": state_c,
                                "state_nr": state_nr,
                                "subject_c": subject_c,
                                "subject_nr": subject_nr,
                                "optimal_nplet": best_nplet,
                                "optimal_score": best_score.item(),
                            }
                        )
                    torch.cuda.empty_cache()

                results_df = pd.DataFrame(results)
                results_df.to_csv(
                    f"{results_path}/R1_A_max_O_diff_{selected_dataset}_{order}.csv",
                    index=False,
                    encoding="utf-8-sig",
                    sep=";",
                    decimal=",",
                )
                t_y = time.time()
                print(
                    f"{X.shape[0]} pairs evaluated in:",
                    np.round(t_y - t_x, 1),
                    "seconds",
                )
        results_df = pd.DataFrame(results)
        results_df.to_csv(
            f"{results_path}/R1_A_max_O_diff_{selected_dataset}_{order}.csv",
            index=False,
            encoding="utf-8-sig",
            sep=";",
            decimal=",",
        )


******************************
ORDER: 3
Evaluating DBS_awake vs ts_off with 1008 pairs for task: Cpos


  0%|          | 0/10000 [00:00<?, ?it/s]                            
Processing n-plets:  57%|█████▋    | 57/100 [00:00<00:00, 561.59it/s] 