In [23]:
import yaml
from violmulti.data.dataset_loader import DatasetLoader
from violmulti.data.dataset_loader import DatasetLoader

from violmulti.features.design_matrix_generator_PWM import *

import pandas as pd
import numpy as np
import ssm

from violmulti.models.ssm_glm_hmm import SSMGLMHMM
from violmulti.utils.save_load import *

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Experiment Runner

### Create folder

In [44]:
from violmulti.utils.create_experiment_folder import *

In [48]:
create_experiment_directory("name2", "standard")

'/Volumes/brody/jbreda/behavioral_analysis/violations_multinomial/data/results/experiment_name2'

# SSM Model Class

In [2]:
# Set the parameters of the GLM-HMM
num_states = 3  # number of discrete states
obs_dim = 1  # number of observed dimensions
num_categories = 2  # number of categories for output
input_dim = 2  # input dimensions

# Make a GLM-HMM
true_glmhmm = ssm.HMM(
    num_states,
    obs_dim,
    input_dim,
    observations="input_driven_obs",
    observation_kwargs=dict(C=num_categories),
    transitions="standard",
)

gen_weights = np.array([[[6, 1]], [[2, -3]], [[2, 3]]])
gen_log_trans_mat = np.log(
    np.array([[[0.98, 0.01, 0.01], [0.05, 0.92, 0.03], [0.03, 0.03, 0.94]]])
)
true_glmhmm.observations.params = gen_weights
true_glmhmm.transitions.params = gen_log_trans_mat

In [113]:
import ssm
import logging
import pickle
from pathlib import Path


class SSMGLMHMM(ssm.HMM):
    """
    Child class of ssm.HMM that adds additional functionality for
    fitting glm-hmm models specific to my (Jess Breda's) use case
    of fitting models to binary and trinomial trial-by-trial choice
    data.

    """
    def __init__(
        self,
        model_config: dict,
        model_name: str = "glmhmm",
        n_fold: int = 1,
        results_path=None,
    ):

        self.model_config = model_config
        self.unpack_model_config()
        self.set_up_priors()

        self.n_fold = n_fold
        self.model_name = model_name
        self.results_path = results_path

        super().__init__(
            K=self.K,
            D=1,  # never have more than 1 output dimension
            M=self.M,
            observations="input_driven_obs",
            observation_kwargs=self.observation_kwargs,
            transitions=self.transitions,
            transition_kwargs=self.transition_kwargs,
        )

        # TODO logic here for initializing weights and transitions
        # TODO if model config var exists

    def unpack_model_config(self):
        """ 
        Method to unpack the model config dictionary into class 
        attributes. Some are exact duplicates from ssm.HMM, others
        are custom to this class.
        """
        self.K = self.model_config["n_states"]
        self.M = self.model_config["n_features"]
        self.C = self.model_config["n_categories"]
        self.transitions = self.model_config.get("transitions", "standard")
        self.n_iters = self.model_config.get("n_iters", 200)
        self.prior_sigma = self.model_config.get("prior_sigma", None)
        self.prior_alpha = self.model_config.get("prior_alpha", 0)
        self.prior_kappa = self.model_config.get("prior_kappa", 0)
        self.masks = self.model_config.get("masks", None)
        self.tolerance = self.model_config.get("tolerance", 1e-4)
        self.seed = self.model_config.get("seed", 0)
        logging.info(f"Unpacked model config: {self.model_config}")

    def set_up_priors(self):
        # Set up the kwargs for the model- can't pass in 0 values and
        # have them be ignored, so need to do this manually
        if self.transitions == "sticky":
            self.transition_kwargs = dict(self.prior_alpha, self.prior_kappa)
        elif self.transitions == "standard":
            self.transition_kwargs = None
        else:
            raise ValueError("Invalid transition type for SSM GLM-HMM.")
        logging.info(f"Transition kwargs set: {self.transition_kwargs}")

        if self.prior_sigma is None:
            self.observation_kwargs = dict(C=self.C)
        else:
            self.observation_kwargs = dict(C=self.C, prior_sigma=self.prior_sigma)
        logging.info(f"Observation kwargs set: {self.observation_kwargs}")

    def initialize_weights(self):
        """
        Initialize the weights of the model. Placeholder for actual implementation.
        """
        np.random.seed(self.seed)
        pass

    def initialize_transitions(self):
        """
        Initialize the transitions of the model. Placeholder for actual implementation.
        """
        # Initialize transitions logic here

    def fit(self, X, y):
        self.X = X
        self.y = y

        self.log_probs = super().fit(
            datas=self.y,
            inputs=self.X,
            masks=self.masks,
            method="em",
            num_iters=self.n_iters,
            tolerance=self.tolerance,
        )

        return self.log_probs

    def compute_stats_of_interest(self):
        """
        Statistics of interest that are easiest to calculate
        after fitting the model while the data is still in memory.

        1. log likelihood
        2. posterior state probs (in list of list by session)
        """

        self.log_like = self.log_likelihood(self.y, self.X)
        self.posterior_state_probs = self.get_posterior_state_probs()

    def get_posterior_state_probs(self):
        posterior_probs = []
        for sesssion_choices, sesssion_inputs in zip(self.y, self.X):

            # expected_states returns
            # [posterior_state_probs, posterior_joint_probs, normalizer]
            # so we only need the first element of the returned list
            posterior_probs.append(
                self.expected_states(data=sesssion_choices, input=sesssion_inputs)[0]
            )
        return posterior_probs

    def save(self):

        self.compute_stats_of_interest()

        if self.results_path is None:
            self.results_path = Path.cwd() / "model_results"

            if not self.results_path.exists():
                self.results_path.mkdir()
        else:
            self.results_path = Path(self.results_path)
            self.results_path.mkdir(
                parents=True, exist_ok=True
            )  # Ensure the directory exists

        with open(
            f"{self.results_path}/{self.model_name}_model_fold_{self.n_fold}.pkl", "wb"
        ) as f:
            pickle.dump(self, f)

