In [None]:
""" ACTIVE JUPYTER NOTEBOOK TO BATCH RUN UNIT MATCHING ALGORITHM """

# To be able to make edits to repo without having to restart notebook
%load_ext autoreload
%autoreload 2


In [None]:
# Outside imports
import os, sys
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# Set necessary paths / make project path = ...../neuroscikit/
unit_matcher_path = os.getcwd()
prototype_path = os.path.abspath(os.path.join(unit_matcher_path, os.pardir))
project_path = os.path.abspath(os.path.join(prototype_path, os.pardir))
lab_path = os.path.abspath(os.path.join(project_path, os.pardir))
sys.path.append(project_path)
os.chdir(project_path)
print(project_path)

# Internal imports

# Read write modules
from x_io.rw.axona.batch_read import make_study
from _prototypes.unit_matcher.read_axona import read_sequential_sessions, temp_read_cut
from _prototypes.unit_matcher.write_axona import format_new_cut_file_name

# Unit matching modules
from _prototypes.unit_matcher.main import format_cut, run_unit_matcher, map_unit_matches_first_session, map_unit_matches_sequential_session
from _prototypes.unit_matcher.session import compare_sessions
from _prototypes.unit_matcher.waveform import time_index, derivative, derivative2, morphological_points

In [None]:
""" ONLY EDIT THE SETTINGS IN THE NEXT TWO CELLS """

In [None]:
""" If a setting is not used for your analysis (e.g. smoothing_factor), just pass in an arbitrary value or pass in 'None' """
STUDY_SETTINGS = {

    'ppm': 511,  # EDIT HERE

    'smoothing_factor': None, # EDIT HERE

    'useMatchedCut': False,  # EDIT HERE, set to False if you want to use runUnitMatcher, set to True after to load in matched.cut file
}


# Switch devices to True/False based on what is used in the acquisition (to be extended for more devices in future)
device_settings = {'axona_led_tracker': True, 'implant': True} 

# Make sure implant metadata is correct, change if not, AT THE MINIMUM leave implant_type: tetrode
implant_settings = {'implant_type': 'tetrode', 'implant_geometry': 'square', 'wire_length': 25, 'wire_length_units': 'um', 'implant_units': 'uV'}

# WE ASSUME DEVICE AND IMPLANT SETTINGS ARE CONSISTENCE ACROSS SESSIONS

# Set channel count + add device/implant settings
SESSION_SETTINGS = {
    'channel_count': 4, # EDIT HERE, default is 4, you can change to other number but code will check how many tetrode files are present and set that to channel copunt regardless
    'devices': device_settings, # EDIT HERE
    'implant': implant_settings, # EDIT HERE
}

STUDY_SETTINGS['session'] = SESSION_SETTINGS

settings_dict = STUDY_SETTINGS

In [None]:
# EDIT HERE --> change to path to your data, can ignore lab_path and put full file path to a folder as: r'path_to_data'
# data_dir = lab_path + r'\neuroscikit_test_data\20180502-ROUND-3000'
data_dir = lab_path + r'\neuroscikit_test_data\single_sequential'
# data_dir = lab_path + r'\neuroscikit_test_data\Outputs'
# data_dir = lab_path + r'\neuroscikit_test_data\20170315-270-3525_Test' 
# data_dir = lab_path + r'\neuroscikit_test_data\RadhaData\Data\highPHF'

# To use in unit matching
settings_dict_unmatched = settings_dict
settings_dict_unmatched['useMatchedCut'] = False

In [None]:
# study = make_study([data_dir], settings_dict_unmatched)
# study.make_animals()

In [None]:
# for animal in study.animals:
#     # print(animal.animal_id)
#     print('sessions for animal ' + str(animal.animal_id))
#     for seskey in animal.sessions:
#         ses = animal.sessions[seskey]
#         # print(ses.animal_id)
#         print(ses.session_metadata.file_paths['cut'])

