In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import h5py
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import analysis_tools
from dredge import motion_util
import seaborn as sns
import dartsort
import spikeinterface.full as si

In [None]:
%config InlineBackend.figure_format = 'retina'
import matplotlib.pyplot as plt
from matplotlib.markers import MarkerStyle
from matplotlib.transforms import offset_copy
from matplotlib.patches import Ellipse, Rectangle, ConnectionPatch
from matplotlib.lines import Line2D
from matplotlib.legend_handler import HandlerTuple
import contextlib
import colorcet as cc

plt.rc("figure", dpi=300)
plt.rc("figure", figsize=(4, 4))
SMALL_SIZE = 5
MEDIUM_SIZE = 7
BIGGER_SIZE =  8
plt.rc('font', size=SMALL_SIZE)
plt.rc('axes', titlesize=MEDIUM_SIZE)
plt.rc('axes', labelsize=SMALL_SIZE)
plt.rc('xtick', labelsize=SMALL_SIZE)
plt.rc('ytick', labelsize=SMALL_SIZE)
plt.rc('legend', fontsize=SMALL_SIZE)
plt.rc('figure', titlesize=BIGGER_SIZE)
# plt.rc('text.latex', preamble=preamble)
plt.rc('svg', fonttype='none')
plt.rc('ps', usedistiller='xpdf')
plt.rc('pdf', fonttype=42)


# plt.rcParams.update({
#     "text.usetex": True,
#     # "font.family": "serif",
# })
# preamble = r"""
# \renewcommand{\familydefault}{\sfdefault}
# \usepackage[scaled=1]{helvet}
# \usepackage[helvet]{sfmath}
# \usepackage{textgreek}
# """
# plt.rc('text.latex', preamble=preamble)

In [None]:
measims_dir = Path("~/proj/measims/recordings").expanduser()
measims = list(measims_dir.glob("*.h5"))
scaleprocs_dir = Path("~/proj/measims/scaleproc_recordings").expanduser()
darts_dir = Path("~/proj/measims/dart_recordings").expanduser()
rez_dir = Path("~/proj/measims/ksrez").expanduser()
# dart_out_dir = Path("~/proj/measims/dart_out").expanduser()
# dart_out_dir = Path("~/proj/measims/dartnonn_out").expanduser()
dart_out_dir = Path("~/proj/measims/dartnonn4_out").expanduser()

In [None]:
vis_dir = measims_dir.parent / "vis"
vis_dir.mkdir(exist_ok=True)

In [None]:
heatmaps_dir = vis_dir / "heatmaps"
heatmaps_dir.mkdir(exist_ok=True)

In [None]:
for p in measims:
    print(p)
    dart_dir = dart_out_dir / p.stem
    subtraction_h5 = dart_dir / "subtraction.h5"
    rec = si.read_binary_folder(darts_dir / p.stem)
    dartsort.estimate_motion(
        rec,
        dartsort.DARTsortSorting.from_peeling_hdf5(subtraction_h5),
        # output_directory=None,
        output_directory=dart_dir,
        overwrite=True,
        temporal_bin_length_s=2.05,
    )

In [None]:
gt_dispmaps = {}
ks_dispmaps = {}
dart_dispmaps = {}
ks_errs = {}
dart_errs = {}

