In [1]:
module_path = "/home/claysmyth/code/integrated_rcs_analysis/python"
#class_path = "sklearn_model/clustering_classification.py"
import sys
sys.path.append(module_path)
import pickle
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from typing import Dict, Any
import numpy as np
import glob
import json
from pathlib import Path

In [None]:
from datetime import datetime
print(f"Cell executed at: {datetime.now()}")

## Model check

In [20]:
weight_gain_factor = 1e5
model_path = '/media/longterm_hdd/Clay/Sleep_aDBS/embeddable_models/{model_group}/{device_id}/ClusterClassificationModel_*/model.pkl'
feature_path = '/media/longterm_hdd/Clay/Sleep_aDBS/embeddable_models/{model_group}/{device_id}/ClusterClassificationModel_*/{device_id}_features.pickle'
input_adaptive_template_path = "/home/claysmyth/code/bayes_opt_infra/configs/templates/adaptive_config.json"
devices = ['RCS01L', 'RCS01R']
device_stim_settings = {
    'RCS01L': {
        'stim_amp': 2.7,
        'stim_freq': 130.2,
    },
    'RCS01R': {
        'stim_amp': 3.5,
        'stim_freq': 130.2,
    }
}

device_group_mapping = {
    'RCS01L': 'unsup_embeddable_models_15s_averaged',
    'RCS01R': 'unsup_embeddable_models_15s_averaged',
}
updated_params_template = {
        "Detection.LD0.Comment": "Settings for Unsupervised-Supervised NREM Generation. NREM as State1",
        "Detection.LD0.B0": None,
        "Detection.LD0.UpdateRate": 15, # Assumes 1000ms FftInterval. This corresponds to an LD update every 15 seconds...
        "Detection.LD0.StateChangeBlankingUponStateChange": 7, # Blank for 7 seconds after state change
        "Detection.LD0.WeightVector": None,
         # This ramp rate corresponds to 1 mA/s rate change
        "Adaptive.Program0.RiseTimes": 65536, 
        "Adaptive.Program0.FallTimes": 65536,
        "Adaptive.Program0.State0AmpInMilliamps": None,
        "Adaptive.Program0.State1AmpInMilliamps": None,
        "Adaptive.Rates.State0.RateTargetInHz": None,
        "Adaptive.Rates.State1.RateTargetInHz": None,
        "Adaptive.Rates.State2.RateTargetInHz": None,
        "Adaptive.Rates.State3.RateTargetInHz": None,
        "Adaptive.Rates.State4.RateTargetInHz": None,
        "Adaptive.Rates.State5.RateTargetInHz": None,
        "Adaptive.Rates.State6.RateTargetInHz": None,
        "Adaptive.Rates.State7.RateTargetInHz": None,
        "Adaptive.Rates.State8.RateTargetInHz": None,
    }
out_path_base = "/media/longterm_hdd/Clay/Sleep_aDBS/bayes_opt_experiments"
states = set(['State 0', 'State 1'])

-------------

In [None]:
def plot_feature_distributions(model, X, y, feature_names=['Delta', 'Alpha+Theta', 'Beta', 'Gamma'], title=None, inverted=False):
    # Get predictions from the model
    predictions = model.classifier_model.predict(X)

    # Create figure with subplots for each feature
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    axes = axes.ravel()

    # Plot histogram for each feature
    for i in range(X.shape[1]):
        # Get data for each class
        class_0_data = X[predictions == 0, i]
        class_1_data = X[predictions == 1, i]

        if inverted:
            class_0_label = 'Class 0: NREM'
            class_1_label = 'Class 1: Wake+REM'
            class_0_color = 'blue'
            class_1_color = 'orange'
        else:
            class_0_label = 'Class 0: Wake+REM'
            class_1_label = 'Class 1: NREM'
            class_0_color = 'orange'
            class_1_color = 'blue'
        
        # Calculate means
        class_0_mean = class_0_data.mean()
        class_1_mean = class_1_data.mean()
        
        # Plot histograms with consistent colors
        axes[i].hist(class_0_data, bins=30, alpha=0.4, color=class_0_color, label=f'{class_0_label} (mean={class_0_mean:.2f})', density=False)
        axes[i].hist(class_1_data, bins=30, alpha=0.4, color=class_1_color, label=f'{class_1_label} (mean={class_1_mean:.2f})', density=False)
        
        # Add vertical lines for means with consistent colors
        axes[i].axvline(class_0_mean, color=class_0_color, linestyle='--', alpha=0.5)
        axes[i].axvline(class_1_mean, color=class_1_color, linestyle='--', alpha=0.5)
        
        axes[i].set_xlabel(f'{feature_names[i]} Value')
        axes[i].set_ylabel('Count')
        axes[i].set_title(f'Distribution of {feature_names[i]} by Predicted Class')
        axes[i].legend()
    if title:
        plt.suptitle(title)
    plt.tight_layout()
    return fig

In [17]:
def update_json_config(filepath: str, update_dict: Dict[str, Any], outpath: str) -> Dict[str, Any]:
    """
    Update JSON config file with new parameter values.

    Args:
        filepath: Path to JSON config file
        update_dict: Dictionary of key-value pairs to update in the JSON. Keys can be nested using dot notation (e.g. 'field0.field1.field2')
        outpath: Path to output JSON file

    Returns:
        Updated JSON content as dictionary
    """
    import json

    # Read existing JSON file
    with open(filepath, 'r') as f:
        config = json.load(f)

    # Update values
    for key_path, value in update_dict.items():
        # Split nested key path
        keys = key_path.split('.')
        
        # Navigate to the nested location
        current = config
        for key in keys[:-1]:
            if key not in current:
                current[key] = {}
            current = current[key]
            
        # Set the value at the final key
        current[keys[-1]] = value

    # Write back to file
    with open(outpath, 'w') as f:
        json.dump(config, f, indent=4)

    return config


