In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import dartsort
import numpy as np
import dartsort.vis as dartvis
import matplotlib.pyplot as plt
from pathlib import Path
import h5py
import torch
import torch.nn.functional as F
from sklearn.decomposition import PCA, TruncatedSVD
import spikeinterface.full as si
from dartsort.config import *
from dartsort.cluster import initial, density
import dataclasses
from dartsort.util import drift_util, data_util
import warnings
from tqdm.auto import trange, tqdm
from scipy.stats import chi2
# from ephysx import spike_gmm, spike_lrgmm, spike_basic, ppca
from ephysx import spike_basic, spike_interp
from matplotlib import colors
import seaborn as sns
from scipy.cluster.hierarchy import linkage, fcluster
from dredge import motion_util
import os

In [3]:
%config InlineBackend.figure_format = 'retina'
import matplotlib.pyplot as plt
from matplotlib.markers import MarkerStyle
from matplotlib.transforms import offset_copy
from matplotlib.patches import Ellipse, Rectangle, ConnectionPatch
from matplotlib.lines import Line2D
from matplotlib.legend_handler import HandlerTuple
import contextlib
import colorcet as cc

plt.rc("figure", dpi=300)
plt.rc("figure", figsize=(7, 4))
SMALL_SIZE = 5
MEDIUM_SIZE = 7
BIGGER_SIZE =  8
plt.rc('font', size=SMALL_SIZE)
plt.rc('axes', titlesize=MEDIUM_SIZE)
plt.rc('axes', labelsize=SMALL_SIZE)
plt.rc('xtick', labelsize=SMALL_SIZE)
plt.rc('ytick', labelsize=SMALL_SIZE)
plt.rc('legend', fontsize=SMALL_SIZE)
plt.rc('figure', titlesize=BIGGER_SIZE)

@contextlib.contextmanager
def subplots(*args, **kwargs):
    fig, axes = plt.subplots(*args, **kwargs)
    try:
        yield fig, axes
    finally:
        plt.show()
        plt.close(fig)

In [4]:
def withgc(function):
    def wrapper(*args, **kwargs):
        try:
            return function(*args, **kwargs)
        finally:
            import gc; gc.collect()
    return wrapper

In [5]:
# global
model_radius = 15.0
cfg = DARTsortConfig(
    matching_iterations=2,
    subtraction_config=SubtractionConfig(
        detection_thresholds=(12, 9, 6, 5, 4),
        extract_radius=75.0,
        max_waveforms_fit=20_000,
        subtraction_denoising_config=FeaturizationConfig(
            denoise_only=True,
            input_waveforms_name="raw",
            output_waveforms_name="subtracted",
            tpca_fit_radius=model_radius,
            tpca_centered=False,
        ),
        residnorm_decrease_threshold=20.0,
    ),
    matching_config=MatchingConfig(
        threshold=2500.0,
        max_waveforms_fit=20_000,
        extract_radius=75.0,
    ),
    template_config=TemplateConfig(
        denoising_fit_radius=model_radius,
        denoising_snr_threshold=100.0,
        superres_templates=False,
    ),
    clustering_config=ClusteringConfig(
        cluster_strategy="density_peaks",
        sigma_regional=25.0,
        noise_density=1.0,
        ensemble_strategy=None,
        remove_duplicates=False,
        remove_big_units=False,
        remove_clusters_smaller_than=50,
    ),
    split_merge_config=SplitMergeConfig(
        min_spatial_cosine=0.0,
        linkage="single",
        # linkage="weighted_template",
        split_strategy_kwargs=dict(
            channel_selection_radius=model_radius,
            max_spikes=10_000,
        ),
        merge_template_config=TemplateConfig(
            denoising_fit_radius=model_radius,
            denoising_snr_threshold=100.0,
            superres_templates=False,
        )
    ),
    featurization_config=FeaturizationConfig(
        tpca_fit_radius=model_radius,
        localization_radius=50.0,
        localization_model="dipole",
        tpca_centered=False,
        input_tpca_projs_temporal_slice=slice(20, 81),
    ),
    motion_estimation_config=MotionEstimationConfig(
        max_dt_s=1000,
        window_scale_um=250,
        window_step_um=75,
        window_margin_um=-150,
        # min_amplitude=15.0,
    ),
)

In [6]:
analysis_tcfg = dartsort.TemplateConfig(superres_templates=False, realign_peaks=False, denoising_snr_threshold=100.0, denoising_fit_radius=25.0)
analysis_tpca = None
def get_analysis(labels, keepers=None, base_sorting=None):
    if base_sorting is None:
        base_sorting=ref_clust
    global analysis_tpca
    if analysis_tpca is None:
        analysis_tpca = dartsort.templates.get_templates.fit_tsvd(rec, base_sorting)

    if keepers is not None:
        labels_ = np.full_like(base_sorting.labels, -1)
        labels_[keepers] = labels
        labels = labels_
    sorting = dataclasses.replace(
        base_sorting,
        labels=labels,
    )
    return dartsort.DARTsortAnalysis.from_sorting(
        rec,
        sorting,
        motion_est=motion_est,
        allow_template_reload=False,
        template_config=analysis_tcfg,
        denoising_tsvd=analysis_tpca,
        device="cpu",
    )

