In [122]:
import yaml
from violmulti.data.dataset_loader import DatasetLoader
from violmulti.data.dataset_loader import DatasetLoader

from violmulti.features.design_matrix_generator_PWM import *

import pandas as pd

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [145]:
true_config = {
    "data": {
        "s_a_stand": lambda df: (standardize(df.s_a)),
        "s_b_stand": lambda df: (standardize(df.s_b)),
        "stim_avg_stand": lambda df: standardize(
            (combine_two_cols(df.s_a, df.s_b, operation="mean"))
        ),
        "prev_correct": lambda df: (
            shift_n_trials_up(df.correct_side, df.session, shift_size=1)
        ),
        "labels": binary_choice_labels(),
    }
}

true_config  # to compare the .yaml file to
config_path = "/Users/jessbreda/Desktop/github/violations-multinomial/src/violmulti/experiments/cluster_scripts/example_experiment/config.yaml"

In [146]:
animal_ids = ["W078"]
relative_data_path = "/Users/jessbreda/Desktop/github/violations-multinomial/data"

df = DatasetLoader(
    animal_ids=animal_ids,
    data_type="new_trained",
    relative_data_path=relative_data_path,
).load_data()

Loading data for animal ids:  ['W078']


### Testing Config Utilty Functions

In [156]:
from violmulti.utils.config_utils import *

config = load_config_from_yaml(config_path)
config["dmg_config"] = convert_dmg_config_functions(config["dmg_config"])
config

{'dmg_config': {'s_a_stand': <function violmulti.utils.config_utils.<lambda>(df)>,
  's_b_stand': <function violmulti.utils.config_utils.<lambda>(df)>,
  'stim_avg_stand': <function violmulti.utils.config_utils.<lambda>(df)>,
  'prev_correct': <function violmulti.utils.config_utils.<lambda>(df)>,
  'labels': {'column_name': 'choice', 'mapping': {0: 0, 1: 1, 2: nan}}}}

In [158]:
dmg = DesignMatrixGeneratorPWM(df, config["dmg_config"], verbose=True)
X, y = dmg.create()
X.head(), y[0:5]

DMG: Creating data matrix with columns: dict_keys(['s_a_stand', 's_b_stand', 'stim_avg_stand', 'prev_correct'])
DMG: Creating labels with column: choice.
DMG: Dropping 8254 nan rows from data and labels.
DMG: Binary encoding labels.


(         s_a_stand  s_b_stand  stim_avg_stand  prev_correct
 1180811  -1.672853  -0.782846       -1.344194           0.0
 1180812  -0.036372   0.851860        0.446731           0.0
 1180813  -0.036372  -0.782846       -0.448732           0.0
 1180814  -0.854613  -1.600199       -1.344194           1.0
 1180815   0.781869   0.034507        0.446731           1.0,
 array([0, 0, 0, 0, 1]))

### Pre config utility functions

In [139]:
def load_config_from_yaml(file_path):
    with open(file_path, "r") as file:
        config = yaml.safe_load(file)
    return config


# Example usage
config = load_config_from_yaml(config_path)
type(config), print(config)

{'data': {'s_a_stand': 'lambda df: standardize(df.s_a)', 's_b_stand': 'lambda df: standardize(df.s_b)', 'stim_avg_stand': "lambda df: standardize(combine_two_cols(df.s_a, df.s_b, operation='mean'))", 'prev_correct': 'lambda df: shift_n_trials_up(df.correct_side, df.session, shift_size=1)', 'labels': 'binary_choice_labels()'}}


(dict, None)

In [140]:
def deserialize_function_or_call(func_str):
    try:
        if "lambda" in func_str:
            return eval(func_str)
        elif "()" in func_str:  # Simple check to see if it's a function call
            return eval(func_str)
        else:
            raise ValueError(f"Unknown function format: {func_str}")
    except (SyntaxError, NameError) as e:
        raise ValueError(f"Error evaluating function string: {func_str} - {e}")


def convert_config_functions(config):
    deserialized_config = {}
    for key, func_str in config.items():
        deserialized_config[key] = deserialize_function_or_call(func_str)
    return deserialized_config

In [141]:
config  # pre-deserialization

{'data': {'s_a_stand': 'lambda df: standardize(df.s_a)',
  's_b_stand': 'lambda df: standardize(df.s_b)',
  'stim_avg_stand': "lambda df: standardize(combine_two_cols(df.s_a, df.s_b, operation='mean'))",
  'prev_correct': 'lambda df: shift_n_trials_up(df.correct_side, df.session, shift_size=1)',
  'labels': 'binary_choice_labels()'}}

In [142]:
config["dmg_config"] = convert_config_functions(config["dmg_config"])
config  # post-deserialization

{'data': {'s_a_stand': <function __main__.<lambda>(df)>,
  's_b_stand': <function __main__.<lambda>(df)>,
  'stim_avg_stand': <function __main__.<lambda>(df)>,
  'prev_correct': <function __main__.<lambda>(df)>,
  'labels': {'column_name': 'choice', 'mapping': {0: 0, 1: 1, 2: nan}}}}

In [131]:

dmg = DesignMatrixGeneratorPWM(df, config["dmg_config"], verbose=True)
X, y = dmg.create()

Loading data for animal ids:  ['W078']
DMG: Creating data matrix with columns: dict_keys(['s_a_stand', 's_b_stand', 'stim_avg_stand', 'prev_correct'])
DMG: Creating labels with column: choice.
DMG: Dropping 8254 nan rows from data and labels.
DMG: Binary encoding labels.


In [133]:
dmg2 = DesignMatrixGeneratorPWM(df, true_config["dmg_config"], verbose=True)
X_true, y_true = dmg2.create()

DMG: Creating data matrix with columns: dict_keys(['s_a_stand', 's_b_stand', 'stim_avg_stand', 'prev_correct'])
DMG: Creating labels with column: choice.
DMG: Dropping 8254 nan rows from data and labels.
DMG: Binary encoding labels.


In [119]:
X

Unnamed: 0,s_a_stand,s_b_stand,stim_avg_stand,prev_correct
1180811,-1.672853,-0.782846,-1.344194,0.0
1180812,-0.036372,0.851860,0.446731,0.0
1180813,-0.036372,-0.782846,-0.448732,0.0
1180814,-0.854613,-1.600199,-1.344194,1.0
1180815,0.781869,0.034507,0.446731,1.0
...,...,...,...,...
1261383,-0.036372,-0.782846,-0.448732,1.0
1261384,-0.854613,0.034507,-0.448732,1.0
1261385,-1.672853,-0.782846,-1.344194,0.0
1261387,-0.036372,-0.782846,-0.448732,0.0


In [120]:
X_true

Unnamed: 0,s_a_stand,s_b_stand,stim_avg_stand,prev_correct
1180811,-1.672853,-0.782846,-1.344194,0.0
1180812,-0.036372,0.851860,0.446731,0.0
1180813,-0.036372,-0.782846,-0.448732,0.0
1180814,-0.854613,-1.600199,-1.344194,1.0
1180815,0.781869,0.034507,0.446731,1.0
...,...,...,...,...
1261383,-0.036372,-0.782846,-0.448732,1.0
1261384,-0.854613,0.034507,-0.448732,1.0
1261385,-1.672853,-0.782846,-1.344194,0.0
1261387,-0.036372,-0.782846,-0.448732,0.0