In [4]:
model_config = {
    "n_states": 2,
    "n_features": 2,
    "n_categories": 2,
}


true_glmhmm = SSMGLMHMM(model_config)

num_sess = 20  # number of example sessions
num_trials_per_sess = 100  # number of trials in a session
inpts = np.ones((num_sess, num_trials_per_sess, input_dim))  # initialize inpts array
stim_vals = [-1, -0.5, -0.25, -0.125, -0.0625, 0, 0.0625, 0.125, 0.25, 0.5, 1]
inpts[:, :, 0] = np.random.choice(
    stim_vals, (num_sess, num_trials_per_sess)
)  # generate random sequence of stimuli
inpts = list(inpts)  # convert inpts to correct format
len(inpts), inpts[0].shape


true_latents, true_choices = [], []
for sess in range(num_sess):
    true_z, true_y = true_glmhmm.sample(num_trials_per_sess, input=inpts[sess])
    true_latents.append(true_z)
    true_choices.append(true_y)

len(true_latents), true_latents[0].shape

(20, (100,))

In [5]:
new_glm = SSMGLMHMM(model_config)

In [6]:
log_ps = new_glm.fit(inpts, true_choices)

Converged to LP: -1166.2:  10%|▉         | 19/200 [00:00<00:03, 47.14it/s]


In [7]:
print(new_glm.init_state_distn.params)
print(new_glm.transitions.params)
print(new_glm.observations.params)

(array([-1.74610071, -0.19170891]),)
(array([[-0.01161171, -4.4615418 ],
       [-3.88070606, -0.02085215]]),)
[[[ 1.47712929  0.38278679]]

 [[-1.08570381  1.58486934]]]


In [9]:
save_model_to_pickle(new_glm) # make a path
save_model_to_pickle( #tell the path
    new_glm,
    results_path="/Users/jessbreda/Desktop/github/violations-multinomial/src/violmulti/experiments/cluster_scripts/example_experiment/",
)

Directory ensured at: /Users/jessbreda/Desktop/github/violations-multinomial/src/violmulti/experiments/cluster_scripts/example_experiment/model_results
Directory ensured at: /Users/jessbreda/Desktop/github/violations-multinomial/src/violmulti/experiments/cluster_scripts/example_experiment


In [10]:
loaded_model = load_model_from_pickle(
    animal_id="",
    n_states=2,
    model_name="glmhmm",
    n_fold=0,
    results_path="/Users/jessbreda/Desktop/github/violations-multinomial/src/violmulti/experiments/cluster_scripts/example_experiment/model_results/",
)

In [11]:
loaded_model2 = load_model_from_pickle(
    animal_id="",
    n_states=2,
    model_name="glmhmm",
    n_fold=0,
    results_path="/Users/jessbreda/Desktop/github/violations-multinomial/src/violmulti/experiments/cluster_scripts/example_experiment/",
)

In [13]:
print(loaded_model2.init_state_distn.params)
print(loaded_model2.transitions.params)
print(loaded_model2.observations.params)

