In [None]:
import os, shutil, time
import yaml
import numpy as np
import pyqtgraph as pg
import tridesclous as tdc

# Load the data

In [None]:
dirname = '../ffs-ignore/spike-sorting-working-dir'
# filenames = ['../ffs-ignore/2018-06-21_IN-VIVO_JG-08 002.axgd']
# channel_groups = {0: {"channels": [3], "geometry": {3: [0, 0]}}} # BN3 only
filenames = ['../ffs-ignore/2018-06-21_IN-VIVO_JG-08 002 BN3-Only-Artifacts-Removed.axgx']

# delete prior workspace
if os.path.exists(dirname):
    shutil.rmtree(dirname)

dataio = tdc.DataIO(dirname = dirname)
dataio.set_data_source(type = 'Axograph', filenames = filenames)
# dataio.set_channel_groups(channel_groups)
print(dataio)

cc = tdc.CatalogueConstructor(dataio = dataio)
print(cc)

# Parameters

In [None]:
tdc_params = {
    'fullchain_kargs' : {
        'duration' : 300.,
        'preprocessor' : {
            'highpass_freq' : None,
            'lowpass_freq' : None,
            'chunksize' : 4096,
            'lostfront_chunksize' : 64,
        },
        'peak_detector' : {
            'peak_sign' : '-',
#             'relative_threshold' : 100,
            'relative_threshold' : 5,
            'peak_span' : 0.0002,
        },
        'noise_snippet' : {
            'nb_snippet' : 300,
        },
        'extract_waveforms' : {
            'n_left' : -20,
            'n_right' : 30,
            'mode' : 'rand',
            'nb_max' : 20000,
            'align_waveform' : False,
        },
        'clean_waveforms' : {
            'alien_value_threshold' : 400., # relative threshold for discarding spikes
        },
    },

    'feat_method' : 'global_pca',
    'feat_kargs' : {'n_components': 5},

    'clust_method' : 'sawchaincut',
    'clust_kargs' : {},
#     'clust_method' : 'gmm',
#     'clust_kargs' : {
#         'n_clusters' : 3,
#         'covariance_type' : 'full',
#         'n_init' : 10,
#     },
}

# Create initial catalogue all at once ...

In [None]:
tdc.apply_all_catalogue_steps(cc, **tdc_params)
print(cc)

# ... or step by step

##### Set parameters for filter and peak detection

In [None]:
# cc.set_preprocessor_params(**tdc_params['fullchain_kargs']['preprocessor'], **tdc_params['fullchain_kargs']['peak_detector'])

# # cc.set_preprocessor_params(
# #     chunksize = chunksize,
# #     lostfront_chunksize = 64,#1,#None,#64,
    
# #     highpass_freq = None,#14.,
# #     lowpass_freq  = None,#100000.,
    
# #     peak_sign = '-',
# #     relative_threshold = 100,
# #     peak_span = 0.0002,
# # )

# print(cc)

##### Estimate background noise

In [None]:
# cc.estimate_signals_noise(seg_num=0, duration=min(10., tdc_params['fullchain_kargs']['duration'], dataio.get_segment_length(seg_num=0)/dataio.sample_rate*.99))

# # cc.estimate_signals_noise(seg_num=0, duration=15.)
# # cc.estimate_signals_noise(seg_num=0, duration=10.)

# print(cc.signals_medians)
# print(cc.signals_mads)
# print(cc)

##### Run the filter and peak detection on a data subset

In [None]:
# cc.run_signalprocessor(duration=tdc_params['fullchain_kargs']['duration'])

# # cc.run_signalprocessor(duration=60.)
# # cc.run_signalprocessor(duration=300.)

# print(cc)

##### Extract waveforms around the detected peaks

In [None]:
# cc.extract_some_waveforms(**tdc_params['fullchain_kargs']['extract_waveforms'])

# # cc.extract_some_waveforms(n_left=-25, n_right=40, mode='rand', nb_max=10000, align_waveform=True)
# # cc.extract_some_waveforms(n_left=-20, n_right=30, mode='rand', nb_max=20000, align_waveform=False)

# print(cc)

##### Discard some bad waveforms (artifacts)

In [None]:
# cc.clean_waveforms(**tdc_params['fullchain_kargs']['clean_waveforms'])

# # cc.clean_waveforms(alien_value_threshold=400.) # relative threshold for discarding spikes

# print(cc)

