In [26]:
import os
import pickle
import pandas as pd
import numpy as np
import matplotlib as mpl
from matplotlib import rcParams
import matplotlib.pyplot as plt
import h5py
import os
import tqdm
import scipy
from scipy import signal
import pickle
from tqdm import tnrange
from tqdm import tqdm

import seaborn as sns
from scipy.stats import norm,entropy,linregress
from scipy.optimize import minimize, curve_fit
from scipy.io import savemat
#import multiprocess as mp
#from multiprocess import Pool
from scipy.special import erf
import sys
import warnings
from odor_breathing_functions import*

#import odor_breathing_functions
#import functions_beh

warnings.filterwarnings('ignore')

cmap = plt.rcParams['axes.prop_cycle'].by_key()['color']
modulename = 'multiprocess'
mpl.rcParams['svg.fonttype'] = 'none'
sns.set_context('poster', font_scale=1.1)

# =============================
# CONFIGURATION
# =============================

class LoaderConfig:
    def __init__(
        self,
        path,
        animals,
        min_pulse=1,
        max_pulse=20,
        num_bins=20,
        num_hist=6,
        load_sniff=True,
        load_breathing=True,
        load_autocorr=True,
        verbose=True,
    ):
        self.path = path
        self.animals = animals
        self.min_pulse = min_pulse
        self.max_pulse = max_pulse
        self.num_bins = num_bins
        self.num_hist = num_hist
        self.load_sniff = load_sniff
        self.load_breathing = load_breathing
        self.load_autocorr = load_autocorr
        self.verbose = verbose

        print(self.path)

# =============================
# HELPER FUNCTIONS
# =============================

def find_session_files(path, animal):
    """Return all valid session pickle filenames for one animal."""
    fnames = []
    for date_number in range(20190419, 20220526):
        for session_id in range(10):
            fname = f"{animal}_{date_number}_{session_id}.pickle"
            fullpath = os.path.join(path, "session_" + fname)
            if os.path.isfile(fullpath):
                fnames.append(fname)
    return fnames


def load_session(fullpath):
    """Load a session pickle file and return its contents."""
    with open(fullpath, "rb") as handle:
        session_list = pickle.load(handle)
    return session_list[0]

def get_sniff_histogram(session,shuffled=False):
    bins = np.linspace(0,250,16) #16
    num_trials = session['num_trials']
    sniff_hist = np.zeros((num_trials,15)) #15
    sniff_kernel = np.load('inhalation_kernel_fine_weights_active.npy')
    sniff_kernel = sniff_kernel/sniff_kernel.mean()
    
    for i_trial in range(num_trials):
        sniff_raw = np.append(session['trial_pre_breath'][i_trial],session['trial_breath'][i_trial])
        #print(sniff_raw.shape)
        sniff = butter_lowpass_filter(sniff_raw,8,1000,3)
        sniff = (sniff - sniff.mean() +1)/2
        #print(sniff.shape)
        sniff_onset,_ = scipy.signal.find_peaks(sniff,distance=100, width=99)
        sniff_onset = sniff_onset[sniff_onset>2250]
        sniff_onset = sniff_onset[sniff_onset<7750]
        sniff_markers = np.zeros((10000,))
        sniff_phase = np.zeros((10000,))

        for i in range(len(sniff_onset)-1):
            nsample = sniff_onset[i+1]-sniff_onset[i]
            sniff_phase[sniff_onset[i]:sniff_onset[i+1]] = scipy.signal.resample(np.arange(0,250),nsample)
        sniff_markers[sniff_onset] = 1
        sniff_sampling_epoch = sniff[2500:7500]
        sniff_phase_sampling_epoch = sniff_phase[2500:]
        odor_command = session['trial_odor'][i_trial]
        odor = get_odor_profile_actual(odor_command)[0:5000]
        if shuffled:
            n_pulses = (np.diff(odor_command)==100).sum()
            valve_onset = np.random.randint(0,5000,(n_pulses,))
        else:
            valve_onset = np.argwhere(np.diff(odor_command)==100)[0:5000]
            #valve_onset = np.argwhere(odor==1)[0:5000]
        odor_onset = valve_onset + 25
        odor_phase = sniff_phase_sampling_epoch[odor_onset].squeeze()
        hist,_ = np.histogram(odor_phase,bins)
        sniff_hist[i_trial,:] = hist
    return sniff_hist


def process_session(session, config):
    """Process one session and return all arrays/lists of interest."""
    # Initialize dictionary for this sessionâ€™s outputs
    out = {
        "all_high_choices": [],
        "all_cum_odor": [],
        "all_correct_trials": [],
        "all_sniff_hist": [],
        "all_sniff_hist_shuffled": [],
        "autocorr": [],
        "all_breathing": [],
        
    }

    if session["type"] != "random":
        return out
    if session["delay_time"] != 5:
        return out
    if round(session["high_count"] / session["low_count"]) != 3:
        return out

    # Example: compute main trial variables
    non_idle_trials = np.invert(session["idle_trials"])
    correct_trials = session["correct_trials"][non_idle_trials]
    high_trials = session["high_trials"][non_idle_trials]
    high_choices = (correct_trials == high_trials)
    trial_odor = session["trial_odor"][non_idle_trials]

    cum_odor = np.ceil(trial_odor.sum(axis=1) / 5000)

    out["all_correct_trials"] = correct_trials
    out["all_high_choices"] = high_choices
    out["all_cum_odor"] = cum_odor

    if config.load_sniff:
        out["all_sniff_hist"] = get_sniff_histogram(session, False)[non_idle_trials]
        out["all_sniff_hist_shuffled"] = get_sniff_histogram(session, True)[non_idle_trials]

    if config.load_breathing:
        breathing_signal = np.append(
            session["trial_pre_breath"][non_idle_trials, -2500:],
            session["trial_breath"][non_idle_trials, :6500],
            axis=1,
        )
        out["all_breathing"] = breathing_signal

    if config.load_autocorr:
        out["autocorr"] = get_autocorr(session, config.num_hist)

    return out

