In [13]:
import hemiSwap_consts as hconsts
import os
import re
from datetime import datetime
import numpy as np
import numpy.matlib as npm
import pandas as pd
import xarray as xr
import seaborn as sns
import spynal as sp
import matplotlib.pyplot as plt
from spynal.spikes import times_to_bool, rate
from itertools import product
from sklearn.cross_decomposition import CCA
import random
from scipy import stats
from matplotlib.ticker import MaxNLocator
from scipy.stats import pearsonr
from joblib import Parallel, delayed
from sklearn.model_selection import KFold
from spynal.matIO import loadmat
from spynal.spikes import psth

def load_and_process_session(session_id, hemiswap):
    """Load and process a single session data."""
    filepath = os.path.join(hemiswap.loadDir, f'{session_id}.mat')
    print(filepath)
    # Load mat file variables
    variables = ['ain', 'ainSchema', 'analogChnlInfo', 'electrodeInfo', 'eventSchema', 
                'fileInfo', 'lfp', 'lfpSchema', 'sessionInfo', 'spikeChnlInfo', 
                'spikeTimes', 'spikeTimesSchema', 'trialInfo', 'unitInfo']
    
    data = loadmat(filepath, variables=variables, 
                  typemap={'trialInfo':'DataFrame'}, verbose=True)
    
    # Unpack required data
    trialInfo = data['trialInfo']
    unitInfo = data['unitInfo']
    spikeTimes = data['spikeTimes']
    
    # Get trial indices
    trial_indices = {
        'noswap_right': np.where((trialInfo['sampleHemifield']=='right') & 
                                (trialInfo['isSwap']== False))[0],
        'noswap_left': np.where((trialInfo['sampleHemifield']=='left') & 
                               (trialInfo['isSwap']== False))[0],
        'swap_right': np.where((trialInfo['sampleHemifield']=='right') & 
                              (trialInfo['isSwap']== True))[0],
        'swap_left': np.where((trialInfo['sampleHemifield']=='left') & 
                             (trialInfo['isSwap']== True))[0]
    }
    
    # Get hemisphere indices
    hemisphere_indices = {
        'right': np.where(unitInfo['hemisphere']=='right')[0],
        'left': np.where(unitInfo['hemisphere']=='left')[0]
    }
    
    # Process spike data
    spike_data = process_spike_data(spikeTimes, trial_indices, hemisphere_indices)
    
    # Combine all session data
    session_data = {**data, **{
        'trial_indices': trial_indices,
        'hemisphere_indices': hemisphere_indices,
        'spike_data': spike_data
    }}
    
    return session_data

def process_spike_data(spikeTimes, trial_indices, hemisphere_indices):
    """Process spike data for different conditions and hemispheres."""
    spike_data = {}
    
    # Process no-swap trials
    for start_hemi in ['right', 'left']:
        trials = trial_indices[f'noswap_{start_hemi}']
        spikes = np.squeeze(spikeTimes[trials, :])
        
        for record_hemi in ['right', 'left']:
            hemi_idx = hemisphere_indices[record_hemi]
            key = f'{start_hemi}_{record_hemi}_hemi_trials'
            spike_data[key] = np.squeeze(spikes[:, hemi_idx])
    
    return spike_data

# Main execution
def load_all_sessions(hemiswap, subject=None,n_sessions=None): #n_sessions is the step of sessions to load
    """Load all sessions for a given subject (or all subjects if None)."""
    sessions = {}
    
    # Filter sessions by subject if specified
    if subject:
        session_list = [s for s in hemiswap.sessions['full'] 
                       if subject in s]
        print(session_list)
    else:
        session_list = hemiswap.sessions['full']
    
    if n_sessions:
        session_list = session_list[::n_sessions]
    
    # Load each session
    for session_id in session_list:
        print(f"Processing session: {session_id}")
        try:
            sessions[session_id] = load_and_process_session(session_id, hemiswap)
        except Exception as e:
            print(f"Error processing session {session_id}: {str(e)}")
            continue
    
    return sessions



    

In [None]:
hemiswap = hconsts.HemiSwap_consts('miller-lab-3', 'tiergan')
sessions = load_all_sessions(hemiswap,'Edith',n_sessions=20) 

  warn(msg)


['hemiSwap_Edith_20180515', 'hemiSwap_Edith_20180517', 'hemiSwap_Edith_20180524', 'hemiSwap_Edith_20180525', 'hemiSwap_Edith_20180529', 'hemiSwap_Edith_20180530', 'hemiSwap_Edith_20180531', 'hemiSwap_Edith_20180601', 'hemiSwap_Edith_20180604', 'hemiSwap_Edith_20180605', 'hemiSwap_Edith_20180606', 'hemiSwap_Edith_20180607', 'hemiSwap_Edith_20180608', 'hemiSwap_Edith_20180612', 'hemiSwap_Edith_20180613', 'hemiSwap_Edith_20180614', 'hemiSwap_Edith_20180615', 'hemiSwap_Edith_20180619', 'hemiSwap_Edith_20180620', 'hemiSwap_Edith_20180621', 'hemiSwap_Edith_20180622']
Processing session: hemiSwap_Edith_20180515
/mnt/common/datasets/hemiSwap/mat/hemiSwap_Edith_20180515.mat


In [10]:
sessions

{}

dict_keys([])