In [7]:
@withgc
def makeplots(
    subdir,
    gmm=None,
    sorting=None,
    n_jobs=0,
    with_summaries=True,
    with_dpcs=True,
    with_over_time=False,
    overwrite=False,
    dpc_par=False,
):
    a0 = None
    if gmm is None:
        analysis_kw = dict(labels=sorting.labels, base_sorting=sorting)
    else:
        analysis_kw = dict(labels=gmm.labels, keepers=gmm.data.keepers, base_sorting=sorting)
        device = gmm.device
        gmm.cpu()
    def ga(a0):
        if a0 is None:
            return get_analysis(**analysis_kw)
        return a0
    subdir.mkdir(exist_ok=True)

    print("scatters...")
    if overwrite or not (subdir / "sorting.png").exists():
        a0 = ga(a0)
        fig = plt.figure(figsize=(15, 15))
        fig = dartvis.make_sorting_summary(a0, figure=fig);
        fig.savefig(subdir / "sorting.png", dpi=200)
        plt.close(fig)

    if overwrite or not (subdir / "scatter.png").exists():
        a0 = ga(a0)
        fig = plt.figure(figsize=(15, 15))
        fig, ax, ss = dartvis.scatter_spike_features(sorting=a0.sorting, show_triaged=False, figure=fig, width_ratios=[1, 1, 1]);
        motion_util.plot_me_traces(motion_est, ax=ax[-1], color="r", lw=1)
        fig.savefig(subdir / "scatter.png", dpi=200)
        plt.close(fig)

    if overwrite or not (subdir / "scatter_reg.png").exists():
        a0 = ga(a0)
        fig = plt.figure(figsize=(15, 15))
        fig, ax, ss = dartvis.scatter_spike_features(
            sorting=a0.sorting, motion_est=motion_est, registered=True, show_triaged=False, figure=fig, width_ratios=[1, 1, 1]
        );
        fig.savefig(subdir / "scatter_reg.png", dpi=200)
        plt.close(fig)

    if overwrite or not (subdir / "scatter_regt.png").exists():
        a0 = ga(a0)
        fig = plt.figure(figsize=(15, 15))
        fig, ax, ss = dartvis.scatter_spike_features(
            sorting=a0.sorting, motion_est=motion_est, registered=True, show_triaged=True, figure=fig, width_ratios=[1, 1, 1]
        );
        fig.savefig(subdir / "scatter_regt.png", dpi=200)
        plt.close(fig)
    print("scatters done")

    # if gmm is not None and (overwrite or not (subdir / "dists.png").exists()):
    #     fig, ax = plt.subplots(figsize=(8, 8))
    #     # fig, ax, ss = dartvis.scatter_spike_features(
    #     #     sorting=a0.sorting, motion_est=motion_est, registered=True, show_triaged=False, figure=fig, width_ratios=[1, 1, 1]
    #     # );
    #     dist_res = gmm.centroid_dists(centroid_only=merge_centroid_only)
    #     im = ax.imshow(dist_res["unexplained_var"], interpolation="none", vmin=0, vmax=1, cmap=plt.cm.rainbow)
    #     plt.colorbar(im, ax=ax, shrink=0.5)
    #     fig.savefig(subdir / "dists.png", dpi=200)
    #     plt.close(fig)

    # if gmm is not None and (overwrite or not (subdir / "likdists.png").exists()):
    #     fig, ax = plt.subplots(figsize=(8, 8))
    #     # fig, ax, ss = dartvis.scatter_spike_features(
    #     #     sorting=a0.sorting, motion_est=motion_est, registered=True, show_triaged=False, figure=fig, width_ratios=[1, 1, 1]
    #     # );
    #     dist_res = gmm.centroid_dists(kind="likelihood", centroid_only=merge_centroid_only)
    #     im = ax.imshow(dist_res["logliks"], interpolation="none", cmap=plt.cm.rainbow)
    #     plt.colorbar(im, ax=ax, shrink=0.5)
    #     fig.savefig(subdir / "likdists.png", dpi=200)
    #     plt.close(fig)

    if overwrite or not (subdir / "scatter_regt.png").exists():
        a0 = ga(a0)
        fig = plt.figure(figsize=(15, 15))
        fig, ax, ss = dartvis.scatter_spike_features(
            sorting=a0.sorting, motion_est=motion_est, registered=True, show_triaged=True, figure=fig, width_ratios=[1, 1, 1]
        );
        fig.savefig(subdir / "scatter_regt.png", dpi=200)
        plt.close(fig)
    print("scatters done")

    print("animation")
    if with_over_time and (overwrite or not (subdir / f"animation.mp4").exists()):
        a0 = ga(a0)
        dartvis.sorting_scatter_animation(
                a0,
                subdir / f"animation.mp4",
                chunk_length_samples=300 * rec.sampling_frequency,
                device="cpu",
            )
    print("animation done")

    print("over_time")
    if with_over_time and (
        overwrite
        or not (subdir / "over_time" / f"unit{a0.coarse_template_data.unit_ids.max():04d}.png").exists()
    ):
        a0 = ga(a0)
        dartvis.make_all_over_time_summaries(
            a0,
            subdir / "over_time",
            chunk_length_s=300.0,
            channel_show_radius_um=15.0,
            amplitude_color_cutoff=25.0,
            pca_radius_um=25.0,
            # max_height=18,
            figsize=(20, 20),
            dpi=200,
            image_ext="png",
            n_jobs=n_jobs,
            n_jobs_templates=0,
            show_progress=True,
            overwrite=overwrite,
            analysis_kw=dict(device="cpu"),
        )
    print("over_time done")

    print("dpcs")
    if with_dpcs and gmm is not None:
        dartvis.gmm.make_all_gmm_summaries(
            gmm,
            subdir / "gmm",
            show_progress=True,
            n_jobs=n_jobs if dpc_par else 0,
            overwrite=overwrite,
        )
    print("dpcs done")

    print("summaries")
    if with_summaries:
        a0 = ga(a0)
        dartvis.make_all_summaries(
            a0,
            subdir / "summaries",
            channel_show_radius_um=15,
            overwrite=overwrite,
            n_jobs=n_jobs,
        )
    print("summaries done")
    if gmm is not None:
        gmm.to(device)

In [8]:
ultra_root = Path("/home/charlie/scratch/NPultra_ImposedMotion")
ultra_out = Path("/home/charlie/scratch/uhd/NPultra_sorting")
figs_out = Path("/home/charlie/scratch/uhd/NPultra_sortingfigs")
ultra_ppx = Path("/home/charlie/scratch/uhd/NPultra_preprocessed")

# scratch_dir = Path(f"/scratch/{os.environ['USER']}/job_{os.environ['SLURM_JOBID']}")
scratch_dir = None
if scratch_dir and not scratch_dir.exists():
    scratch_dir = None

In [9]:
rec_orig_dir = ultra_root / "ZYE_0021" / "2021-05-01" / "1"

In [10]:
ap_bin = rec_orig_dir / "p1_g0_t0.imec0.ap.bin"

In [11]:
recid = "-".join(str(rec_orig_dir.relative_to(ultra_root)).split("/"))
out_dir = ultra_out / recid

In [12]:
ppx_dir = ultra_ppx / recid

In [13]:
fig_dir = figs_out / recid
fig_dir.mkdir(exist_ok=True, parents=True)

In [14]:
if not ppx_dir.exists():
    geom = np.load(rec_orig_dir / "channel_positions.npy")

    rec0 = si.read_binary(
        ap_bin,
        30_000,
        np.int16,
        num_channels=385,
    )
    print(rec0)
    rec0 = rec0.channel_slice(rec0.channel_ids[:len(geom)])
    rec0.set_dummy_probe_from_locations(geom)
    rec = si.highpass_filter(rec0, dtype="float64")
    rec = si.common_reference(rec)
    rec = si.zscore(rec, num_chunks_per_segment=100, mode="mean+std")
    print(rec)
    rec = rec.astype(np.float16)
    print(rec)

    fig = plt.figure()
    plt.imshow(rec0.get_traces(0, 0, 1000).T, aspect="auto")
    plt.colorbar()
    plt.show()
    plt.close(fig)

    fig = plt.figure()
    plt.imshow(rec.get_traces(0, 0, 1000).T, aspect="auto")
    plt.colorbar()
    plt.show()
    plt.close(fig)

    rec = rec.save_to_folder(ppx_dir, chunk_memory="40M", n_jobs=8)
rec = si.read_binary_folder(ppx_dir)
rec

BinaryFolderRecording: 384 channels - 30.0kHz - 1 segments - 121,209,192 samples 
                       4,040.31s (1.12 hours) - float16 dtype - 86.70 GiB

In [15]:
if scratch_dir is not None:
    if not (scratch_dir / ppx_dir.stem).exists():
        !rsync -avP {ppx_dir} {scratch_dir}/
    rec = si.read_binary_folder(scratch_dir / ppx_dir.stem)

In [16]:
# (out_dir / "subtraction").mkdir(parents=True, exist_ok=True)
# sub_st, sub_h5 = dartsort.subtract(
#     rec,
#     out_dir / "subtraction",
#     subtraction_config=cfg.subtraction_config,
#     featurization_config=cfg.featurization_config,
#     overwrite=True,
#     n_jobs=2,
# )

In [17]:
sub_st = dartsort.DARTsortSorting.from_peeling_hdf5(out_dir / "subtraction" / "subtraction.h5")

In [18]:
sub_st

DARTsortSorting: 331080 spikes, 1 unit. extra features: denoised_logpeaktotrough, denoised_peak_amplitude_vectors, denoised_ptp_amplitude_vectors, denoised_ptp_amplitudes, point_source_localizations, times_seconds. from parent h5 file /home/charlie/scratch/uhd/NPultra_sorting/ZYE_0021-2021-05-01-1/subtraction/subtraction.h5.

