In [None]:
%load_ext autoreload
%autoreload 2

# SpikeInterface pipeline for Mease Lab - CED

In [None]:
from pathlib import Path
from os import getenv
import numpy as np
from pprint import pprint
import matplotlib.pyplot as plt

In [None]:
import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.toolkit as st
import spikeinterface.sorters as ss
import spikeinterface.comparison as sc
import spikeinterface.widgets as sw
import spikeinterface.exporters as sx
import probeinterface as sp

In [None]:
import spikeextractors as oldse

## Bin recording

In [None]:
bin_file = Path(
    r"/mnt/sds-hd/sd19b001/PainData/Corrected_Channel_Map/L6/Cortex/20.8.21/KS2/m6.bin"
)
recording_prb = "cambridge_neurotech_H3.prb"
sampling_frequency = 3.003003003003003e04
data_type = "int16"
numChan = 64;

In [None]:
# Rhd channels should have already been selected thanks to smrx2bin
recording = se.BinaryRecordingExtractor(
    bin_file, sampling_frequency, numChan, data_type
)

In [None]:
# load probe file
probegroup = sp.read_prb(recording_prb)

In [None]:
# add probe file to recording
recording_bin = recording_bin.set_probegroup(probegroup)

In [None]:
print(f"Num channels: {recording_bin.get_num_channels()}")
print(f"Channel ids: {recording_bin.get_channel_ids()}")
print(f"Sampling rate: {recording_bin.get_sampling_frequency()}")
print(
    f"Duration (s): {recording_bin.get_num_frames() / recording_bin.get_sampling_frequency()}"
)

## Smrx recordings

In [None]:
smrx_files = sorted(
    list(
        Path(
            r"/mnt/sds-hd/sd19b001/PainData/Corrected_Channel_Map/L6/Cortex/20.8.21/KS2"
        ).glob("**/*.smrx")
    )
)

In [None]:
# Automatically select Rhd channels
channel_info = oldse.CEDRecordingExtractor.get_all_channels_info(smrx_files[0])

rhd_channels = []
for ch, info in channel_info.items():
    if "Rhd" in info["title"]:
        rhd_channels.append(ch)

In [None]:
recordings_ced = [se.read_ced(file, stream_id="1") for file in smrx_files]
recordings_ced

In [None]:
recordings_ced = [r.channel_slice(r.channel_ids[rhd_channels]) for r in recordings_ced]

In [None]:
recordings_ced

In [None]:
# recordings_ced = [oldse.CEDRecordingExtractor(file, smrx_channel_ids=rhd_channels) for file in smrx_files]

In [None]:
recording_ced = si.concatenate_recordings(recordings_ced)

In [None]:
# add probe file to recording
recording_ced = recording_ced.set_probegroup(probegroup)

In [None]:
print(f"Num channels: {recording_ced.get_num_channels()}")
print(f"Channel ids: {recording_ced.get_channel_ids()}")
print(f"Sampling rate: {recording_ced.get_sampling_frequency()}")
print(
    f"Duration (s): {recording_ced.get_num_frames() / recording_ced.get_sampling_frequency()}"
)

## CMR

In [None]:
recording_ced_cmr = st.preprocessing.common_reference(recording_ced)
recording_bin_cmr = st.preprocessing.common_reference(recording_bin)

## Inspect signals

In [None]:
start = 0
end = 500
channel = 2

plt.figure()
plt.title("Trace")
plt.plot(
    recording_bin.get_traces(start_frame=start, end_frame=end)[:, channel],
    label=f"bin channel {channel}",
)
plt.plot(
    recording_ced.get_traces(start_frame=start, end_frame=end)[:, channel],
    label=f"ced channel {channel}",
)
plt.legend()
plt.show()

In [None]:
plt.figure()
plt.title("Trace")
plt.plot(
    recording_ced.get_traces(start_frame=start, end_frame=end)[:, channel],
    label=f"ced channel {channel}",
)
plt.plot(
    recording_ced_cmr.get_traces(start_frame=start, end_frame=end)[:, channel],
    label=f"ced channel {channel} CMR",
)
plt.legend()
plt.show()

In [None]:
plt.figure()
plt.title("Difference")
plt.plot(
    recording_bin.get_traces(start_frame=start, end_frame=end)[:, channel]
    - recording_ced.get_traces(start_frame=start, end_frame=end)[:, channel],
    label=f"bin channel {channel} - ced channel {channel}",
)
plt.legend()
plt.show()

In [None]:
plt.figure()
plt.title("Difference")
plt.plot(
    recording_ced.get_traces(start_frame=start, end_frame=end)[:, channel]
    - recording_ced_cmr.get_traces(start_frame=start, end_frame=end)[:, channel],
    label=f"ced channel {channel} - ced channel {channel} CMR",
)
plt.legend()
plt.show()