##### Shorten or extend waveforms based on amplitude above noise

In [None]:
# n_left, n_right = cc.find_good_limits(mad_threshold = 1.1,)

# print(n_left, n_right)
# print(cc)

##### Extract noise samples for comparison to spikes

In [None]:
# cc.extract_some_noise(**tdc_params['fullchain_kargs']['noise_snippet'])

# # cc.extract_some_noise(nb_snippet = 300)

# print(cc)

##### Extract spike features (PCA)

In [None]:
# cc.extract_some_features(method=tdc_params['feat_method'], **tdc_params['feat_kargs'])

# # cc.extract_some_features(method='global_pca', n_components=5)
# # cc.extract_some_features(method='peak_max')

# print(cc)

##### Cluster spikes based on extracted features

In [None]:
# cc.find_clusters(method=tdc_params['clust_method'], **tdc_params['clust_kargs'])

# # cc.find_clusters(method='kmeans', n_clusters=12)
# # cc.find_clusters(method='gmm', n_clusters=5, covariance_type='full', n_init=10)
# # cc.find_clusters(method='gmm', n_clusters=3, covariance_type='full', n_init=10)

# print(cc)

# Preview and manually merge/split/delete clusters ...

In [None]:
# Must click "Make catalogue for peeler" when finished!
app = pg.mkQApp()
win = tdc.CatalogueWindow(cc)
win.traceviewer.params['xsize_max'] = 300.0  # increase upper bound on time zoom
win.traceviewer.params['zoom_size'] = 30.0   # increase amount of time plotted after clicking on a spike
win.traceviewer.spinbox_xsize.setValue(30.0) # increase amount of time plotted initially
win.traceviewer.gain_zoom(50)                # increase amount of voltage plotted initially
win.show()
app.exec_()

# ... or merge/split/delete clusters programmatically

In [None]:
# #order cluster by waveforms rms
# cc.order_clusters(by='waveforms_rms')

# #put labels to trash
# # mask = (
# #     cc.all_peaks['cluster_label'] == 0 or
# #     cc.all_peaks['cluster_label'] == 1 or
# #     cc.all_peaks['cluster_label'] == 2 or
# #     cc.all_peaks['cluster_label'] == 3
# # )
# # mask = cc.all_peaks['cluster_label'] != 4
# # cc.all_peaks['cluster_label'][mask] = -1
# # cc.on_new_cluster()

# #save the catalogue
# cc.make_catalogue_for_peeler()

# print(cc)

# Run Peeler: classify spikes in the full dataset using template matching

In [None]:
initial_catalogue = dataio.load_catalogue(chan_grp=0)
if initial_catalogue is not None:
    print(cc)

    peeler = tdc.Peeler(dataio)
    peeler.change_params(catalogue=initial_catalogue, chunksize=tdc_params['fullchain_kargs']['preprocessor']['chunksize'])

    t1 = time.perf_counter()
    peeler.run()
    t2 = time.perf_counter()
    print('peeler.run', t2-t1)

    print()
    for seg_num in range(dataio.nb_segment):
        spikes = dataio.get_spikes(seg_num)
        print('seg_num', seg_num, 'nb_spikes', spikes.size)
else:
    print('You need to make a catalogue for the peeler first!')

# View the final result

In [None]:
if initial_catalogue is not None:
    app = pg.mkQApp()
    win = tdc.PeelerWindow(dataio=dataio, catalogue=initial_catalogue)
    win.traceviewer.params['xsize_max'] = 300.0  # increase upper bound on time zoom
    win.traceviewer.params['zoom_size'] = 30.0   # increase amount of time plotted after clicking on a spike
    win.traceviewer.spinbox_xsize.setValue(30.0) # increase amount of time plotted initially
    win.traceviewer.gain_zoom(50)                # increase amount of voltage plotted initially
    win.show()
    app.exec_()
else:
    print('You need to make a catalogue for the peeler and run the peeler first!')

# Export the spikes and (initial) parameters to a file

In [None]:
if initial_catalogue is not None:
    dataio.export_spikes('../ffs-ignore/spike-sorting-export-dir', formats = 'csv')
    with open('../ffs-ignore/spike-sorting-export-dir/tdc_initial_params.yml', 'w') as f:
        yaml.dump(tdc_params, f, default_flow_style=False)
else:
    print('You need to make a catalogue for the peeler and run the peeler first!')