(array([-1.74610071, -0.19170891]),)
(array([[-0.01161171, -4.4615418 ],
       [-3.88070606, -0.02085215]]),)
[[[ 1.47712929  0.38278679]]

 [[-1.08570381  1.58486934]]]


# DMG Config

In [14]:
true_config = {
    "data": {
        "s_a_stand": lambda df: (standardize(df.s_a)),
        "s_b_stand": lambda df: (standardize(df.s_b)),
        "stim_avg_stand": lambda df: standardize(
            (combine_two_cols(df.s_a, df.s_b, operation="mean"))
        ),
        "prev_correct": lambda df: (
            shift_n_trials_up(df.correct_side, df.session, shift_size=1)
        ),
        "labels": binary_choice_labels(),
    }
}

true_config  # to compare the .yaml file to
config_path = "/Users/jessbreda/Desktop/github/violations-multinomial/src/violmulti/experiments/cluster_scripts/example_experiment/config.yaml"

In [15]:
animal_ids = ["W078"]
relative_data_path = "/Users/jessbreda/Desktop/github/violations-multinomial/data"

df = DatasetLoader(
    animal_ids=animal_ids,
    data_type="new_trained",
    relative_data_path=relative_data_path,
).load_data()

Loading data for animal ids:  ['W078']


### Testing Config Utilty Functions

In [36]:
from violmulti.utils.config_utils import *

config = load_config_from_yaml(config_path)
config["dmg_config"] = convert_dmg_config_functions(config["dmg_config"])
config

{'animal_ids': ['W078'],
 'relative_data_path': '/Users/jessbreda/Desktop/github/violations-multinomial/data',
 'data_type': 'new_trained',
 'dmg_config': {'s_a_stand': <function violmulti.utils.config_utils.<lambda>(df)>,
  's_b_stand': <function violmulti.utils.config_utils.<lambda>(df)>,
  'stim_avg_stand': <function violmulti.utils.config_utils.<lambda>(df)>,
  'prev_correct': <function violmulti.utils.config_utils.<lambda>(df)>,
  'labels': {'column_name': 'choice', 'mapping': {0: 0, 1: 1, 2: nan}}}}

In [37]:
dmg = DesignMatrixGeneratorPWM(df, config["dmg_config"], verbose=True)
X, y = dmg.create()
X.head(), y[0:5]

DMG: Creating data matrix with columns: dict_keys(['s_a_stand', 's_b_stand', 'stim_avg_stand', 'prev_correct'])
DMG: Creating labels with column: choice.
DMG: Dropping 8254 nan rows from data and labels.
DMG: Binary encoding labels.


(         s_a_stand  s_b_stand  stim_avg_stand  prev_correct
 1180811  -1.672853  -0.782846       -1.344194           0.0
 1180812  -0.036372   0.851860        0.446731           0.0
 1180813  -0.036372  -0.782846       -0.448732           0.0
 1180814  -0.854613  -1.600199       -1.344194           1.0
 1180815   0.781869   0.034507        0.446731           1.0,
 array([0, 0, 0, 0, 1]))

In [38]:
save_data_and_labels_to_parquet(X, y)

Directory ensured at: /Users/jessbreda/Desktop/github/violations-multinomial/src/violmulti/experiments/cluster_scripts/example_experiment/data
DataFrame saved to /Users/jessbreda/Desktop/github/violations-multinomial/src/violmulti/experiments/cluster_scripts/example_experiment/data/animal__model__fold_0_X.parquet
Labels saved to /Users/jessbreda/Desktop/github/violations-multinomial/src/violmulti/experiments/cluster_scripts/example_experiment/data/animal__model__fold_0_y.parquet


In [39]:
X_loaded, y_loaded = load_data_and_labels_from_parquet(
    animal_id="",
    model_name="",
    n_fold=0,
    data_path=Path(
        "/Users/jessbreda/Desktop/github/violations-multinomial/src/violmulti/experiments/cluster_scripts/example_experiment/data"
    ),
)

here
DataFrame loaded from /Users/jessbreda/Desktop/github/violations-multinomial/src/violmulti/experiments/cluster_scripts/example_experiment/data/animal__model__fold_0_X.parquet
Labels loaded from /Users/jessbreda/Desktop/github/violations-multinomial/src/violmulti/experiments/cluster_scripts/example_experiment/data/animal__model__fold_0_y.parquet


In [42]:
# ensure X and X_loaded are equal
assert X.equals(X_loaded)

# ensure y and y_loaded are equal
assert np.all(y == y_loaded)

### Pre config utility functions

