# E. Computation of states per trial epoch 


In [2]:

""" 
IMPORTS
"""
import os
import autograd.numpy as np
import jax.numpy as jnp
import jax.random as jr
import pickle
import seaborn as sns
from collections import defaultdict
import pandas as pd
from matplotlib import colors as mcolors
from one.api import ONE
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler, Normalizer

# Get my functions
functions_path =  '/home/ines/repositories/representation_learning_variability/Models/Sub-trial//2_fit_models/'
# functions_path = '/Users/ineslaranjeira/Documents/Repositories/representation_learning_variability//Models/Sub-trial//2_fit_models/'
os.chdir(functions_path)
from preprocessing_functions import idxs_from_files
functions_path =  '/home/ines/repositories/representation_learning_variability/Models/Sub-trial//3_postprocess_results/'
# functions_path = '/Users/ineslaranjeira/Documents/Repositories/representation_learning_variability//Models/Sub-trial//2_fit_models/'
os.chdir(functions_path)
from postprocessing_functions import remove_states_str, state_identifiability, align_bin_design_matrix, states_per_trial_phase, broader_label
functions_path =  '/home/ines/repositories/representation_learning_variability/Functions/'
os.chdir(functions_path)
from data_processing import save_and_log
functions_path =  '/home/ines/repositories/representation_learning_variability/Models/Sub-trial//3_postprocess_results/'
os.chdir(functions_path)
from plotting_functions import wheel_over_wavelet_clusters

# one = ONE(base_url='https://alyx.internationalbrainlab.org')
one = ONE(mode='remote')



## Parameters

In [3]:
# Parameters
bin_size = 0.017
multiplier = 1
num_states = 2
num_iters = 100
threshold = 0
sticky = False
optimal_k = 4
num_train_batches = 20

event_type_list = ['goCueTrigger_times']  # , 'feedback_times', 'firstMovement_times'
event_type_name = ['Go cue']  # , 'Feedback time', 'First movement onset'

save_path = '/home/ines/repositories/representation_learning_variability/DATA/Sub-trial/Results/'  + str(bin_size) + '/'+str(num_states)+'_states/most_likely_states/'

# LOAD DATA
data_path ='/home/ines/repositories/representation_learning_variability/DATA/Sub-trial/Design matrix/' + 'v5_15Jan2025/' + str(bin_size) + '/'
all_files = os.listdir(data_path)
design_matrices = [item for item in all_files if 'design_matrix' in item and 'standardized' not in item]
idxs, mouse_names = idxs_from_files(design_matrices, bin_size)

states_path =  '/home/ines/repositories/representation_learning_variability/DATA/Sub-trial/Results/' + str(bin_size) + '/'+str(num_states)+'_states/most_likely_states/'
wavelet_states_path = '/home/ines/repositories/representation_learning_variability/DATA/Sub-trial/Results/'  + str(bin_size) + '/wavelet_transform_states/'

use_sets = [['whisker_me'], ['Lick count'], ['0.5',
    '1.0', '2.0', '4.0', '8.0']]
use_sets = [['avg_wheel_vel'], ['whisker_me'], ['Lick count'], ['0.25', '0.5',
    '1.0', '2.0', '4.0', '8.0', '16.0']]
var_interest_map = ['whisker_me', 'Lick count', 'wavelet']
var_interest_map = ['wavelet', 'whisker_me', 'Lick count']

path_sets = [wavelet_states_path, states_path, states_path]

idx_init_list = [1, 2, 4]
idx_end_list = [2, 3, 8]
# var_interest = 'avg_wheel_vel'

# Did this on the 28th February!!!!! dataset from a week before was scrambled
identifiable_mapping = {'000': 0.0,
           '100': 1.0,
           '200': 2.0,
           '300': 3.0,
           '010': 4.0,
           '110': 5.0,
           '210': 6.0,
           '310': 7.0,
           '001': 8.0,
           '101': 9.0,
           '201': 10.0,
           '301': 11.0, 
           '011': 12.0,
           '111': 13.0,
           '211': 14.0,
           '311': 15.0,
           'nan': np.nan
           }


# Individual sessions

In [4]:
# Identify sessions available to process
sessions_to_process = []
for m, mat in enumerate(idxs):
    mouse_name = mat[37:]
    session = mat[:36]
    fit_id = str(mouse_name + session)
    whisker_filename = os.path.join(states_path, "most_likely_states" + 'whisker_me' + '_' + fit_id)
    licks_filename = os.path.join(states_path, "most_likely_states" + 'Lick count' + '_' + fit_id)
    wavelet_filename = os.path.join(wavelet_states_path, "most_likely_states_" + str(optimal_k) + '_' + fit_id)

    if os.path.exists(whisker_filename) and os.path.exists(licks_filename) and os.path.exists(wavelet_filename):
        sessions_to_process.append((mouse_name, session))

print(f"Found {len(sessions_to_process)} sessions to process.")

Found 215 sessions to process.


In [None]:
states_trial_type = pd.DataFrame(columns=['mouse_name', 'session', 'correct', 'choice', 'contrast', 
                                          'reaction', 'response', 'elongation', 'most_likely_states', 
                                          'identifiable_states', 'Bin', 'label'])
vars_interest = [0, 1, 2]
var_names = ['avg_wheel_vel', 'whisker_me', 'Lick count']