### TODO:
- Remove need for rounding?
- Document transformations on LDA function that allow for conversion appropriate for RC+S (specifically moving bias term to otherside, or the need to offload to subtract vector and the ensuing math)

In [18]:
def offload_bias_to_subtract_vector(coefs, intercept):
    if coefs.ndim > 1:
        coefs = coefs.squeeze()
    # Get vector of length coefs, with value set to bias term (e.g. intercept)
    subtract_vector = np.ones_like(coefs) * intercept
    # Calculate the subtract vector
    subtract_vector = subtract_vector / coefs
    subtract_vector = subtract_vector / np.size(subtract_vector)

    # Switch the sign to match the convention of the adaptive config
    subtract_vector = -1 * subtract_vector 
    return subtract_vector

In [19]:
def check_bias(model):
    # If the intercept is positive, then moving it to the other side of the inequality will flip the sign of the inequality, resulting in a negative RC+S bias value.
    # RC+S bias cannot be negative, so we would need to invert the model.
    if model.classifier_model.intercept_  > 0:
        return False
    return True

In [None]:
def get_adaptive_weights_and_bias(model, weight_gain_factor):
    if np.size(model.classifier_model.coef_) > 1:
        weights = np.round(model.classifier_model.coef_.squeeze() * weight_gain_factor).tolist()
    else:
        weights = np.round(model.classifier_model.coef_.squeeze() * weight_gain_factor).tolist()
        weights = [weights]
    # Add small offset to any zero weights to avoid RC+S errors
    weights = [w + 0.001 if w == 0.0 else w for w in weights]
    bias = -1 * np.round(model.classifier_model.intercept_ * weight_gain_factor)[0]
    return weights, bias

In [24]:
def update_params_for_device(device, model, model_inverted, weight_gain_factor, device_stim_settings, updated_params_template):
    weights, bias = get_adaptive_weights_and_bias(model, weight_gain_factor)
    updated_params = updated_params_template.copy()

    if model_inverted:
        updated_params["Detection.LD0.Comment"] = "Settings for Unsupervised-Supervised NREM Generation. NREM as State0. Inverted model."
    updated_params["Detection.LD0.WeightVector"] = weights
    updated_params["Detection.LD0.B0"] = bias
    updated_params["Adaptive.Program0.State0AmpInMilliamps"] = device_stim_settings[device]["stim_amp"]
    updated_params["Adaptive.Program0.State1AmpInMilliamps"] = device_stim_settings[device]["stim_amp"]
    updated_params["Adaptive.Rates.State0.RateTargetInHz"] = device_stim_settings[device]["stim_freq"]
    updated_params["Adaptive.Rates.State1.RateTargetInHz"] = device_stim_settings[device]["stim_freq"]
    for param in updated_params:
        if 'RateTargetInHz' in param:
            updated_params[param] = device_stim_settings[device]["stim_freq"]
    return updated_params


In [25]:
def invert_model(model):
    model.classifier_model.coef_ = -1 * model.classifier_model.coef_
    model.classifier_model.intercept_ = -1 * model.classifier_model.intercept_
    return model


_______

### Note that inverting the model only inverts it's implementation on the RC+S device. It does NOT INVERT THE STATE PREDICTIONS FOR THE SKLEARN MODEL! SKLEARN MODEL WILL ALWAYS PREDICT NREM AS STATE 1 (BY DEFINITION)!

In [None]:
nrem_state = {}

for device in devices:
    model_inverted = False
    print(f"Checking {device}")
    # Get the actual model path by finding the matching directory
    model_dir = glob.glob(model_path.format(model_group=device_group_mapping[device], device_id=device))[0]
    feature_dir = glob.glob(feature_path.format(model_group=device_group_mapping[device], device_id=device))[0]
    model = pickle.load(open(model_dir, "rb"))
    features = pickle.load(open(feature_dir, "rb"))
    X = features.X
    y = features.y
    if not check_bias(model):
        print(f"Bias fails for {device}. Need to invert model.")
        model = invert_model(model)
        model_inverted = True
        nrem_state[device] = 'State 0'
        updated_params = update_params_for_device(device, model, model_inverted, weight_gain_factor, device_stim_settings, updated_params_template)
    else:
        print(f"Bias passes for {device}")
        updated_params = update_params_for_device(device, model, model_inverted, weight_gain_factor, device_stim_settings, updated_params_template)
        nrem_state[device] = 'State 1'
    
    fig = plot_feature_distributions(model, X, y, title=device, inverted=model_inverted)
    
    # Create directory path
    dir_path = Path(f"{out_path_base}/{device[:-1]}/{device}")
    dir_path.mkdir(parents=True, exist_ok=True)

    outpath = Path(f"{dir_path}/adaptive_config_{device[-1]}_template.json")

    _ = update_json_config(input_adaptive_template_path, updated_params, outpath)
    fig.savefig(f"{dir_path}/feature_distributions_{device[-1]}.png")
    
    # Save the state mapping for this device
    state_mapping = {
        nrem_state[device]: 'NREM',
        list(states.difference({nrem_state[device]}))[0]: 'Wake+REM'
    }

    with open(dir_path / "rcs_state_to_NREM_mapping.json", 'w') as f:
        json.dump(state_mapping, f, indent=4)
