In [None]:

# To be able to make edits to repo without having to restart notebook
%load_ext autoreload
%autoreload 2


In [None]:
# Outside imports
import os, sys
import numpy as np
import matplotlib.pyplot as plt

# Set necessary paths / make project path = ...../neuroscikit/
unit_matcher_path = os.getcwd()
prototype_path = os.path.abspath(os.path.join(unit_matcher_path, os.pardir))
project_path = os.path.abspath(os.path.join(prototype_path, os.pardir))
lab_path = os.path.abspath(os.path.join(project_path, os.pardir))
out_path = os.path.abspath(os.path.join(lab_path, os.pardir))
sys.path.append(project_path)
os.chdir(project_path)
print(project_path)

In [None]:
# Internal imports

# Read write modules
from x_io.rw.axona.batch_read import make_study
from _prototypes.unit_matcher.read_axona import read_sequential_sessions, temp_read_cut
from _prototypes.unit_matcher.write_axona import format_new_cut_file_name

# Unit matching modules
from _prototypes.unit_matcher.main import format_cut, run_unit_matcher, map_unit_matches_first_session, map_unit_matches_sequential_session
from _prototypes.unit_matcher.session import compare_sessions
from _prototypes.unit_matcher.waveform import time_index, derivative, derivative2, morphological_points

# External imports

import math
from sklearn import linear_model
from scipy import stats
import statsmodels.api as sm
from statsmodels.genmod.families import Poisson, Gaussian
from sklearn.metrics import r2_score
from statsmodels.genmod.families.links import identity, log


In [None]:
""" If a setting is not used for your analysis (e.g. smoothing_factor), just pass in an arbitrary value or pass in 'None' """
STUDY_SETTINGS = {

    'ppm': 511,  # EDIT HERE

    'smoothing_factor': None, # EDIT HERE

    'useMatchedCut': False,  # EDIT HERE, set to False if you want to use runUnitMatcher, set to True after to load in matched.cut file
}


# Switch devices to True/False based on what is used in the acquisition (to be extended for more devices in future)
device_settings = {'axona_led_tracker': True, 'implant': True} 

# Make sure implant metadata is correct, change if not, AT THE MINIMUM leave implant_type: tetrode
implant_settings = {'implant_type': 'tetrode', 'implant_geometry': 'square', 'wire_length': 25, 'wire_length_units': 'um', 'implant_units': 'uV'}

# WE ASSUME DEVICE AND IMPLANT SETTINGS ARE CONSISTENCE ACROSS SESSIONS

# Set channel count + add device/implant settings
SESSION_SETTINGS = {
    'channel_count': 4, # EDIT HERE, default is 4, you can change to other number but code will check how many tetrode files are present and set that to channel copunt regardless
    'devices': device_settings, # EDIT HERE
    'implant': implant_settings, # EDIT HERE
}

STUDY_SETTINGS['session'] = SESSION_SETTINGS

settings_dict = STUDY_SETTINGS

In [None]:
# data_dir = lab_path + r'\neuroscikit_test_data\20180530-ROUND-3300-1X2B3A' 
data_dir = lab_path + r'\neuroscikit_test_data\LEC_odor\AD\Odor_119a-6' 


# To use in unit matching
settings_dict_unmatched = settings_dict
settings_dict_unmatched['useMatchedCut'] = True

In [None]:
study = make_study([data_dir], settings_dict_unmatched)
study.make_animals()

In [None]:
def event_times_to_count(event_times, T):
    dt = 0.5
    new_time_index = np.arange(0,T,dt)
    ct, bins = np.histogram(event_times, bins=new_time_index)
    # print(ct.shape, bins[:-1].shape)
    return ct, bins[:-1]

def aggregate_cell_trials(agg_events, agg_event_objects, trial_start_times, trial_length, trial_ids):
    def _filtEvent(x, event_time, start, end):
        if event_time>= start and event_time < end:
            return x
        
    sequential_trials = []
    trial_dict = {}
    # for cell
    for i in range(len(agg_events)):
        trial_dict[i] = {}
        trial_dict[i]['obj'] = {}
        trial_dict[i]['data'] = {}

        cell_trials = []
        # for trial type
        for k in range(len(trial_start_times)):
            start = trial_start_times[k]
            end = trial_start_times[k] + trial_length
            trial_id = trial_ids[k]
            trial_dict[i]['data'][trial_id] = []
            trial_dict[i]['obj'][trial_id] = []

            # for ses that has cell
            for j in range(len(agg_events[i])):
                ses_events = agg_events[i][j]
                ids = np.array(list(map(lambda x: _filtEvent(x,ses_events[x],start,end), np.arange(0, len(ses_events), 1))))
                ids = ids[ids != None]
                # print(ids)
                ct, new_time_index = event_times_to_count(np.array(agg_events[i][j])[np.array(ids, dtype=np.int32)], trial_length)

                obj = agg_event_objects[i][j]                                        

                trial_dict[i]['data'][trial_id].append(ct)
                trial_dict[i]['obj'][trial_id].append(obj)

                cell_trials = np.hstack((cell_trials, ct))

        sequential_trials.append(cell_trials)

    return np.array(sequential_trials), trial_dict, np.array(new_time_index)
                

