In [None]:
import time, os, json, warnings
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import spikeinterface.core as si
import spikeinterface.extractors as se
import probeinterface as pi
from probeinterface.plotting import plot_probe
import spikeinterface.preprocessing as spre
import spikeinterface.widgets as sw
from spikeinterface.sortingcomponents.motion import (
    correct_motion_on_peaks,
    interpolate_motion,
    estimate_motion,
)

In [None]:
from estimate_drift_motion import AP_band_drift_estimation, LFP_band_drift_estimation

In [None]:
def LFP_band_drift_estimation(group,raw_rec,oe_folder):
    lfprec = spre.bandpass_filter(raw_rec,freq_min=0.5,freq_max=250,margin_ms=1500.,filter_order=3,dtype="float32",add_reflect_padding=True)
    lfprec = spre.common_reference(lfprec,reference="global", operator="median")
    lfprec = spre.resample(lfprec, resample_rate=250, margin_ms=1000)
    lfprec = spre.average_across_direction(lfprec)
    fig0=plt.figure()
    ax=fig0.add_subplot(121)
    sw.plot_traces(lfprec, backend="matplotlib", mode="auto",ax=ax,time_range=(0, 1))
    #sw.plot_traces(lfprec, backend="matplotlib", mode="auto", ax=ax, clim=(-0.05, 0.05),time_range=(0, 20))
    motion_lfp = estimate_motion(lfprec, method='dredge_lfp', rigid=True, progress_bar=True)
    ax=fig0.add_subplot(122)
    sw.plot_motion(motion_lfp, mode='line', ax=ax)
    motion_folder = oe_folder / f"lfp_motion_shank{group}"
    if Path(motion_folder).is_dir():
        pass
    else:
        motion_folder.mkdir(parents=True, exist_ok=True)
    fig0.savefig(motion_folder / "dredge_lfp.png")
    plt.show()
    return motion_lfp

In [None]:
#thisDir = r"Z:\DATA\experiment_openEphys\H-series-128channels\2025-03-23_21-33-38"
thisDir = r"Z:\DATA\experiment_openEphys\H-series-128channels\2025-03-23_20-47-26"
json_file = "./analysis_methods_dictionary.json"

In [None]:
def load_data(oe_folder,analysis_methods):
    this_experimenter = analysis_methods.get("experimenter")
    probe_type = analysis_methods.get("probe_type")
    if (
        analysis_methods.get("load_prepocessed_file") == True
        and (oe_folder / "preprocessed_compressed.zarr").is_dir()
    ):
        recording_saved = si.read_zarr(oe_folder / "preprocessed_compressed.zarr")
        print(recording_saved.get_property_keys())
        fs = recording_saved.get_sampling_frequency()
        data_to_load=recording_saved 
    elif (
        analysis_methods.get("load_prepocessed_file") == True
        and (oe_folder / "preprocessed").is_dir()
    ):
        print(
            "Looks like you do not have compressed files. Read the original instead"
        )
        recording_saved = si.load_extractor(oe_folder / "preprocessed")
        fs = recording_saved.get_sampling_frequency()
        data_to_load=recording_saved 
    else:
        print("Load meta information from openEphys")

        raw_rec = se.read_openephys(oe_folder, load_sync_timestamps=True)
        # To show the start of recording time
        # raw_rec.get_times()[0]
        event = se.read_openephys_event(oe_folder)
        # event_channel_ids=channel_ids
        # events = event.get_events(channel_id=channel_ids[1], segment_index=0)# a complete record of events including [('time', '<f8'), ('duration', '<f8'), ('label', '<U100')]
        events_times = event.get_event_times(
            channel_id=event.channel_ids[1], segment_index=0
        )  # this record ON phase of sync pulse
        fs = raw_rec.get_sampling_frequency()
        if analysis_methods.get("load_raw_traces") == True:
            trace_snippet = raw_rec.get_traces(
                start_frame=int(fs * 0), end_frame=int(fs * 2)
            )

        ################load probe information################
        if probe_type == "P2":
            manufacturer = "cambridgeneurotech"
            probe_name = "ASSY-37-P-2"
            probe = pi.get_probe(manufacturer, probe_name)
            print(probe)
            probe.wiring_to_device("ASSY-116>RHD2132")
            probe.to_dataframe(complete=True).loc[
                :, ["contact_ids", "shank_ids", "device_channel_indices"]
            ]
        elif probe_type == "H10_stacked":
            stacked_probes = pi.read_probeinterface("H10_stacked_probes.json")
            probe = stacked_probes.probes[0]
        else:
            print("the name of probe not identified. stop the programme")
            exit()
            
        # drop AUX channels here
        raw_rec = raw_rec.set_probe(probe,group_mode='by_shank')
        probe_rec = raw_rec.get_probe()
        probe_rec.to_dataframe(complete=True).loc[
            :, ["contact_ids", "device_channel_indices"]
        ]

        raw_rec.annotate(
            description=f"Dataset of {this_experimenter}"
        )
        data_to_load=raw_rec
    
    return data_to_load

