In [None]:
%load_ext autoreload
%autoreload 2

# bin vs smrx

In [None]:
from pathlib import Path
import numpy as np
from pprint import pprint
from datetime import datetime
import matplotlib.pyplot as plt

import spikeextractors as se
import spiketoolkit as st
import spikesorters as ss
import spikecomparison as sc
import spikewidgets as sw

# from mease_lab_to_nwb.convert_ced.cednwbconverter import quick_write

## Load & concatenate CED recordings

In [None]:
ced_files = sorted(
    list(
        Path(
            r"/mnt/sds-hd/sd19b001/PainData/Corrected_Channel_Map/L6/Cortex/20.8.21/KS2"
        ).glob("**/*.smrx")
    )
)
probe_file = "cambridge_neurotech_H3.prb"

In [None]:
recordings = []
for file in ced_files:
    # Automatically select Rhd channels
    channel_info = se.CEDRecordingExtractor.get_all_channels_info(file)
    rhd_channels = []
    for ch, info in channel_info.items():
        if "Rhd" in info["title"]:
            rhd_channels.append(ch)
    recording = se.CEDRecordingExtractor(file, smrx_channel_ids=rhd_channels)
    recording = se.load_probe_file(recording, probe_file)
    recordings.append(recording)
# instantiate a MultiRecording object
multirecording = se.MultiRecordingTimeExtractor(recordings)

In [None]:
for epoch_name in multirecording.get_epoch_names():
    print(multirecording.get_epoch_info(epoch_name))

In [None]:
multirecording_cmr = st.preprocessing.common_reference(multirecording)

## Load bin recording

In [None]:
# bin_file = Path(r'/mnt/sds-hd/sd19b001/PainData/Corrected_Channel_Map/L6/Cortex/20.8.21/KS2/m6.bin')
bin_file = Path(
    r"/mnt/sds-hd/sd19b001/PainData/Corrected_Channel_Map/L6/Cortex/20.8.21/KS2_5/Troubleshooting_11_3_22/m6Troubleshooting11322.bin"
)
recording_prb = "cambridge_neurotech_H3.prb"
sampling_frequency = 3.003003003003003e04
data_type = "int16"
numChan = 64
recording_bin = se.BinDatRecordingExtractor(
    bin_file, sampling_frequency, numChan, data_type
)
recording_bin_cmr = st.preprocessing.common_reference(recording_bin)

In [None]:
recording_bin.get_traces(63)[0][0:3]

## Compare raw traces

In [None]:
channel = 1
start = 0
end = 500
scaleFactor = 8.423

In [None]:
111319410 / sampling_frequency

In [None]:
plt.figure()
plt.title("Traces agree (with arbitrary rescaling)")
# plt.plot(scaleFactor*multirecording.get_traces(channel_ids=channel, start_frame=start, end_frame=end)[0], label=f"ced channel {channel}")
plt.plot(
    scaleFactor
    * multirecording_cmr.get_traces(
        channel_ids=channel, start_frame=start, end_frame=end
    )[0],
    label=f"ced channel {channel}",
)
plt.plot(
    recording_bin.get_traces(channel_ids=channel, start_frame=start, end_frame=end)[0],
    label=f"bin channel {channel}",
)
plt.legend()
plt.show()

In [None]:
plt.figure()
plt.title("Trace difference")
plt.plot(
    scaleFactor
    * multirecording_cmr.get_traces(
        channel_ids=channel, start_frame=start, end_frame=end
    )[0]
    - recording_bin.get_traces(channel_ids=channel, start_frame=start, end_frame=end)[
        0
    ],
    label=f"channel {channel}",
)
plt.legend()
plt.show()

## Compare raw traces at concatenation location

In [None]:
channel = 12
# first concatenation at 111319410
start = 111319410 - 10
end = 111319410 + 10

In [None]:
plt.figure()
plt.title("Trace: bin concatenation skips a value")
plt.plot(
    scaleFactor
    * multirecording_cmr.get_traces(
        channel_ids=channel, start_frame=start, end_frame=end
    )[0],
    label=f"ced channel {channel}",
)
plt.plot(
    recording_bin.get_traces(channel_ids=channel, start_frame=start, end_frame=end)[0],
    label=f"bin channel {channel}",
)
plt.legend()
plt.show()

In [None]:
plt.rcParams["figure.figsize"] = (20, 20)

## Raw traces, all channels

In [None]:
w_ts_ap = sw.plot_timeseries(recording_bin, trange=[3920, 3980])

In [None]:
ks_params = ss.get_default_params("kilosort2_5")

In [None]:
ks_params

In [None]:
ss.run_kilosort2_5(multirecording_cmr, verbose=True, **ks_params)