In [139]:
def load_config_from_yaml(file_path):
    with open(file_path, "r") as file:
        config = yaml.safe_load(file)
    return config


# Example usage
config = load_config_from_yaml(config_path)
type(config), print(config)

{'data': {'s_a_stand': 'lambda df: standardize(df.s_a)', 's_b_stand': 'lambda df: standardize(df.s_b)', 'stim_avg_stand': "lambda df: standardize(combine_two_cols(df.s_a, df.s_b, operation='mean'))", 'prev_correct': 'lambda df: shift_n_trials_up(df.correct_side, df.session, shift_size=1)', 'labels': 'binary_choice_labels()'}}


(dict, None)

In [140]:
def deserialize_function_or_call(func_str):
    try:
        if "lambda" in func_str:
            return eval(func_str)
        elif "()" in func_str:  # Simple check to see if it's a function call
            return eval(func_str)
        else:
            raise ValueError(f"Unknown function format: {func_str}")
    except (SyntaxError, NameError) as e:
        raise ValueError(f"Error evaluating function string: {func_str} - {e}")


def convert_config_functions(config):
    deserialized_config = {}
    for key, func_str in config.items():
        deserialized_config[key] = deserialize_function_or_call(func_str)
    return deserialized_config

In [141]:
config  # pre-deserialization

{'data': {'s_a_stand': 'lambda df: standardize(df.s_a)',
  's_b_stand': 'lambda df: standardize(df.s_b)',
  'stim_avg_stand': "lambda df: standardize(combine_two_cols(df.s_a, df.s_b, operation='mean'))",
  'prev_correct': 'lambda df: shift_n_trials_up(df.correct_side, df.session, shift_size=1)',
  'labels': 'binary_choice_labels()'}}

In [142]:
config["dmg_config"] = convert_config_functions(config["dmg_config"])
config  # post-deserialization

{'data': {'s_a_stand': <function __main__.<lambda>(df)>,
  's_b_stand': <function __main__.<lambda>(df)>,
  'stim_avg_stand': <function __main__.<lambda>(df)>,
  'prev_correct': <function __main__.<lambda>(df)>,
  'labels': {'column_name': 'choice', 'mapping': {0: 0, 1: 1, 2: nan}}}}

In [131]:

dmg = DesignMatrixGeneratorPWM(df, config["dmg_config"], verbose=True)
X, y = dmg.create()

Loading data for animal ids:  ['W078']
DMG: Creating data matrix with columns: dict_keys(['s_a_stand', 's_b_stand', 'stim_avg_stand', 'prev_correct'])
DMG: Creating labels with column: choice.
DMG: Dropping 8254 nan rows from data and labels.
DMG: Binary encoding labels.


In [133]:
dmg2 = DesignMatrixGeneratorPWM(df, true_config["dmg_config"], verbose=True)
X_true, y_true = dmg2.create()

DMG: Creating data matrix with columns: dict_keys(['s_a_stand', 's_b_stand', 'stim_avg_stand', 'prev_correct'])
DMG: Creating labels with column: choice.
DMG: Dropping 8254 nan rows from data and labels.
DMG: Binary encoding labels.


In [119]:
X

Unnamed: 0,s_a_stand,s_b_stand,stim_avg_stand,prev_correct
1180811,-1.672853,-0.782846,-1.344194,0.0
1180812,-0.036372,0.851860,0.446731,0.0
1180813,-0.036372,-0.782846,-0.448732,0.0
1180814,-0.854613,-1.600199,-1.344194,1.0
1180815,0.781869,0.034507,0.446731,1.0
...,...,...,...,...
1261383,-0.036372,-0.782846,-0.448732,1.0
1261384,-0.854613,0.034507,-0.448732,1.0
1261385,-1.672853,-0.782846,-1.344194,0.0
1261387,-0.036372,-0.782846,-0.448732,0.0


In [120]:
X_true

Unnamed: 0,s_a_stand,s_b_stand,stim_avg_stand,prev_correct
1180811,-1.672853,-0.782846,-1.344194,0.0
1180812,-0.036372,0.851860,0.446731,0.0
1180813,-0.036372,-0.782846,-0.448732,0.0
1180814,-0.854613,-1.600199,-1.344194,1.0
1180815,0.781869,0.034507,0.446731,1.0
...,...,...,...,...
1261383,-0.036372,-0.782846,-0.448732,1.0
1261384,-0.854613,0.034507,-0.448732,1.0
1261385,-1.672853,-0.782846,-1.344194,0.0
1261387,-0.036372,-0.782846,-0.448732,0.0