In [None]:
# Run unit matching on non-matched study, will save new matched cut file. 
# First all data is loaded ('Animal ID set': Animal1_tet1, Animal1_tet2, etc..)
# Then sessions PER animal and spikes per cell are sorted and  ('Session data added (to animal), spikes sorted by cell': Animal1_tet1_ses1, Animal1_tet1_ses2, etc..)
# Then unit matching begins Animal1_tet1_ses1_cell_1 vs Animal1_tet1_ses2_cell_1, Animal1_tet1_ses1_cell_1 vs Animal1_tet1_ses2_cell_2, Animal1_tet1_ses1_cell_1 vs Animal1_tet1_ses2_cell_3, etc...
unmatched_study = run_unit_matcher([data_dir], settings_dict_unmatched)
print('COMPLETED UNIT MATCHING')

In [None]:
# Quick check that session file are grouped correctly + ordered in time
tets = []
for ses in unmatched_study.animals[0].sessions:
    print(unmatched_study.animals[0].sessions[ses].session_metadata.file_paths)
    tets.append(unmatched_study.animals[0].sessions[ses].session_metadata.file_paths['tet'])
    print(unmatched_study.animals[0].sessions[ses].datetime)

In [None]:
sorted(tets)

In [None]:
# New settings dictionary indicating the matched cut file should now be used for file loading
settings_dict_matched = settings_dict
settings_dict_matched['useMatchedCut'] = True

# Load new study but using labels from matched cut file
matched_study = make_study([data_dir], settings_dict_matched)

# Sort spikes by cell and order sessions sequentially
matched_study.make_animals()

In [None]:
# Look at matched cell ids from new cut files
for animal in matched_study.animals:
    print('New Animal')
    for session in animal.sessions.values():
        print('Session date & time: ' + str(session.datetime))
        print('Matched cell ids: ' + str(np.unique(session.get_spike_data()['spike_cluster'].cluster_labels)))

In [None]:
# Pull out session classes and cell ensembles from matched and unmatched studies for plotting

session1 = matched_study.animals[0].sessions['session_1']
ensemble1 = matched_study.animals[0].ensembles['session_1']
session2 = matched_study.animals[0].sessions['session_2']
ensemble2 = matched_study.animals[0].ensembles['session_2']

unmatched_ensembles1 = unmatched_study.animals[0].ensembles['session_1']
unmatched_ensembles2 = unmatched_study.animals[0].ensembles['session_2']

In [None]:
""" Plot session 1 (left) & session 2 (right) MATCHED units """

pair_count = len(ensemble1.get_label_ids())

for i in range(pair_count):
# for i in range(2):

    fig = plt.figure(figsize=(6,12))

    axes = []

    waveforms1 = ensemble1.cells[i].signal
    waveforms2 = ensemble2.cells[i].signal

    avg_waveforms1 = np.mean(waveforms1, axis=0)
    avg_waveforms2 = np.mean(waveforms2, axis=0)

    assert waveforms1.shape[1] == avg_waveforms1.shape[0]

    for j in range(0,avg_waveforms1.shape[0]*2,2):
        ax1 = plt.subplot(avg_waveforms1.shape[0],2,j+1)
        ax2 = plt.subplot(avg_waveforms1.shape[0],2,j+2)

        ax1.plot(waveforms1[:,int(j/2)].T, color='gray', lw=0.5, alpha=0.5)
        ax2.plot(waveforms2[:,int(j/2)].T, color='gray', lw=0.5, alpha=0.5)

        ax1.plot(avg_waveforms1[int(j/2)], color='k', lw=2)
        ax2.plot(avg_waveforms2[int(j/2)], color='k', lw=2)

        ax1.set_title('Channel ' + str(int(j/2+1)))
        ax2.set_title('Channel ' + str(int(j/2+1)))

        axes.append(ax1)
        axes.append(ax2)

    for ax in axes:
        ax.set_xlabel('Bin Number')
        ax.set_ylabel('Waveform')

    fig.suptitle('Session 1 (left) & 2 (right) - Unit ' + str(i+1))

    fig.tight_layout()
    plt.show()