for m, mat in enumerate(sessions_to_process[:1]):

    mouse_name = mat[0]
    session = mat[1]
    fit_id = str(mouse_name+session)

    # Get mouse data
    # Get session data
    trials_file = data_path + "session_trials_" + str(session) + '_'  + mouse_name
    session_trials = pd.read_parquet(trials_file, engine='pyarrow').reset_index()  # I think resetting index is what gives the trial number?
    # Get design_matrix
    filename = data_path + "design_matrix_" + str(session) + '_'  + mouse_name
    unnorm_design_matrix =  pd.read_parquet(filename, engine='pyarrow').dropna().reset_index()
    # Get standardized design matrix
    data_file = data_path + "standardized_design_matrix_" + str(session) + '_'  + mouse_name
    standardized_designmatrix = np.load(data_file+str('.npy'))
    # Need to dropnans
    filtered_matrix = standardized_designmatrix[~np.isnan(standardized_designmatrix).any(axis=1)]
    design_matrix = filtered_matrix[:, vars_interest]    
    num_timesteps = np.shape(design_matrix)[0]
    mat_length = np.min([(num_timesteps // 5) * 5, (num_timesteps // 20) * 20])  # To account for different lengths
    
    # Get states per variable and concatenate
    c_states = []
    for v, var in enumerate(var_interest_map):
        use_path = path_sets[v]
        states_filename = os.path.join(use_path, "most_likely_states" + f"{'_'+str(optimal_k) if var=='wavelet' else var}"+ '_' + fit_id)
        if var == 'wavelet':
            most_likely_states = pickle.load(open(states_filename, "rb"))
        else:
            most_likely_states, _, _ = pickle.load(open(states_filename, "rb"))

        # Save
        if len(c_states) == 0:
            c_states = most_likely_states[:mat_length]
        else:
            c_states = np.vstack((c_states, most_likely_states[:mat_length]))      
    combined_states = remove_states_str(c_states.T, threshold)

    " Prepare data "
    
    design_matrix_heading = pd.DataFrame(columns=var_names)
    design_matrix_heading[var_names] = design_matrix[0:len(combined_states)]  # TODO: need to understand why number is the same and whether it depends on lag    " Prepare data "
    bins = unnorm_design_matrix[:len(combined_states)]['Bin']
    design_matrix_heading['Bin'] = bins
    design_matrix_heading['Lick count'] = unnorm_design_matrix[:len(combined_states)]['Lick count']

    # Transform states into identifiable states
    sets_to_identify = [[], ['whisker_me'], ['Lick count']]
    identifiable_states = state_identifiability(combined_states, design_matrix_heading, sets_to_identify)
    
    
    # # Change states back to integer
    # state_labels = np.unique(identifiable_states)
    # int_state = np.arange(0, len(state_labels), 1).astype(float)
    # if state_labels[-1] == 'nan':
    #     int_state[-1] = np.nan
    # # Define the mapping as a dictionary
    # mapping = {unique: key for unique, key in zip(state_labels, int_state)}
    # inverted_mapping = {v: k for k, v in mapping.items()}

    # # Use np.vectorize to apply the mapping
    # replace_func = np.vectorize(mapping.get)
    # new_states = replace_func(identifiable_states)
    
    replace_func = np.vectorize(identifiable_mapping.get)
    new_states = replace_func(identifiable_states)
    inverted_mapping = {v: k for k, v in identifiable_mapping.items()}
    
    
    # Align bins
    init = -1 * multiplier
    end = 1.5 * multiplier
    empirical_data = align_bin_design_matrix(init, end, event_type_list, session_trials, design_matrix_heading, new_states, multiplier)
    empirical_data = empirical_data.drop(columns=['new_bin'])
    empirical_data['identifiable_states'] = identifiable_states
    
    """ Trial types """
    # Split in trial types
    states_trial = states_per_trial_phase(empirical_data, session_trials, multiplier)
    states_trial['mouse_name'] = mouse_name
    states_trial['session'] = session
    # replace_func = np.vectorize(inverted_mapping.get)
    # str_states = replace_func(np.array(states_trial['most_likely_states']))
    # states_trial['identifiable_states'] = str_states
    states_trial = broader_label(states_trial)
    
    
    # # Plot raw trace over states       
    # init = 400
    # inter = 1000
    # wheel_over_wavelet_clusters(init, inter, empirical_data, session_trials)
    # print(session)
    

    # Save to big df
    states_trial_type = pd.concat([states_trial_type, states_trial], ignore_index=True)
    # if len(states_trial_type) == 0:
    #     states_trial_type = states_trial.copy()
    # else:
    #     states_trial_type = states_trial_type.append(states_trial)


An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
  return f_raw(*args, **kwargs)
You are setting values through chained assignment. Currently this works in certain cases, but when using Copy-on-Write (which will become the default behaviour in pandas 3.0) this will never work to update the original DataFrame or Series, because the intermediate object on which we are setting values will behave as a copy.
A typical example is when you are setting values in a column of a DataFrame, like:

df["col"][row_indexer] = value

Use `df.loc[row_indexer, "col"] = values` instead, to perform the assignment in a single step and ensure this keeps updating the original `df`.

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy

  trials['prev_choice'][1:] = trials['choice'][:-1]
A value is trying to be set on a copy of a slice from a DataFrame

See the cav

## Save data

In [23]:

# file_to_save = states_trial_type
filename = "states_trial_type"
save_path = '/home/ines/repositories/representation_learning_variability/Models/Sub-trial/3_postprocess_results/'
file_format = 'parquet'
script_name = 'E_states_trial_types_v5.ipynb'
metadata = save_and_log(states_trial_type, filename, file_format, save_path, script_name)