In [None]:
oe_folder = Path(thisDir)
if isinstance(json_file, dict):
    analysis_methods = json_file
else:
    with open(json_file, "r") as f:
        print(f"load analysis methods from file {json_file}")
        analysis_methods = json.loads(f.read())

motion_corrector = analysis_methods.get("motion_corrector")
n_cpus = os.cpu_count()
n_jobs = n_cpus - 4
job_kwargs = dict(n_jobs=n_jobs, chunk_duration="1s", progress_bar=True)


In [None]:
analysis_methods.update({"load_prepocessed_file": False})
if 'raw_rec' in locals():
    pass
else:
    raw_rec=load_data(oe_folder,analysis_methods)
raw_rec = spre.astype(raw_rec, np.float32)
raw_rec = spre.depth_order(raw_rec)
raw_rec_dict = raw_rec.split_by(property='group', outputs='dict')
motion_lfp_dict={}
for group, rec_per_shank in raw_rec_dict.items():
    motion_lfp=LFP_band_drift_estimation(group,rec_per_shank,oe_folder)
    motion_lfp_dict[group]=motion_lfp

In [None]:
analysis_methods.update({"load_prepocessed_file": True})
if (analysis_methods.get("load_prepocessed_file") == True) and (oe_folder / "preprocessed_compressed.zarr").is_dir():
    recording_saved=load_data(oe_folder,analysis_methods)
    fs = recording_saved.get_sampling_frequency()
else:
    if 'raw_rec' in locals():
        pass
    else:
        raw_rec=load_data(oe_folder,analysis_methods)
    fs = raw_rec.get_sampling_frequency()
    recording_f = spre.bandpass_filter(raw_rec, freq_min=600, freq_max=6000)
    if analysis_methods.get("analyse_good_channels_only") == True:
        """
        This step should be done before saving preprocessed files because ideally the preprocessed file we want to create is something ready for spiking
        detection, which means neural traces gone through bandpass filter and common reference.
        However, applying common reference takes signals from channels of interest which requires us to decide what we want to do with other bad or noisy channels first.
        """
        bad_channel_ids, channel_labels = spre.detect_bad_channels(
            recording_f, method="coherence+psd"
        )  # bad_channel_ids=np.array(['CH1','CH2','CH3','CH4','CH5','CH6','CH7','CH8','CH9','CH10','CH11','CH12','CH13','CH14','CH15','CH16'],dtype='<U64')
        #bad channel ids in ['CH2' 'CH22' 'CH25' 'CH27' 'CH29' 'CH30' 'CH31' 'CH32' 'CH33' 'CH34''CH35' 'CH36' 'CH37' 'CH38' 'CH39' 'CH41' 'CH43' 'CH80' 'CH112'] in 2025-03-19_18-02-13"
        print("bad_channel_ids", bad_channel_ids)
        print("channel_labels", channel_labels)

        recording_f = recording_f.remove_channels(
            bad_channel_ids
        )  # need to check if I can do this online
                    ##not sure if I should apply CAR by shank by shank
    recording_cmr = spre.common_reference(
            recording_f, reference="global", operator="median"
        )
if "recording_cmr" in locals():
    rec_of_interest = recording_cmr
else:
    rec_of_interest = recording_saved
    rec_of_interest.annotate(
        is_filtered=True
    )  # needed to add this somehow because when loading a preprocessed data saved in the past, that data would not be labeled as filtered data
# Slice the recording if needed
if analysis_methods.get("analyse_entire_recording") == False:
    start_sec = 1
    end_sec = 899
    rec_of_interest = rec_of_interest.frame_slice(
        start_frame=start_sec * fs, end_frame=end_sec * fs
    )

In [None]:
recordings_dict = rec_of_interest.split_by(property='group', outputs='dict')
win_um=100    
recording_corrected_dict = {}
motion_ap_dict={}
for group, sub_recording in recordings_dict.items():
    recording_corrected,motion_ap_list=AP_band_drift_estimation(group,sub_recording,oe_folder,analysis_methods,win_um,job_kwargs)
    recording_corrected_dict[group]=recording_corrected
    motion_ap_dict[group]=motion_ap_list