def aggregate_event_times_matched(study):
    agg_events = []
    agg_event_objects = []
    agg_events_binary = []
    prev_time_index = None
    for animal in study.animals:

        max_matched_cell_count = len(animal.sessions[sorted(list(animal.sessions.keys()))[-1]].get_cell_data()['cell_ensemble'].cells)
        # print(max_matched_cell_count)
        for k in range(int(max_matched_cell_count)):
            cell_label = k + 1
            cell_events = []
            cell_event_objects = []
            cell_events_binary = []
            for i in range(len(list(animal.sessions.keys()))):
                seskey = 'session_' + str(i+1)
                ses = animal.sessions[seskey]
                ensemble = ses.get_cell_data()['cell_ensemble']
                if cell_label in ensemble.get_cell_label_dict():
                    cell = ensemble.get_cell_by_id(cell_label)
                    cell_events.append(cell.event_times)
                    cell_event_objects.append(cell)
                    ct, time_index = event_times_to_count(cell.event_times, cell.cluster.time_index[-1])
                    if prev_time_index is not None:
                        assert time_index.all() == prev_time_index.all()
                    cell_events_binary.append(ct)
                    prev_time_index = time_index
            agg_events.append(cell_events)
            agg_event_objects.append(cell_event_objects)
            agg_events_binary.append(cell_events_binary)

    # agg_events saves as (cell, sessions_cell_is_in, event_times)
    return agg_events, np.array(agg_event_objects), np.array(agg_events_binary), np.array(time_index)

def get_time_regressors(time_index, trial_start_times, trial_length):
    session_ramping = time_index

    dt = time_index[1] - time_index[0]

    trial_ramping = []
    for i in range(len(trial_start_times)):
        ramp = np.arange(0,trial_length, dt)
        trial_ramping = np.hstack((trial_ramping, ramp))

    if len(trial_ramping) < len(session_ramping):
        trial_ramping = np.hstack((trial_ramping, np.arange(0, session_ramping[-1]-trial_start_times[-1]-60+dt, dt)))

    return session_ramping, trial_ramping

def get_odor_regressors(time_index, odor_labels, odor_presentation_times, odor_window, trial_length):
    odor_pulses = []

    post_odor_window = []

    for i in range(len(odor_labels)):
        pulse = np.zeros(len(time_index))
        post_pulse = np.zeros(len(time_index))
        for j in range(len(odor_presentation_times)):
            # print(np.where(np.array(time_index) >=  float(odor_presentation_times[j] + odor_window))[0], np.where(np.array(time_index) < float(odor_presentation_times[j] + odor_window))[0])
            end_post = np.where(np.array(time_index) >=  float(odor_presentation_times[j] + odor_window))[0][0]
            start = np.where(np.array(time_index) >= float(odor_presentation_times[j]))[0][0]
            end = np.where(np.array(time_index) < float(odor_presentation_times[j]) + odor_window)[0][-1]

            pulse[start:end] = 1

            post_pulse[end : end_post] = 1
            
        odor_pulses.append(pulse)
        post_odor_window.append(post_pulse)

    return odor_pulses, post_odor_window

def collect_synth_regressors(time_index):

    session_ramping, trial_ramping = get_time_regressors(time_index, [0,60,120,180], 60)

    odor_pulses, post_odor_window = get_odor_regressors(time_index, ['X','B', 'A', 'X'], [0,60,120,180], 10, 60)

    regs = np.vstack((session_ramping, trial_ramping, odor_pulses, post_odor_window))
    reg_labels = ['sessionT', 'trialT', 'odorX', 'odorB']

    return np.array(regs)

def split_test_train(data, percentage=0.70):
    
    idx = int(percentage * data.shape[1])

    train = data[:idx]
    test = data[idx:]

    return train, test

def split_endog_exog(y, X, p):
    idx = int(p * y.shape[0])

    trainY = y[:idx]
    testY = y[idx:]
    print(X.shape, y.shape, idx)
    trainX = X[:,:idx]
    testX = X[:,idx:]

    print(trainX.shape, trainY.shape, testX.shape, testY.shape)

    return trainX, trainY, testX, testY


In [None]:

agg_event_times, agg_event_objects, agg_events_binary, time_index = aggregate_event_times_matched(study)

sequential_trials, trial_dict, new_time_index = aggregate_cell_trials(agg_event_times, agg_event_objects, [0,60,120,180], 60, ['O','B', 'A', 'X'])



In [None]:
# plt.plot(sequential_trials)

In [None]:
time_index.shape