def get_autocorr(session,n):
    high_trials = session['high_trials']
    autocorr = list()
    for i in range(n):
        autocorr.append(np.corrcoef(high_trials[i+1:],high_trials[:-(i+1)])[0][1])
    return np.array(autocorr)

def butter_lowpass(cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = scipy.signal.butter(order, normal_cutoff, btype='low', analog=False)
    return b, a

def butter_lowpass_filter(data, cutoff, fs, order=5):
    b, a = butter_lowpass(cutoff, fs, order=order)
    y = scipy.signal.filtfilt(b, a, data)
    return y

# =============================
# MAIN LOADER FUNCTION
# =============================

def run_loader(config):
    """Run the loader for all animals in config."""
    pulse_bins = np.linspace(config.min_pulse, config.max_pulse, config.max_pulse - config.min_pulse + 1)

    phigh_list = []
    performance_list = []
    bin_counts_list = []
    all_high_choices_list = []
    all_cum_odor_list = []
    all_correct_trials_list = []
    all_sniff_hist_list = []
    all_sniff_hist_shuffled_list = []
    all_breathing_list = []
    autocorr_list = []
    fname_list=[]

    for animal in config.animals:
        if config.verbose:
            print(f"Loading data for {animal}...")

        fname_list = find_session_files(config.path, animal)
        for i_file in tnrange(len(fname_list), desc=f"{animal} sessions"):
            fname = fname_list[i_file]
            session_path = os.path.join(config.path, "session_" + fname)
            session = load_session(session_path)
            results = process_session(session, config)

            all_high_choices_list.append(results["all_high_choices"])
            all_cum_odor_list.append(results["all_cum_odor"])
            all_correct_trials_list.append(results["all_correct_trials"])

            if config.load_sniff:
                all_sniff_hist_list.append(results["all_sniff_hist"])
                all_sniff_hist_shuffled_list.append(results["all_sniff_hist_shuffled"])

            if config.load_breathing:
                all_breathing_list.append(results["all_breathing"])

            if config.load_autocorr:
                autocorr_list.append(results["autocorr"])
            
            fname_list.append(fname)

    # Example: return everything in one dictionary
        return {
            'fname':fname_list,
            
            "phigh_list": phigh_list,
            "performance_list": performance_list,
            "bin_counts_list": bin_counts_list,
            "all_high_choices_list": all_high_choices_list,
            "all_cum_odor_list": all_cum_odor_list,
            "all_correct_trials_list": all_correct_trials_list,
            "all_sniff_hist_list": all_sniff_hist_list,
            "all_sniff_hist_shuffled_list": all_sniff_hist_shuffled_list,
            "all_breathing_list": all_breathing_list,
            "autocorr_list": autocorr_list,
        }


In [27]:
config = LoaderConfig(
    path= os.getcwd()+"/Session/",
    animals=["Tabby",'Bengal'],
    load_sniff=True,         # enable sniff hist
    load_breathing=False,    # skip breathing
    load_autocorr=False,     # skip autocorr
)


data = run_loader(config)

/Users/boero/OEA_Data_Analysis/Session/
Loading data for Tabby...


Tabby sessions:   0%|          | 0/18 [00:00<?, ?it/s]

In [23]:
data

{'fname': ['Bengal_20190419_0.pickle',
  'Bengal_20190422_0.pickle',
  'Bengal_20190423_0.pickle',
  'Bengal_20190426_0.pickle',
  'Bengal_20190427_0.pickle',
  'Bengal_20190430_0.pickle',
  'Bengal_20190501_0.pickle',
  'Bengal_20190502_0.pickle',
  'Bengal_20190503_0.pickle',
  'Bengal_20190504_0.pickle',
  'Bengal_20190505_0.pickle',
  'Bengal_20190507_0.pickle',
  'Bengal_20190508_0.pickle',
  'Bengal_20190509_0.pickle',
  'Bengal_20190510_0.pickle',
  'Bengal_20190511_0.pickle',
  'Bengal_20190512_0.pickle',
  'Bengal_20190513_0.pickle',
  'Bengal_20190516_0.pickle',
  'Bengal_20190517_0.pickle',
  'Bengal_20190518_0.pickle',
  'Bengal_20190519_0.pickle',
  'Bengal_20190419_0.pickle',
  'Bengal_20190422_0.pickle',
  'Bengal_20190423_0.pickle',
  'Bengal_20190426_0.pickle',
  'Bengal_20190427_0.pickle',
  'Bengal_20190430_0.pickle',
  'Bengal_20190501_0.pickle',
  'Bengal_20190502_0.pickle',
  'Bengal_20190503_0.pickle',
  'Bengal_20190504_0.pickle',
  'Bengal_20190505_0.pickle',
 

In [18]:
print(len(data['all_high_choices_list']))

40