In [None]:
motion_est = dartsort.estimate_motion(
    rec,
    sub_st,
    sub_st.parent_h5_path.parent,
    **dataclasses.asdict(cfg.motion_estimation_config),
    overwrite=False,
)

In [20]:
# fig, ax = plt.subplots()
# dartvis.scatter_time_vs_depth(sorting=sub_st, geom=rec.get_channel_locations(), ax=ax, amplitude_color_cutoff=20)
# motion_util.plot_me_traces(motion_est, ax, color="r");

In [48]:
full_dpc = initial.initial_clustering(
    peeling_hdf5_filename=sub_st.parent_h5_path,
    clustering_config=cfg.clustering_config,
    recording=rec,
    sorting=sub_st,
    motion_est=motion_est,
)

In [49]:
full_dpc

DARTsortSorting: 331080 spikes, 35 units. extra features: denoised_logpeaktotrough, denoised_peak_amplitude_vectors, denoised_ptp_amplitude_vectors, denoised_ptp_amplitudes, point_source_localizations, times_seconds. from parent h5 file /home/charlie/scratch/uhd/NPultra_sorting/ZYE_0021-2021-05-01-1/subtraction/subtraction.h5.

In [50]:
# tpca = torch.load(full_dpc.parent_h5_path.parent / "subtraction_models" / "featurization_pipeline.pt").transformers[0]
# tpca.mean.shape, tpca.mean.abs().max()

In [51]:
# makeplots(
#     fig_dir / "aaa_000_fulldpc",
#     gmm=None,
#     sorting=full_dpc,
#     n_jobs=4,
#     with_summaries=True,
#     with_dpcs=True,
#     with_over_time=False,
#     overwrite=False,
# )

In [52]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fc0a8bdf990>

In [53]:
import gc; gc.collect()

24741

In [54]:
gmm = spike_interp.InterpClusterer(
    full_dpc,
    (0, rec.get_total_duration()),
    motion_est=motion_est,
    do_interp=False,
    sampling_method="random",
    split_sampling_method="time_amp_reweighted",
    fa_kwargs=dict(
        do_prior=True,
        lengthscale=250,
        latent_update="gradient",
        points_per_lengthscale=1,
    ),
    residual_pca_kwargs=dict(
        centered=False,
        atol=1e-3,
        max_iter=100,
        transform_iter=0,
        pca_on_waveform_channels=False,
    ),
    dpc_split_kwargs=spike_interp.DPCSplitKwargs(
        split_on_train=True,
        sigma_local="rule_of_thumb*0.5",
        radius_search=8.0,
    ),
    outlier_explained_var=0.5,
    scale_residual_embed=False,
    reassign_metric="1-r^2",
    merge_metric="1-r^2",
    fit_radius=20.0,
    waveform_radius=20.0,
    n_spikes_fit=4096,
    clustering_config=cfg.clustering_config,
)

In [102]:
gmm.models.clear()
gmm.labels = torch.tensor(full_dpc.labels[gmm.data.keepers]).to(gmm.labels)
gmm.cleanup()


In [103]:
# gmm.parcellate()

In [104]:
# gmm.unit_ids()

In [105]:
# import gc; gc.collect()

In [106]:
gmm.cuda();

In [107]:
gmm.m_step()

M step:   0%|          | 0/35 [00:00<?, ?it/s]

In [112]:
# g: splitem
# h: transform_iter
# i: fit_pca args same for testing
# j: fewer fit iters
# k: more fit iters, fewer transform iters
# l: even fewer transform iters
# m: no transform iters
# n: raise atol
# o: raise atol again
# p: drop max iter
# q: split on train
# r: raise max iter
# s: lower atol (now 1e-2)
# t: lower atol (now 1e-3)
# u: reweighted sampling
# v: back to split full with transform iters
# w: smaller fit radius
# x: bigger fit radius
# z: split on wf chans, rot*1.5, handle duplicates
# ba: pca not on wf chans. too many all nan cases.
# bb: to pandesk.
# bc: parcellate 0.5
# bd: bigger chan hoods w parcellation
# be: we have bigger things to fry than parcels 
# bf: lower sampling sigma
# bg: channels switchup
# bh: trying to do zipper
# bi: reassign was not maintaining zipper. try no scaling.
# bj: frame slice tpca
# bk: rule_of_thumb*0.5
# bl: assign at random for center, weighted for split, handle CIs right
# bm: search rad
# bn: add in outlier re-clustering
# bo: back to split on orig chans
# bp: no cleanup after recluster, no //2 in dpc, adding stuff to zipper vis
# bq: zipper with centroid criterion
tag = "abq"

In [108]:
import gc; gc.collect()

82911

In [77]:
makeplots(
    fig_dir / f"{tag}_000_init",
    gmm=gmm,
    sorting=full_dpc,
    n_jobs=2,
    with_summaries=True,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=False,
)

scatters...


Denoised templates:   0%|          | 0/35 [00:00<?, ?template/s]

q


Pairwise convolution:   0%|          | 0/25 [00:00<?, ?pair block/s]

scatters done
scatters done
animation
animation done
over_time
over_time done
dpcs


GMM summaries:   0%|          | 0/35 [00:00<?, ?it/s]

dpcs done
summaries


Unit summaries:   0%|          | 0/35 [00:00<?, ?it/s]

summaries done


In [109]:
gmm.residual_dpc_split()
gmm.recluster_outliers()
gmm.m_step()
gmm.residual_dpc_split()
gmm.recluster_outliers()
gmm.m_step()
gmm.residual_dpc_split()
(gmm.labels < 0).to(torch.float).mean()

Split round 0:   0%|          | 0/35 [00:00<?, ?it/s]

M step:   0%|          | 0/52 [00:00<?, ?it/s]

Split round 1:   0%|          | 0/37 [00:00<?, ?it/s]

M step:   0%|          | 0/23 [00:00<?, ?it/s]

Split round 2:   0%|          | 0/11 [00:00<?, ?it/s]

M step:   0%|          | 0/8 [00:00<?, ?it/s]

Split: 35 + (20+6+0) = 61.
Reclustering found 26 new clusters with spike counts from 63 to 22495. Outlier fraction: 45.0% -> 3.7%.


M step:   0%|          | 0/26 [00:00<?, ?it/s]

Split round 0:   0%|          | 0/87 [00:00<?, ?it/s]

M step:   0%|          | 0/53 [00:00<?, ?it/s]

Split round 1:   0%|          | 0/32 [00:00<?, ?it/s]

M step:   0%|          | 0/20 [00:00<?, ?it/s]

Split round 2:   0%|          | 0/8 [00:00<?, ?it/s]

M step:   0%|          | 0/6 [00:00<?, ?it/s]

Split round 3:   0%|          | 0/4 [00:00<?, ?it/s]

M step:   0%|          | 0/3 [00:00<?, ?it/s]

Split: 87 + (17+4+2+0) = 110.
Reclustering found 14 new clusters with spike counts from 64 to 19077. Outlier fraction: 34.8% -> 3.4%.


M step:   0%|          | 0/14 [00:00<?, ?it/s]

Split round 0:   0%|          | 0/124 [00:00<?, ?it/s]

M step:   0%|          | 0/43 [00:00<?, ?it/s]

Split round 1:   0%|          | 0/19 [00:00<?, ?it/s]

M step:   0%|          | 0/14 [00:00<?, ?it/s]

Split round 2:   0%|          | 0/6 [00:00<?, ?it/s]