for p in measims:
    print(p)
    # if ("modulated" not in p.stem) or ("bumps" not in p.stem):
    #     continue
    dart_dir = dart_out_dir / p.stem
    subtraction_h5 = dart_dir / "thresholding.h5"
    rec = si.read_binary_folder(darts_dir / p.stem)
    dredgeme2 = dartsort.estimate_motion(
        rec,
        dartsort.DARTsortSorting.from_peeling_hdf5(subtraction_h5),
        output_directory=dart_dir,
        temporal_bin_length_s=2.05,
        # spatial_bin_length_um=5.0,
        amplitudes_dataset_name='ptp_amplitudes',
        # correlation_threshold=0.4,
        # max_disp_um=75.0,
        overwrite=True,
    )

    dart_dir = dart_out_dir / p.stem
    ks_dir = rez_dir / p.stem

    rez2_mat = ks_dir / "rez2.mat"
    # subtraction_h5 = dart_dir / "subtraction.h5"
    subtraction_h5 = dart_dir / "thresholding.h5"
    motion_est_pkl = dart_dir / "motion_est.pkl"
    if not motion_est_pkl.exists():
        print("Not done")
        continue

    gtme, gt_dispmaps[p.stem] = analysis_tools.dispmap_from_mearec(p)
    ksme, ks_dispmaps[p.stem] = analysis_tools.dispmap_from_ks(gt_dispmaps[p.stem], rez2_mat, fs=32000)
    dredgeme, dart_dispmaps[p.stem] = analysis_tools.dispmap_from_dredge(gt_dispmaps[p.stem], motion_est_pkl)

    gtsp = analysis_tools.spikes_from_mearec(p, gtme)
    kssp = analysis_tools.spikes_from_ks(rez2_mat, fs=32000)
    dredgesp = analysis_tools.spikes_from_dredge(subtraction_h5)

    fig, axes = plt.subplots(ncols=3, nrows=3, figsize=(10, 10), layout="constrained", sharex=True, sharey=True)

    motion_util.show_spike_raster(*gtsp, ax=axes[0, 0], aspect="auto", vmax=15)
    motion_util.show_spike_raster(*kssp, ax=axes[0, 1], aspect="auto", vmax=15)
    motion_util.plot_me_traces(ksme, ax=axes[0, 1], color="r", lw=1)
    motion_util.show_spike_raster(*dredgesp, ax=axes[0, 2], aspect="auto", vmax=15)
    motion_util.plot_me_traces(dredgeme, ax=axes[0, 2], color="r", lw=1)
    # motion_util.plot_me_traces(dredgeme2, ax=axes[0, 2], color="w", lw=1)
    # motion_util.plot_me_traces(dredgeme2s[p.stem], ax=axes[0, 2], color="orange", lw=1)

    vm = max(
        np.abs(gt_dispmaps[p.stem].displacement).max(),
        np.abs(ks_dispmaps[p.stem].displacement).max(),
        np.abs(dart_dispmaps[p.stem].displacement).max(),
    )
    ks_offset = np.mean(gt_dispmaps[p.stem].displacement - ks_dispmaps[p.stem].displacement)
    dart_offset = np.mean(gt_dispmaps[p.stem].displacement - dart_dispmaps[p.stem].displacement)
    ksd = motion_util.get_motion_estimate(
        ks_dispmaps[p.stem].displacement + ks_offset,
        time_bin_centers_s=gt_dispmaps[p.stem].time_bin_centers_s,
        spatial_bin_centers_um=gt_dispmaps[p.stem].spatial_bin_centers_um,
    )
    dartd = motion_util.get_motion_estimate(
        dart_dispmaps[p.stem].displacement + dart_offset,
        time_bin_centers_s=gt_dispmaps[p.stem].time_bin_centers_s,
        spatial_bin_centers_um=gt_dispmaps[p.stem].spatial_bin_centers_um,
    )
    kw = dict(aspect="auto", cmap=plt.cm.seismic, vmin=-vm, vmax=vm)
    motion_util.show_displacement_heatmap(gt_dispmaps[p.stem], ax=axes[1, 0], **kw)
    motion_util.show_displacement_heatmap(ksd, ax=axes[1, 1], **kw)
    motion_util.show_displacement_heatmap(dartd, ax=axes[1, 2], **kw)

    axes[2, 0].axis("off")
    ks_err = motion_util.get_motion_estimate(
        gt_dispmaps[p.stem].displacement - ks_dispmaps[p.stem].displacement - ks_offset,
        time_bin_centers_s=gt_dispmaps[p.stem].time_bin_centers_s,
        spatial_bin_centers_um=gt_dispmaps[p.stem].spatial_bin_centers_um,
    )
    ks_errs[p.stem] = ks_err
    dart_err = motion_util.get_motion_estimate(
        gt_dispmaps[p.stem].displacement - dart_dispmaps[p.stem].displacement - dart_offset,
        time_bin_centers_s=gt_dispmaps[p.stem].time_bin_centers_s,
        spatial_bin_centers_um=gt_dispmaps[p.stem].spatial_bin_centers_um,
    )
    dart_errs[p.stem] = dart_err
    motion_util.show_displacement_heatmap(ks_err, ax=axes[2, 1], **kw)
    motion_util.show_displacement_heatmap(dart_err, ax=axes[2, 2], **kw)

    axes[0, 0].set_title("GT localizations")
    axes[0, 1].set_title("KS localizations")
    axes[0, 2].set_title("DREDge localizations")
    axes[1, 0].set_title("GT dispmap")
    axes[1, 1].set_title("KS dispmap")
    axes[1, 2].set_title("DREDge dispmap")
    axes[2, 1].set_title(f"KS MSE={np.square(ks_err.displacement).mean():0.3f}")
    axes[2, 2].set_title(f"DREDge MSE={np.square(dart_err.displacement).mean():0.3f}")
    print(f"{dart_offset=} {ks_offset=}")

    fig.suptitle(p.stem)
    fig.savefig(heatmaps_dir / f"{p.stem}.png", dpi=200)
    plt.show()
    plt.close(fig)

In [None]:
depths = []
times = []
probes = []
drift_types = []
cell_position_dists = []
firing_types = []
dart_sqerrs = []
ks_sqerrs = []

for p in ks_errs:
    print(p)
    probe, drift_type, cell_position_dist, firing_type = p.split("_")

    ks_err = ks_errs[p]
    dart_err = dart_errs[p]

    depth, time = np.meshgrid(ks_err.spatial_bin_centers_um, ks_err.time_bin_centers_s, indexing="ij")
    nij = depth.size

    depths.append(depth.ravel())
    times.append(time.ravel())
    probes.append(np.full(nij, probe))
    drift_types.append(np.full(nij, drift_type))
    cell_position_dists.append(np.full(nij, cell_position_dist))
    firing_types.append(np.full(nij, firing_type))
    dart_sqerrs.append(np.square(dart_err.displacement).ravel())
    ks_sqerrs.append(np.square(ks_err.displacement).ravel())