M step:   0%|          | 0/3 [00:00<?, ?it/s]

Split: 124 + (10+3+0) = 137.


tensor(0.1956)

In [110]:
gmm.zipper_split()

Zipper split:   0%|          | 0/137 [00:00<?, ?it/s]

Zipper split broke off 26 new units.


M step:   0%|          | 0/50 [00:00<?, ?it/s]

In [113]:
gmm.residual_dpc_split()

Split round 0:   0%|          | 0/163 [00:00<?, ?it/s]

M step:   0%|          | 0/32 [00:00<?, ?it/s]

Split round 1:   0%|          | 0/8 [00:00<?, ?it/s]

M step:   0%|          | 0/6 [00:00<?, ?it/s]

Split: 163 + (4+0) = 167.


In [114]:
gmm.device

device(type='cpu')

In [116]:
makeplots(
    fig_dir / f"{tag}_002_dpcsplit_recluster_unzip",
    gmm=gmm,
    sorting=full_dpc,
    n_jobs=2,
    with_summaries=False,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=True,
)

scatters...


Denoised templates:   0%|          | 0/167 [00:00<?, ?template/s]

q


Pairwise convolution:   0%|          | 0/441 [00:00<?, ?pair block/s]

scatters done
scatters done
animation
animation done
over_time
over_time done
dpcs


GMM summaries:   0%|          | 0/167 [00:00<?, ?it/s]

  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  density = density / reg_density
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefi

// error in unit 73
Traceback (most recent call last):
  File "/home/charlie/spike-psvae/src/dartsort/vis/gmm.py", line 1081, in _summary_job
    make_unit_gmm_summary(
  File "/home/charlie/spike-psvae/src/dartsort/vis/gmm.py", line 945, in make_unit_gmm_summary
    figure = layout.flow_layout(
             ^^^^^^^^^^^^^^^^^^^
  File "/home/charlie/spike-psvae/src/dartsort/vis/layout.py", line 77, in flow_layout
    plot.draw(panel, **plot_kwargs)
  File "/home/charlie/spike-psvae/src/dartsort/vis/gmm.py", line 69, in draw
    dens = density.density_peaks_clustering(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/charlie/spike-psvae/src/dartsort/cluster/density.py", line 361, in density_peaks_clustering
    nhdn, distances, indices = nearest_higher_density_neighbor(
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/charlie/spike-psvae/src/dartsort/cluster/density.py", line 161, in nearest_higher_density_neighbor
    is_lower_density = density_pa

  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
  fig.savefig(tmp_out, dpi=_summary_job_contex

BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.

In [None]:
for j in range(10):
    gmm.reassign()
    gmm.m_step(force=True, fit_residual=j == 9)

In [None]:
makeplots(
    fig_dir / f"{tag}_003_em",
    gmm=gmm,
    sorting=full_dpc,
    n_jobs=2,
    with_summaries=False,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=True,
)

In [None]:
gmm.cpu();

In [None]:
makeplots(
    fig_dir / f"{tag}_000_zipper",
    gmm=gmm,
    sorting=full_dpc,
    n_jobs=2,
    with_summaries=False,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=False,
)

In [None]:
gmm.cuda();

In [None]:
import gc; gc.collect()

In [None]:
for _ in range(10):
    gmm.reassign()
    gmm.m_step(force=True)

In [None]:
gmm.cpu(); gc.collect();

In [None]:
makeplots(
    fig_dir / f"{tag}_000_zipreas10",
    gmm=gmm,
    sorting=full_dpc,
    n_jobs=2,
    with_summaries=True,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=False,
)

In [None]:
dpczipreas10labels = gmm.labels.numpy(force=True)
np.save("dpczipreas10labels.npy", dpczipreas10labels)

In [None]:
dpczipreas10labels = np.load("dpczipreas10labels.npy")

In [None]:
# let's try a global merge

In [None]:
full_labels = np.full_like(full_dpc.labels, -1)
full_labels[gmm.data.keepers] = dpczipreas10labels
sorting = dataclasses.replace(full_dpc, labels=full_labels)

In [None]:
merged_sorting = dartsort.merge.merge_templates(
    sorting,
    rec,
    template_config=cfg.split_merge_config.merge_template_config,
    motion_est=motion_est,
    sym_function=np.maximum,
)

In [None]:
gmm.models.clear()
globmerge_labels = merged_sorting.labels[gmm.data.keepers]
gmm.labels = torch.tensor(globmerge_labels).to(gmm.labels)
gmm.cleanup()

gmm.do_interp = True

gmm.cuda();

gmm.m_step()

gmm.cpu(); gc.collect()

In [None]:
makeplots(
    fig_dir / f"{tag}_001_u_globmerge",
    gmm=gmm,
    sorting=merged_sorting,
    n_jobs=2,
    with_summaries=True,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=False,
)

In [None]:
merged2_sorting = dartsort.merge.merge_templates(
    merged_sorting,
    rec,
    template_config=cfg.split_merge_config.merge_template_config,
    motion_est=motion_est,
    sym_function=np.maximum,
)

In [None]:
gmm.models.clear()
globmerge2_labels = merged2_sorting.labels[gmm.data.keepers]
gmm.labels = torch.tensor(globmerge2_labels).to(gmm.labels)
gmm.cleanup()
gmm.do_interp = True
# gmm.cuda();
gmm.m_step()
gmm.cpu(); gc.collect()

In [None]:
gc.collect()

In [None]:
makeplots(
    fig_dir / f"{tag}_001_u_globmerge2",
    gmm=gmm,
    sorting=merged_sorting,
    n_jobs=4,
    with_summaries=True,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=False,
)

In [None]:
np.save("globmerge2_labels.npy", globmerge2_labels)

In [None]:
globmerge2_labels = np.load("globmerge2_labels.npy")

In [None]:
# e/m over time
gmm.models.clear()
gmm.labels = torch.tensor(globmerge2_labels).to(gmm.labels)
gmm.cleanup()
gmm.do_interp = True
gmm.cuda();
gmm.m_step()
# gmm.cpu(); gc.collect()

In [None]:
for _ in range(3):
    gmm.reassign()
    gmm.m_step(force=True)

In [None]:
gmm.cpu(); gc.collect()

In [None]:
makeplots(
    fig_dir / f"{tag}_001_u_globmerge2_emtime",
    gmm=gmm,
    sorting=merged_sorting,
    n_jobs=4,
    with_summaries=True,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=False,
)

In [None]:
# e/m NOT over time
gmm.models.clear()
gmm.labels = torch.tensor(globmerge2_labels).to(gmm.labels)
gmm.cleanup()
gmm.do_interp = False
gmm.cuda();
gmm.m_step()
# gmm.cpu(); gc.collect()

In [None]:
for _ in range(3):
    gmm.reassign()
    gmm.m_step(force=True)

In [None]:
gmm.cpu(); gc.collect()

In [None]:
makeplots(
    fig_dir / f"{tag}_001_u_globmerge2_emnotime",
    gmm=gmm,
    sorting=merged_sorting,
    n_jobs=4,
    with_summaries=True,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=False,
)

In [None]:
# e/m NOT over time, but scaled
gmm.models.clear()
gmm.labels = torch.tensor(globmerge2_labels).to(gmm.labels)
gmm.cleanup()
gmm.do_interp = False
gmm.reassign_metric = "1-scaledr^2"
gmm.cuda();
gmm.m_step()
# gmm.cpu(); gc.collect()

In [None]:
for _ in range(3):
    gmm.reassign()
    gmm.m_step(force=True)

In [None]:
gmm.cpu(); gc.collect()

In [None]:
makeplots(
    fig_dir / f"{tag}_001_u_globmerge2_emnotimescaled",
    gmm=gmm,
    sorting=merged_sorting,
    n_jobs=4,
    with_summaries=True,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=False,
)

In [None]:
# e/m over time, higher rank
gmm.models.clear()
# gmm.unit_kw['fa_kwargs']['latent_dim'] = 1
# gmm.unit_kw['fa_kwargs']['do_prior'] = True
# gmm.unit_kw['fa_kwargs']['latent_update'] = 'gradient'
# gmm.unit_kw['fa_kwargs']['interp_kind'] = 'linear'
# gmm.unit_kw['fa_kwargs']['learn_prior_noise_fraction'] = True
gmm.labels = torch.tensor(globmerge2_labels).to(gmm.labels)
gmm.cleanup()
gmm.do_interp = False
gmm.reassign_metric = "1-r^2"
gmm.cuda();
gmm.m_step()
# gmm.cpu(); gc.collect()

In [None]:
gmm.residual_dpc_split()
gmm.zipper_split()

In [None]:
for _ in range(10):
    gmm.reassign()
    gmm.m_step(force=True)

In [None]:
makeplots(
    fig_dir / f"{tag}_001_u_globmerge2_split_em10",
    gmm=gmm,
    sorting=merged_sorting,
    n_jobs=4,
    with_summaries=True,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=True,
)

In [None]:
full_labels = np.full_like(full_dpc.labels, -1)
full_labels[gmm.data.keepers] = gmm.labels
merged_sorting = dataclasses.replace(full_dpc, labels=full_labels)

merged_sorting = dartsort.merge.merge_templates(
    merged_sorting,
    rec,
    template_config=cfg.split_merge_config.merge_template_config,
    motion_est=motion_est,
    sym_function=np.maximum,
)
merged_sorting = dartsort.merge.merge_templates(
    merged_sorting,
    rec,
    template_config=cfg.split_merge_config.merge_template_config,
    motion_est=motion_est,
    sym_function=np.maximum,
)

In [None]:
gmm.models.clear()
gmm.labels = torch.tensor(merged_sorting.labels[gmm.data.keepers]).to(gmm.labels)
gmm.cleanup()
gmm.do_interp = False
gmm.reassign_metric = "1-r^2"
gmm.cuda();
gmm.m_step()

In [None]:
gmm.cpu(); gc.collect()

In [None]:
makeplots(
    fig_dir / f"{tag}_001_u_globmerge2_split_em10_globmerge2",
    gmm=gmm,
    sorting=merged_sorting,
    n_jobs=4,
    with_summaries=True,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=True,
)

In [None]:
for _ in range(3):
    gmm.reassign()
    gmm.m_step(force=True)

In [None]:
gmm.cpu(); gc.collect()

In [None]:
makeplots(
    fig_dir / f"{tag}_001_u_globmerge2_emtimelinear",
    gmm=gmm,
    sorting=merged_sorting,
    n_jobs=4,
    with_summaries=True,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=True,
)

In [None]:
1

In [None]:
for _ in range(10):
    gmm.reassign()
    gmm.m_step(force=True)

In [None]:
makeplots(
    fig_dir / f"{tag}_000_zipreasmoremore",
    gmm=gmm,
    sorting=full_dpc,
    n_jobs=0,
    with_summaries=True,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=False,
)

In [None]:
gmm.residual_dpc_split()
gmm.zipper_split()
gmm.recluster_outliers()
gmm.m_step()
gmm.residual_dpc_split()
gmm.zipper_split()

In [None]:
import gc; gc.collect()

In [None]:
gmm.cpu();

In [None]:
gmm.device

In [None]:
gmm.cuda();

In [None]:
makeplots(
    fig_dir / f"{tag}_001_splitreclussplit",
    gmm=gmm,
    sorting=full_dpc,
    n_jobs=4,
    with_summaries=False,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=False,
)

In [None]:
for _ in range(3):
    gmm.reassign()
    gmm.m_step(force=True)

In [None]:
for _ in range(3):
    gmm.reassign()
    gmm.m_step(force=True)

In [None]:
for _ in range(3):
    gmm.reassign()
    gmm.m_step(force=True)

In [None]:
makeplots(
    fig_dir / f"{tag}_001_split_em",
    gmm=gmm,
    sorting=full_dpc,
    n_jobs=4,
    with_summaries=False,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=True,
)

In [None]:
gmm.residual_dpc_split()
gmm.recluster_outliers()
gmm.m_step()
gmm.residual_dpc_split()

In [None]:
makeplots(
    fig_dir / f"{tag}_002_split",
    gmm=gmm,
    sorting=full_dpc,
    n_jobs=4,
    with_summaries=True,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=False,
)

In [None]:
for _ in range(3):
    gmm.reassign()
    gmm.m_step(force=True)

In [None]:
for _ in range(3):
    gmm.reassign()
    gmm.m_step(force=True)

In [None]:
makeplots(
    fig_dir / f"{tag}_002_splitem",
    gmm=gmm,
    sorting=full_dpc,
    n_jobs=4,
    with_summaries=True,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=False,
)

In [None]:
dists = gmm.central_divergences()

In [None]:
dists

In [None]:
plt.imshow(dists, cmap=plt.cm.rainbow, vmin=0, vmax=1);
plt.colorbar();

In [None]:
d = np.minimum(dists, dists.T).numpy(force=True).astype(float)

In [None]:
(d < 0.25) + (d < 0.5)

In [None]:
sns.clustermap((d < 0.25) + (d < 0.5))

In [None]:
gmm.cuda()

In [None]:
gmm.do_interp = True
gmm.models.clear()
gmm.m_step()

In [None]:
makeplots(
    fig_dir / f"{tag}_002_splitem_overtime",
    gmm=gmm,
    sorting=full_dpc,
    n_jobs=4,
    with_summaries=True,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=False,
)

In [None]:
unit_id = 29

In [None]:
gmm.cpu();

In [None]:
gmm[unit_id].needs_fit = True

In [None]:
gmm.m_step()

In [None]:
_, in_unit0, z0 = gmm.split_features(unit_id)

In [None]:
z00 = gmm[unit_id].pca.train_loadings.numpy(force=True)

In [None]:
in_unit1, data1 = gmm.get_training_data(unit_id, in_unit=in_unit0)
waveforms2 = gmm[unit_id].to_unit_channels(
    waveforms=data1["waveforms"],
    times=data1["times"],
    waveform_channels=data1["waveform_channels"],
)
loadings1, mean1, components1, svs1 = spike_interp.fit_pcas(
    data1["waveforms"].reshape(in_unit1.numel(), -1),
    missing=None,
    empty=None,
    rank=gmm.dpc_split_kw.rank,
    show_progress=False,
)
z1 = loadings1.numpy(force=True)
loadings11, mean1, components1, svs1 = spike_interp.fit_pcas(
    data1["waveforms"].reshape(in_unit1.numel(), -1),
    missing=None,
    empty=None,
    rank=gmm.dpc_split_kw.rank,
    show_progress=False,
)
z11 = loadings11.numpy(force=True)
loadings2, mean2, components2, svs2 = spike_interp.fit_pcas(
    waveforms2.reshape(in_unit1.numel(), -1),
    missing=None,
    empty=None,
    rank=gmm.dpc_split_kw.rank,
    show_progress=False,
)
z2 = loadings2.numpy(force=True)

In [None]:
waveforms3 = gmm[unit_id].to_unit_channels(
    waveforms=data1["waveforms"],
    times=data1["times"],
    waveform_channels=data1["waveform_channels"],
    fill_mode="constant"
)

In [None]:
wfcs = torch.tensor(gmm.data.original_channel_index[gmm[unit_id].max_channel])
waveforms5 = gmm[unit_id].to_waveform_channels(
    waveforms_rel=waveforms3,
    waveform_channels=wfcs[None].broadcast_to(len(waveforms3), -1)
)

In [None]:
loadings3, mean3, components3, svs3 = spike_interp.fit_pcas(
    waveforms3.reshape(in_unit1.numel(), -1),
    missing=None,
    empty=None,
    rank=gmm.dpc_split_kw.rank,
    show_progress=False,
)
z3 = loadings3.numpy(force=True)

In [None]:
loadings5, mean5, components5, svs5 = spike_interp.fit_pcas(
    waveforms5.reshape(in_unit1.numel(), -1),
    missing=None,
    empty=None,
    rank=gmm.dpc_split_kw.rank,
    show_progress=False,
)
z5 = loadings5.numpy(force=True)

In [None]:
torch.equal(in_unit0, in_unit1)

In [None]:
data1["waveforms"].shape, waveforms2.shape

In [None]:
plt.scatter(*z0.T, s=3, lw=0)
plt.scatter(*z1.T, s=3, lw=0)
plt.scatter(*z2.T, s=3, lw=0)

In [None]:
plt.scatter(*z1.T, s=3, lw=0)
plt.scatter(*z11.T, s=3, lw=0)
plt.scatter(*z5.T, s=3, lw=0)

In [None]:
np.array_equal(z0, z00)

In [None]:
plt.scatter(*z0.T, s=3, lw=0)
plt.scatter(*z00.T, s=3, lw=0)
plt.scatter(*z3.T, s=3, lw=0)

In [None]:
waveforms5.shape

In [None]:
np.isnan(waveforms5).all(axis=(1,2)).sum()

In [None]:
nwf = len(waveforms5)

In [None]:
np.unique(waveforms5.reshape(nwf, -1), axis=0).shape

In [None]:
np.isnan(np.unique(waveforms5.reshape(nwf, -1), axis=0)).all(axis=1).sum()

In [None]:

(in_unit_full,) = (gmm.labels == unit_id).nonzero(as_tuple=True)
n = in_unit_full.numel()
features = torch.empty((n, gmm.residual_pca_rank), device=gmm.device)
rrel = torch.empty((n, gmm[unit_id].waveform_rank, gmm[unit_id].n_chans_unit), device=gmm.device)
unit = gmm[unit_id]
for sl, data in gmm.batches(in_unit_full):
    unit.residual_embed(**data,  out=features[sl])
    unit.residuals_rel(**data, out=rrel[sl])
wfcs = torch.tensor(gmm.data.reassign_channel_index[gmm[unit_id].max_channel])
rrel = gmm[unit_id].to_waveform_channels(
    waveforms_rel=rrel,
    waveform_channels=wfcs[None].broadcast_to(len(rrel), -1)
)
zfull = features[:, : gmm.dpc_split_kw.rank].numpy(force=True)
zfullw, mean5, components5, svs5 = spike_interp.fit_pcas(
    rrel.reshape(len(rrel), -1),
    missing=None,
    empty=None,
    rank=gmm.dpc_split_kw.rank,
    show_progress=False,
)
zfullw = zfullw.numpy(force=True)

In [None]:
z = np.unique(zfullw, axis=0)
dens = density.density_peaks_clustering(
    z,
    # sigma_local=gmm.dpc_split_kw.sigma_local,
    sigma_local="rule_of_thumb*0.5",
    n_neighbors_search=gmm.dpc_split_kw.n_neighbors_search,
    remove_clusters_smaller_than=25,
    return_extra=True,
)

ru = np.unique(dens["labels"])
panel, axes = dartvis.analysis_plots.density_peaks_study(
    z,
    dens,
    s=10,
)

In [None]:
gmm.dpc_split_kw.sigma_local

In [None]:
z = np.unique(z0, axis=0)
dens = density.density_peaks_clustering(
    z,
    # sigma_local=gmm.dpc_split_kw.sigma_local,
    sigma_local="rule_of_thumb*1.5",
    n_neighbors_search=gmm.dpc_split_kw.n_neighbors_search,
    remove_clusters_smaller_than=5,
    return_extra=True,
)

ru = np.unique(dens["labels"])
panel, axes = dartvis.analysis_plots.density_peaks_study(
    z,
    dens,
    s=10,
)

In [None]:
z = np.unique(z3, axis=0)
dens = density.density_peaks_clustering(
    z,
    # sigma_local=gmm.dpc_split_kw.sigma_local,
    sigma_local="rule_of_thumb",
    n_neighbors_search=gmm.dpc_split_kw.n_neighbors_search,
    remove_clusters_smaller_than=5,
    return_extra=True,
)

ru = np.unique(dens["labels"])
panel, axes = dartvis.analysis_plots.density_peaks_study(
    z,
    dens,
    s=10,
)

In [None]:
z = np.unique(z5, axis=0)
dens = density.density_peaks_clustering(
    z,
    # sigma_local=gmm.dpc_split_kw.sigma_local,
    sigma_local="rule_of_thumb*1.5",
    n_neighbors_search=gmm.dpc_split_kw.n_neighbors_search,
    remove_clusters_smaller_than=5,
    return_extra=True,
)

ru = np.unique(dens["labels"])
panel, axes = dartvis.analysis_plots.density_peaks_study(
    z,
    dens,
    s=10,
)

In [None]:
np.unique(z5, axis=0).shape

In [None]:
z5.shape

In [None]:
z1.size, np.unique(z1, axis=0).size

In [None]:
z = z1
dens = density.density_peaks_clustering(
    z,
    # sigma_local=gmm.dpc_split_kw.sigma_local,
    sigma_local="rule_of_thumb",
    n_neighbors_search=gmm.dpc_split_kw.n_neighbors_search,
    remove_clusters_smaller_than=5,
    return_extra=True,
)

ru = np.unique(dens["labels"])
panel, axes = dartvis.analysis_plots.density_peaks_study(
    z,
    dens,
    s=10,
)

In [None]:
gmm.do_interp = True
gmm.models.clear()
gmm.m_step()

In [None]:
gmm[0]

In [None]:
makeplots(
    # fig_dir / f"{tag}_001_splitem",
    fig_dir / f"{tag}_001_splitem_interp",
    gmm=gmm,
    sorting=full_dpc,
    n_jobs=4,
    with_summaries=False,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=False,
)

In [None]:
for _ in range(3):
    gmm.reassign()
    gmm.m_step(force=True)

In [None]:
makeplots(
    # fig_dir / f"{tag}_001_splitem",
    fig_dir / f"{tag}_001_splitem_interpem",
    gmm=gmm,
    sorting=full_dpc,
    n_jobs=4,
    with_summaries=True,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=False,
)

In [None]:
for _ in range(3):
    gmm.reassign()
    gmm.m_step(force=True)

In [None]:
makeplots(
    # fig_dir / f"{tag}_001_splitem",
    fig_dir / f"{tag}_001_splitem_interpemem",
    gmm=gmm,
    sorting=full_dpc,
    n_jobs=4,
    with_summaries=True,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=False,
)

In [None]:
unit_id = 8
in_unit = np.flatnonzero(gmm.labels == unit_id)

In [None]:
times = gmm.data.times_seconds[in_unit].numpy(force=True)
amps = np.nan_to_num(gmm.data.static_amp_vecs[in_unit]).ptp(1)
times.shape, amps.shape

In [None]:
from dartsort.cluster import density

In [None]:
def mad(x):
    x = x - np.median(x)
    np.abs(x, out=x)
    return np.median(x)

In [None]:
ds = density.get_smoothed_densities(
    np.c_[times / mad(times), amps / mad(amps)],
    sigmas=(0.5, 1),
)

In [None]:
ds = ds[0] / ds[1]

In [None]:
plt.scatter(times, amps, c=ds, lw=0, s=3)

In [None]:
z = np.c_[times / mad(times), amps / mad(amps)]
dens = density.density_peaks_clustering(
    z,
    # sigma_local=gmm.dpc_split_kw.sigma_local,
    sigma_local=0.5,
    sigma_regional=1.,
    min_bin_size=0.05,
    n_neighbors_search=gmm.dpc_split_kw.n_neighbors_search,
    remove_clusters_smaller_than=5,
    return_extra=True,
)

ru = np.unique(dens["labels"])
panel, axes = dartvis.analysis_plots.density_peaks_study(
    z,
    dens,
    s=10,
)

In [None]:
rg = np.random.default_rng(0)
choice0 = rg.choice(len(ds), size=1024, replace=False)
choice0.sort()

In [None]:
plt.scatter(times[choice0], amps[choice0], c=ds[choice0], lw=0)

In [None]:
plt.hist(ds[choice0])

In [None]:
plt.hist(np.log(ds[choice0]))

In [None]:
ds.min(), ds.max()

In [None]:
p = np.reciprocal(ds)
p/=p.sum()
choice1 = rg.choice(len(ds), size=1024, p=p)
choice1.sort()

In [None]:
plt.scatter(times[choice0], amps[choice0], c="r")
plt.scatter(times[choice1], amps[choice1], c="b")

In [None]:
import seaborn as sns

In [None]:
sns.jointplot(x=times[choice0], y=amps[choice0], kind="kde")

In [None]:
sns.jointplot(x=times[choice1], y=amps[choice1], kind="kde")

In [None]:
plt.scatter(times[choice1], amps[choice1], c=ds[choice1], lw=0)

In [None]:
ds0 = density.get_smoothed_densities(np.c_[times / mad(times), amps / mad(amps)][choice0], sigmas=1)

In [None]:
ds1 = density.get_smoothed_densities(np.c_[times / mad(times), amps / mad(amps)][choice1], sigmas=1)

In [None]:
plt.hist(ds0, histtype="step")
plt.hist(ds1, histtype="step")

In [None]:
dists = gmm.central_divergences()

In [None]:
dists

In [None]:
plt.imshow(dists, cmap=plt.cm.rainbow, vmin=0, vmax=1);
plt.colorbar();

In [None]:
d = np.minimum(dists, dists.T).numpy(force=True).astype(float)

In [None]:
(d < 0.25) + (d < 0.5)

In [None]:
sns.clustermap((d < 0.25) + (d < 0.5))

In [None]:
makeplots(
    fig_dir / f"{tag}_001_split",
    gmm=gmm,
    sorting=full_dpc,
    n_jobs=4,
    with_summaries=False,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=False,
)

In [None]:
gmm.reassign()

In [None]:
makeplots(
    fig_dir / f"{tag}_001_splitreas",
    gmm=gmm,
    sorting=full_dpc,
    n_jobs=4,
    with_summaries=False,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=False,
)

In [None]:
gmm.m_step(force=True)

In [None]:
# makeplots(
#     fig_dir / f"{tag}_001_splitreasfit",
#     gmm=gmm,
#     sorting=full_dpc,
#     n_jobs=4,
#     with_summaries=False,
#     with_dpcs=True,
#     with_over_time=False,
#     overwrite=False,
# )

In [None]:
gmm.residual_dpc_split()

In [None]:
makeplots(
    fig_dir / f"{tag}_002_split",
    gmm=gmm,
    sorting=full_dpc,
    n_jobs=4,
    with_summaries=False,
    with_dpcs=True,
    dpc_par=True,
    with_over_time=False,
    overwrite=False,
)

In [None]:
gmm.reassign()

In [None]:
makeplots(
    fig_dir / f"{tag}_002_splitreas",
    gmm=gmm,
    sorting=full_dpc,
    n_jobs=4,
    with_summaries=False,
    dpc_par=True,
    with_dpcs=True,
    with_over_time=False,
    overwrite=False,
)

In [None]:
gmm.m_step(force=True)

In [None]:
makeplots(
    fig_dir / f"{tag}_002_splitreasfit",
    gmm=gmm,
    sorting=full_dpc,
    n_jobs=4,
    with_summaries=False,
    dpc_par=True,
    with_dpcs=True,
    with_over_time=False,
    overwrite=False,
)

In [None]:
gc.collect()

In [None]:
gmm.reassign()
gmm.m_step(force=True)

In [None]:
gmm.reassign()
gmm.m_step(force=True)

In [None]:
gmm.reassign()
gmm.m_step(force=True)

In [None]:
gmm.reassign()
gmm.m_step(force=True)

In [None]:
gmm.reassign()
gmm.m_step(force=True)

In [None]:
gmm.reassign()
gmm.m_step(force=True)

In [None]:
gmm.reassign()
gmm.m_step(force=True)

In [None]:
gmm.reassign()
gmm.m_step(force=True)

In [None]:
makeplots(
    fig_dir / f"{tag}_003_em",
    gmm=gmm,
    sorting=full_dpc,
    n_jobs=4,
    with_summaries=False,
    dpc_par=True,
    with_dpcs=True,
    with_over_time=False,
    overwrite=False,
)

In [None]:
dartvis.gmm.make_unit_gmm_summary(
    gmm,
    16,
)

In [None]:
unit_id = 2

In [None]:
in_unit, utd = gmm.get_training_data(unit_id)

In [None]:
waveforms = utd["waveforms"]
n, r, c = waveforms.shape
waveforms = gmm.data.tpca._inverse_transform_in_probe(waveforms.permute(0, 2, 1).reshape(n * c, r))
waveforms = waveforms.reshape(n, c, -1).permute(0, 2, 1)
amps = np.nan_to_num(waveforms.numpy(force=True)).ptp(axis=(1, 2))

In [None]:
waveforms.shape

In [None]:
plt.hist(amps);

In [None]:
plt.scatter(utd["times"], amps, lw=0, s=2);

In [None]:
small = torch.logical_and(utd["times"] > 2500, torch.tensor(amps < 11))

In [None]:
recons = gmm[unit_id].get_means(utd["times"])
recons = gmm[unit_id].to_waveform_channels(recons, waveform_channels=utd["waveform_channels"])
n, r, c = recons.shape
recons = gmm.data.tpca._inverse_transform_in_probe(recons.permute(0, 2, 1).reshape(n * c, r))
recons = recons.reshape(n, -1, c).permute(0, 2, 1)

In [None]:
chans = utd["waveform_channels"].numpy(force=True)

In [None]:
chans.shape, waveforms.shape

In [None]:
small

In [None]:
small.shape

In [None]:
fig, ax = plt.subplots()
dartvis.geomplot(
    waveforms[small],
    channels=chans[small],
    geom=gmm.data.registered_geom.numpy(force=True),
    max_abs_amp=15,
    lw=1,
    show_zero=False,
    subar=True,
    color="b",
    msbar=False,
    zlim="tight",
    ax=ax,
)
dartvis.geomplot(
    waveforms[~small],
    channels=chans[~small],
    geom=gmm.data.registered_geom.numpy(force=True),
    max_abs_amp=15,
    lw=1,
    show_zero=False,
    subar=True,
    color="r",
    msbar=False,
    zlim="tight",
    ax=ax,
)

In [None]:
fig, ax = plt.subplots()
dartvis.geomplot(
    waveforms[~small],
    channels=chans[~small],
    geom=gmm.data.registered_geom.numpy(force=True),
    max_abs_amp=15,
    lw=1,
    show_zero=False,
    subar=True,
    color="r",
    msbar=False,
    zlim="tight",
    ax=ax,
    alpha=0.1,
)
dartvis.geomplot(
    waveforms[small],
    channels=chans[small],
    geom=gmm.data.registered_geom.numpy(force=True),
    max_abs_amp=15,
    lw=1,
    show_zero=False,
    subar=True,
    color="b",
    msbar=False,
    zlim="tight",
    ax=ax,
    alpha=0.1,
)

In [None]:
overlaps, rel_ix = gmm[unit_id].overlaps(utd["waveform_channels"])

In [None]:
plt.hist(overlaps[small], bins=np.linspace(0.5, 1, 40), density=True, histtype="step")
plt.hist(overlaps[~small], bins=np.linspace(0.5, 1, 40), density=True, histtype="step");

In [None]:
_, _, res = gmm[unit_id].spike_badnesses(utd["times"], utd["waveforms"], utd["waveform_channels"])

In [None]:
plt.hist(res["1-r^2"][small], bins=np.linspace(0.0, 1.5, 40), density=True, histtype="step")
plt.hist(res["1-r^2"][~small], bins=np.linspace(0.0, 1.5, 40), density=True, histtype="step");

In [None]:
plt.hist(res["1-scaledr^2"][small].clip(0, 1.5), bins=np.linspace(0.0, 1.5, 40), density=True, histtype="step")
plt.hist(res["1-scaledr^2"][~small].clip(0, 1.5), bins=np.linspace(0.0, 1.5, 40), density=True, histtype="step");

In [None]:
overlaps[small].unique()

In [None]:
overlaps[~small].unique()

In [None]:
plt.scatter(overlaps, res["1-scaledr^2"], c=small, s=3, lw=0)

In [None]:
wfs_u = gmm[unit_id].to_unit_channels(utd["waveforms"], rel_ix=rel_ix)

In [None]:
chans_u = gmm.data.cluster_channel_index[gmm[unit_id].max_channel]

In [None]:
chans_u.shape, wfs_u.shape

In [None]:
chans_u = chans_u[None].broadcast_to((len(wfs_u), *chans_u.shape)).contiguous()

In [None]:
n, r, c = wfs_u.shape
wfs_ur = gmm.data.tpca._inverse_transform_in_probe(wfs_u.permute(0, 2, 1).reshape(n * c, r))
wfs_ur = wfs_ur.reshape(n, c, -1).permute(0, 2, 1)

In [None]:
wfs_ur.shape

In [None]:
1

In [None]:
wfs_ur[~small].shape

In [None]:
small.sum()

In [None]:
small.shape

In [None]:
fig, ax = plt.subplots(figsize=(4, 4))
dartvis.geomplot(
    wfs_ur[~small].numpy(),
    channels=chans_u[~small].numpy(),
    geom=gmm.data.registered_geom.numpy(force=True),
    max_abs_amp=15,
    lw=1,
    show_zero=False,
    subar=True,
    color="r",
    msbar=False,
    zlim="tight",
    ax=ax,
    alpha=0.1,
)
dartvis.geomplot(
    wfs_ur[small].numpy(),
    channels=chans_u[small].numpy(),
    geom=gmm.data.registered_geom.numpy(force=True),
    max_abs_amp=15,
    lw=1,
    show_zero=False,
    subar=True,
    color="b",
    msbar=False,
    zlim="tight",
    ax=ax,
    alpha=0.1,
)

In [None]:
loadings, mean, components, svs = spike_interp.fit_pcas(
    wfs_u.reshape(len(wfs_u), -1),
    missing=None,
    empty=None,
    rank=2,
    max_iter=100,
    check_every=5,
    n_oversamples=10,
    atol=1e-3,
    show_progress=False,
    centered=True,
)

In [None]:
z = loadings.numpy()

In [None]:
plt.scatter(*z.T, c=small)

In [None]:
z = np.unique(z, axis=0)
dens = density.density_peaks_clustering(
    z,
    # sigma_local=gmm.dpc_split_kw.sigma_local,
    sigma_local="rule_of_thumb",
    sigma_regional="rule_of_thumb",
    n_neighbors_search=gmm.dpc_split_kw.n_neighbors_search,
    remove_clusters_smaller_than=5,
    return_extra=True,
)

ru = np.unique(dens["labels"])
panel, axes = dartvis.analysis_plots.density_peaks_study(
    z,
    dens,
    s=10,
)

In [None]:
z.shape

In [None]:
rg = np.random.default_rng(0)
subset = np.unique(np.concatenate((np.flatnonzero(small), rg.choice(np.flatnonzero(~small), size=small.sum(), replace=False))))
subset_mask = np.isin(subset, np.flatnonzero(small))

In [None]:
loadings, mean, components, svs = spike_interp.fit_pcas(
    wfs_u[subset].reshape(len(subset), -1),
    missing=None,
    empty=None,
    rank=2,
    max_iter=100,
    check_every=5,
    n_oversamples=10,
    atol=1e-3,
    show_progress=False,
    centered=True,
)

In [None]:
z = loadings.numpy()

In [None]:
plt.scatter(*z.T, c=subset_mask)

In [None]:
subci = dartsort.make_channel_index(gmm.data.registered_geom.numpy(), 10.0)
subci.shape

In [None]:
wfs_u.shape

In [None]:
wfs_usub = dartsort.util.waveform_util.channel_subset_by_index(
    wfs_u.numpy(),
    torch.full((len(wfs_u),), gmm[unit_id].max_channel).numpy(),
    gmm.data.cluster_channel_index.numpy(force=True),
    subci,
)

In [None]:
wfs_usub.shape

In [None]:
torch.tensor(wfs_usub).shape

In [None]:
loadings, mean, components, svs = spike_interp.fit_pcas(
    torch.tensor(wfs_usub).reshape(len(wfs_u), -1),
    missing=None,
    empty=None,
    rank=2,
    max_iter=100,
    check_every=5,
    n_oversamples=10,
    atol=1e-3,
    show_progress=False,
    centered=True,
)

In [None]:
z = loadings.numpy()

In [None]:
plt.scatter(*z[~small].T, c="r", lw=0, s=3, alpha=0.8)
plt.scatter(*z[small].T, c="b", lw=0, s=3, alpha=0.8)

In [None]:
loadings, mean, components, svs = spike_interp.fit_pcas(
    torch.tensor(wfs_usub[subset].reshape(len(subset), -1)),
    missing=None,
    empty=None,
    rank=2,
    max_iter=100,
    check_every=5,
    n_oversamples=10,
    atol=1e-3,
    show_progress=False,
    centered=True,
)

In [None]:
z = loadings.numpy()

In [None]:
plt.scatter(*z.T, c=subset_mask)

In [None]:
gc.collect()

In [None]:
gmm.do_interp = True
gmm.models.clear()

In [None]:
gmm.m_step()
gmm.residual_dpc_split()