dartdf = dict(
    depth=np.concatenate(depths),
    time=np.concatenate(times),
    probe=np.concatenate(probes),
    drift_type=np.concatenate(drift_types),
    cell_position_dist=np.concatenate(cell_position_dists),
    firing_type=np.concatenate(firing_types),
    rmse=np.sqrt(np.concatenate(dart_sqerrs)),
    algorithm="DREDge",
)
dartdf = pd.DataFrame(dartdf)
ksdf = dict(
    depth=np.concatenate(depths),
    time=np.concatenate(times),
    probe=np.concatenate(probes),
    drift_type=np.concatenate(drift_types),
    cell_position_dist=np.concatenate(cell_position_dists),
    firing_type=np.concatenate(firing_types),
    sqerr=np.concatenate(ks_sqerrs),
    rmse=np.sqrt(np.concatenate(ks_sqerrs)),
    algorithm="KS"
)
ksdf = pd.DataFrame(ksdf)
df = pd.concat([dartdf, ksdf])

In [None]:
df.query("cell_position_dist=='uniform'")

In [None]:
sns.relplot(
    df.query("cell_position_dist=='uniform'"),
    hue="algorithm",
    x="depth",
    y="rmse",
    col="drift_type",
    row="firing_type",
    kind="line",
    errorbar=("pi", 80),
    height=2,
)
plt.gcf().suptitle("Uniform cell layout", y=1.03)

In [None]:
sns.relplot(
    df.query("cell_position_dist=='bimodal'"),
    hue="algorithm",
    x="depth",
    y="rmse",
    col="drift_type",
    row="firing_type",
    kind="line",
    errorbar=("pi", 90),
    height=2,
)
plt.gcf().suptitle("Bimodal cell layout", y=1.03)

In [None]:
sns.relplot(
    df,
    hue="algorithm",
    x="depth",
    y="rmse",
    # col="drift_type",
    # row="firing_type",
    kind="line",
    errorbar=("pi", 90),
    height=2,
)
plt.gcf().suptitle("Global", y=1.03)

In [None]:
ks_err.displacement.shape

In [None]:
uu, vv = np.meshgrid(ks_err.spatial_bin_centers_um, ks_err.time_bin_centers_s, indexing="ij")

In [None]:
uu.shape

In [None]:
p = [p for p in measims if "bumps" in p.stem][0]

In [None]:
# determine shapes
n_disps = gt_dispmaps[p.stem].displacement.size
n_recs = len(gt_dispmaps)

In [None]:
for p in measims:
    gt_disp = gt_d

In [None]:
with h5py.File(ks_dir / "rez2.mat") as h5:
    rez = h5["rez"]
    for k, v in rez.items():
        print(k, v)
    dshift = rez["dshift"][:]
    yc = rez["ycoords"][:]
    st0 = rez["st0"][:]

In [None]:
with h5py.File(p) as h5:

    for k, v in h5["drift_list"].items():
        print(k, v)
    
    for k, v in h5.items():
        print(k, v)
    g = h5["drift_list/1"]
    print("drift--------------")
    for k, v in g.items():
        print(k, v)
    drift_fs = g["drift_fs"][()]
    drift_factors = g["drift_factors"][()]
    drift_times = g["drift_times"][()]
    drift_vector_idxs = g["drift_vector_idxs"][()]
    drift_vector_um = g["drift_vector_um"][()]

    gg = h5["spiketrains"]
    st = h5["spiketrains/0"]
    for k, v in st.items():
        print(k, v)
    st = h5["spiketrains/0/times"][:]
    vp = h5["voltage_peaks"][:]

    temp_locs = h5["template_locations"][:]
    geom = h5["channel_positions"][:]
    ttt = h5["timestamps"][:]

In [None]:
plt.imshow(vp)

In [None]:
st

In [None]:
st

In [None]:
drift_factors.shape

In [None]:
drift_vector_idxs.shape

In [None]:
drift_vector_um.shape

In [None]:
ttt.min(), ttt.max()

In [None]:
np.arange(0, ttt.max(), 1 / drift_fs).shape

In [None]:
temp_locs.shape

In [None]:
temp_locs[:, 50, 2].min()
temp_locs[:, 50, 2].max()

In [None]:
geom[:, 2]

In [None]:
p

In [None]:
drift_times

In [None]:
drift_fs

In [None]:
drift_vector_idxs.shape

In [None]:
drift_factors

In [None]:
plt.plot(drift_factors);

In [None]:
np.unique(drift_vector_idxs)

In [None]:
drift_vector_um

In [None]:
plt.plot(drift_vector_um);

In [None]:
np.unique(drift_vector_um)