In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
import io
import copy
import cv2
import matplotlib.pyplot as plt
import numpy as np
import skimage
import toml
import glob
import h5py
from tqdm.auto import tqdm
from markovids import vid
from scipy import spatial, signal, ndimage
import sleap_io

In [3]:
from joblib import Parallel, delayed
import pandas as pd
import seaborn as sns

In [4]:
import re
import warnings

In [5]:
# TODO:
# 1. add hampel filtering...
# 2. add kalman filtering...

## Gather data and predicted keypoints

In [6]:
hindleg_only = False
base_dir = "/home/jmarkow/data_dir/active_projects/quantum_dots/timecourse_02"
# base_dir = "/home/jmarkow/data_dir/active_projects/quantum_dots/timecourse_02_joints"
fluo_files = sorted(glob.glob(os.path.join(base_dir, "**", "Basler*fluorescence.avi"), recursive=True))

In [7]:
if "joint" in base_dir:
    kpoint_dir = "/home/jmarkow/data_dir/active_projects/keypoints_basler_nir_plexiglass_arena/keypoint_inference_kneejoints_export_fused_weights-None_bpass-None"
else:
    kpoint_dir = "/home/jmarkow/data_dir/active_projects/keypoints_basler_nir_plexiglass_arena/keypoint_inference_export_fused_weights-None_bpass-None"


In [8]:
sleap_files = sorted(glob.glob(os.path.join(kpoint_dir, "**", "*.slp"), recursive=True))

In [9]:
bground_path = "_bground"

In [10]:
sleap_dcts = []
for _sfile in tqdm(sleap_files):
    metadata_fname = _sfile.replace(".predictions.slp",".toml")
    metadata = toml.load(metadata_fname)
    # sleap_dcts[_sfile] = {}
    _file_dct = {}
    _file_dct["sleap_fname"] = _sfile
    # _file_dect = sleap_dcts[_sfile]
    # _file_dct["kpoint_predictions_fname"] = _sfile
    _file_dct["fluo_fname"] = metadata["export_metadata"]["file"].replace("-reflectance.avi","-fluorescence.avi")
    
    fluo_basename = os.path.splitext(os.path.basename(_file_dct["fluo_fname"]))[0]
    fluo_dirname = os.path.dirname(_file_dct["fluo_fname"]) 
    bground_fname = os.path.join(fluo_dirname, bground_path, f"{fluo_basename}.hdf5")

    _file_dct["fluo_bground_fname"] = bground_fname
    _file_dct["reflect_fname"] = metadata["export_metadata"]["file"]
    _file_dct["kpoint_avi_fname"] = _sfile.replace(".predictions.slp","")
    _file_dct["start_time"] = metadata["export_metadata"]["original_metadata"]["start_time"]
    _file_dct["camera"] = metadata["export_metadata"]["cam"]
    # _file_dct["kpoint_arr"] = sleap_io.load_slp(_file_dct["sleap_fname"]).numpy(return_confidence=True).squeeze()

    sleap_dcts.append(_file_dct)

    # get kpoints, then center of mass in masked image around kpoint, do some basic
    # confidence thresholding and filtering...

  0%|          | 0/52 [00:00<?, ?it/s]

In [11]:
# https://stackoverflow.com/questions/66171969/compute-center-of-mass-in-numpy
def center_of_mass(array: np.ndarray, yidx = None, xidx = None):
    total = array.sum()
    # alternatively with np.arange as well
    if yidx is None:
        yidx = range(array.shape[0])

    if xidx is None:
        xidx = range(array.shape[1])

    with np.errstate(invalid="ignore"):
        y_coord = (array.sum(axis=1) @ yidx) / total
        x_coord = (array.sum(axis=0) @ xidx) / total
    return x_coord, y_coord

In [12]:
def get_kpoint_qd_df(
    sleap_fname=None,
    fluo_fname=None,
    fluo_bground_fname=None,
    batch_size=500,
    search_radius=5,
    confidence_threshold=0, # used no confidence threshold for first run, then .5 for joints
    reader_kwargs={"threads": 2},
    force=True,
    **kwargs
):
    
    save_file = os.path.splitext(sleap_fname)[0] + ".parquet"
    if (not force) and os.path.exists(save_file):
        return pd.read_parquet(save_file)

    dset = sleap_io.load_slp(sleap_fname)
    
    # TODO: also read in frame index...
    kpoint_arr = dset.numpy(return_confidence=True).squeeze()
    skeleton = dset.skeleton

    search_obj = skimage.morphology.disk(radius=search_radius).astype("float")
    search_y, search_x = np.where(search_obj>0)
    center = search_obj.shape[0] // 2
    xx = np.arange(search_obj.shape[1] + 1)
    yy = np.arange(search_obj.shape[0] + 1)
    xx -= center
    yy -= center

    with h5py.File(fluo_bground_fname, "r") as f:
        bgrounds = f["bground"][()]
        bgrounds_idxs = f["frame_idxs"][()]

    reader = vid.io.AutoReader(fluo_fname, **reader_kwargs)
    nframes = reader.nframes
    width, height = reader.frame_size

    dcts = []
    batches = range(0, nframes, batch_size)
    for _batch in tqdm(batches):
        frame_idx = range(_batch, _batch + batch_size)
        use_frames = reader.get_frames(frame_idx)
        frame_idx = list(frame_idx)

        for idx, _frame in tqdm(zip(frame_idx, use_frames), total=len(frame_idx)):
            kpoints = kpoint_arr[idx]
            
            for kpoint_idx, _kpoint in enumerate(kpoints):
                if ~np.isnan(_kpoint[0]) and (_kpoint[2] >= confidence_threshold):

                    use_bground_idx = np.argmin(np.abs(bgrounds_idxs - idx))
                    bground_sub = np.clip(
                        _frame.astype("int16") - bgrounds[use_bground_idx], 0, np.inf
                    )

                    kpoint_x_idx = int(np.round(_kpoint[0]))
                    kpoint_y_idx = int(np.round(_kpoint[1]))

                    new_coords = xx + kpoint_x_idx, yy + kpoint_y_idx
                    xrange = (new_coords[0][0], new_coords[0][-1])
                    yrange = (new_coords[1][0], new_coords[1][-1]) 
                    try:
                        masked_bground_sub = search_obj * bground_sub[slice(*yrange), slice(*xrange)].astype("float")
                    except ValueError:
                        continue
                    
                    com = center_of_mass(masked_bground_sub, yidx=range(*yrange), xidx=range(*xrange))
                    fluo_ave = np.nanmean(masked_bground_sub[search_y,search_x])
                    fluo_peak = np.nanmax(masked_bground_sub[search_y,search_x])
                    
                    _dct = {
                        "com_x": com[0],
                        "com_y": com[1],
                        "kpoint_name": skeleton.nodes[kpoint_idx].name,
                        "kpoint_x": _kpoint[0],
                        "kpoint_y": _kpoint[1],
                        "kpoint_confidence": _kpoint[2],
                        "kpoint_to_com_l2": np.hypot(
                            _kpoint[0] - com[0], _kpoint[1] - com[1]
                        ),
                        "fluo_ave": fluo_ave,
                        "fluo_peak": fluo_peak,
                        "frame_index": idx,
                    }

                    dcts.append(_dct)

    reader.close()
    store_dat = pd.DataFrame(dcts)
    store_dat["sleap_file"] = sleap_fname
    store_dat["fluo_file"] = fluo_fname
    store_dat["search_radius"] = search_radius
    for k, v in kwargs.items():
        store_dat[k] = v
    store_dat.to_parquet(save_file)
    return store_dat

In [13]:
delays = []
for _dct in sleap_dcts:
    _delay = delayed(get_kpoint_qd_df)(
        sleap_fname=_dct["sleap_fname"],
        fluo_fname=_dct["fluo_fname"],
        fluo_bground_fname=_dct["fluo_bground_fname"],
        start_time=_dct["start_time"],
        kpoint_avi_fname=_dct["kpoint_avi_fname"],
        camera=_dct["camera"],
        search_radius=7.5,
        force=False,
        # confidence_threshold=0 if "joint" in base_dir else 0,
    )
    delays.append(_delay)

In [14]:
# probably save intermediate files...

In [15]:
dat = Parallel(n_jobs=7, verbose=10, backend="multiprocessing")(delays)

[Parallel(n_jobs=7)]: Using backend MultiprocessingBackend with 7 concurrent workers.
[Parallel(n_jobs=7)]: Batch computation too fast (0.18516826629638672s.) Setting batch_size=2.
[Parallel(n_jobs=7)]: Done   4 tasks      | elapsed:    0.4s
[Parallel(n_jobs=7)]: Done  11 tasks      | elapsed:    0.6s
[Parallel(n_jobs=7)]: Done  22 tasks      | elapsed:    1.0s
[Parallel(n_jobs=7)]: Done  45 out of  52 | elapsed:    1.6s remaining:    0.3s
[Parallel(n_jobs=7)]: Done  52 out of  52 | elapsed:    1.8s finished


In [16]:
qd_df = pd.concat(dat, ignore_index=True)

In [17]:
qd_df["reflect_file"] = qd_df["fluo_file"].str.replace("-fluorescence.avi","-reflectance.avi")

In [18]:
config = toml.load("config.toml")
config["dirs"]["analysis"]

'/home/jmarkow/data_dir/active_projects/quantum_dots/_analysis'

In [19]:
if "joint" in base_dir:
    qd_df.to_parquet(os.path.join(config["dirs"]["analysis"], "kpoint_kneejoints_qd_alignment.parquet"))
else:
    qd_df.to_parquet(os.path.join(config["dirs"]["analysis"], "kpoint_qd_alignment.parquet"))

In [23]:
confidence_threshold_hi = .6 if ("joint" in base_dir or hindleg_only) else .3
fluo_ave_threshold_hi = 25
l2_threshold_hi = np.inf
min_confidence_per_frame = .7
min_nkpoints = 1 if ("joint" in base_dir or hindleg_only) else 5

In [24]:
condition2 = f"(kpoint_confidence > {confidence_threshold_hi}"
condition2 += f" and fluo_ave > {fluo_ave_threshold_hi}"
condition2 += f" and kpoint_to_com_l2 < {l2_threshold_hi})"

In [25]:
filtered_qd_df = qd_df.query(
    f"{condition2}"
).copy()

In [26]:
if hindleg_only:
    # only use one mouse here since we're comparing to joints, which only has one mouse
    filtered_qd_df = filtered_qd_df.loc[filtered_qd_df["sleap_file"].str.lower().str.contains("qd_beads_07")].copy()
    filtered_qd_df = filtered_qd_df.query("kpoint_name.str.contains('hindleg')").copy()

In [27]:
filtered_qd_df = filtered_qd_df.groupby(["kpoint_avi_fname", "frame_index"]).filter(
    lambda x: x["kpoint_confidence"].min() >= min_confidence_per_frame
)

In [28]:
filtered_qd_df = filtered_qd_df.groupby(["kpoint_avi_fname", "frame_index"]).filter(
    lambda x: len(x) >= min_nkpoints 
)

In [29]:
if "joint" in base_dir:
    filtered_qd_df = filtered_qd_df.loc[~filtered_qd_df["sleap_file"].str.contains("\+48h")].copy()

In [30]:
# maybe require n kpoints per frame...

In [31]:
from scipy import spatial

In [32]:
def get_distance(x):
    if len(x) > 1:
        return spatial.distance.pdist(x.to_numpy())

In [33]:
# filtered_qd_df.groupby(["sleap_file", "frame_index"])[["kpoint_x", "kpoint_y"]].apply(
#     get_distance
# ).dropna().min()

In [34]:
len(filtered_qd_df)

174548

# Make new training file

1. ADD HAMPEL FILTER?

In [35]:
import sleap_io as sio
import sleap

2024-08-15 16:00:29.335355: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /home/jmarkow/miniconda3/envs/sleap-analysis/lib/python3.10/site-packages/cv2/../../lib64:
2024-08-15 16:00:29.335389: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [36]:
if "joint" in base_dir:
    old_labels_fname = "/home/jmarkow/data_dir/active_projects/keypoints_basler_nir_plexiglass_arena/sleap_training/_labels/train_frames_kneejoints_fused_weights-None_bpass-None.slp"
else:
    old_labels_fname = "/home/jmarkow/data_dir/active_projects/keypoints_basler_nir_plexiglass_arena/sleap_training/_labels/train_frames_fused_weights-None_bpass-None.slp"

In [37]:
use_skeleton = sleap.load_file(old_labels_fname).skeleton

In [38]:
if hindleg_only:
    new_skeleton = sleap.skeleton.Skeleton(name=use_skeleton.name)
    new_skeleton.add_node("hindleg_L")
    new_skeleton.add_node("hindleg_R")
    new_skeleton.add_symmetry("hindleg_L","hindleg_R")
    use_skeleton = new_skeleton 

In [39]:
# old_labels = sio.load_slp(old_labels_fname)
# use_skeleton = old_labels.skeleton

In [40]:
# map to idx 
kpoint_idx_mapping = {key: i for i, key in enumerate(sorted(filtered_qd_df["kpoint_name"].unique()))}
idx_kpoint_mapping = {v: k for k, v in kpoint_idx_mapping.items()}

In [41]:
n_nodes = len(kpoint_idx_mapping)

In [42]:
filtered_qd_df["kpoint_idx"] = filtered_qd_df["kpoint_name"].map(kpoint_idx_mapping)

In [43]:
# TRY KPOINT X/Y HERE
def convert_to_point_array(df, x_key = "com_x", y_key= "com_y", n_nodes=n_nodes):
    # sort by kpoint num
    assert(len(df) == df["kpoint_name"].nunique())
    new_array = np.full((n_nodes, 2), fill_value=np.nan)
    for _idx, _row in df.iterrows():
        new_array[int(_row["kpoint_idx"]), 0] = _row[x_key]
        new_array[int(_row["kpoint_idx"]), 1] = _row[y_key]
    return new_array

In [44]:
# sort by fluo_name, reflect_name, or kpoint_avi_fname to use various things...

In [45]:
use_vid_key = "reflect_file"

In [46]:
arrays = filtered_qd_df.groupby([use_vid_key, "frame_index"]).apply(
    convert_to_point_array, include_groups=False
)

In [47]:
uniq_videos = sorted(set([_key[0] for _key, _arr in arrays.items()]))

In [48]:
videos = {}
for _vid in uniq_videos:
    videos[_vid] = sleap.io.video.Video.from_filename(_vid)

In [49]:
all_labeled_frames = []
for (_vid, _frame_idx), _arr in tqdm(arrays.items(), total=len(arrays)):
    use_points = {}
    for i, _point in enumerate(_arr):
        x, y = _point
        name = idx_kpoint_mapping[i]
        use_points[name] = sleap.instance.Point(x=x, y=y, visible=True, complete=True)
    if len(use_points) > 0:
        instance = sleap.instance.Instance(skeleton=use_skeleton, points=use_points)
        labeled_frame = sleap.instance.LabeledFrame(
            video=videos[_vid], instances=[instance], frame_idx=_frame_idx
        )
        all_labeled_frames.append(labeled_frame)

  0%|          | 0/141570 [00:00<?, ?it/s]

In [50]:
nlabeled_frames = len(all_labeled_frames)

In [51]:
from segments import SegmentsClient

# You can find your api key at https://segments.ai/account
api_key = "XXX"
client = SegmentsClient(api_key)

INFO:segments.client:Initialized successfully.


In [52]:
if "joint" in base_dir:
    dataset_name = "jmarkow/basler-nir-plexiglass-arena-keypoints-fused-kneejoints"
else:
    dataset_name = "jmarkow/basler-nir-plexiglass-arena-keypoints-fused"
samples = client.get_samples(dataset_name, per_page=10000, label_status=["LABELED","REVIEWED"])
pre_segments_labels = [(client.get_label(_sample.uuid), _sample) for _sample in samples]

In [53]:
len(samples)

592

In [54]:
# make sure we don't have any file/frame collisions. If so remove from the new training data...

In [55]:
keep_idxs = []
for i, _labeled_frame in tqdm(enumerate(all_labeled_frames), total=len(all_labeled_frames)):
    chk_file = _labeled_frame.video.filename
    chk_string = f"{os.sep}".join(chk_file.split(os.sep)[-3:]) # last three ropes in session name
    chk_frame_idx = int(_labeled_frame.frame_idx)
    keep = True
    for (_label, _sample) in pre_segments_labels:
        match_string = f"{os.sep}".join(_sample.metadata["dat_path_reflect"].split(os.sep)[-3:]) # last three ropes in session name
        match_frame_idx = int(_sample.metadata["frame_index"])
        frame_distance = np.abs(match_frame_idx - chk_frame_idx)
        # if files match we need to be N frames away to keep
        if (match_string == chk_string) and (frame_distance < 10):
            keep = False
            break
    # if a match was never found let's keep it!
    if keep:
        keep_idxs.append(i)

  0%|          | 0/141570 [00:00<?, ?it/s]

In [56]:
print(len(keep_idxs))

137655


In [57]:
all_labeled_frames_nomatch = [all_labeled_frames[i] for i in keep_idxs]
nlabeled_frames = len(all_labeled_frames_nomatch)

In [58]:
nlabeled_frames

137655

In [59]:
# label_range = [  60, 125, 250, 500, 1000, 2500, 5000, 10000, 25000, 50000, 100000 ]
label_range = [  60, 125, 250, 500, 1000, 2500, 5000, 10000, 20000 ]

In [None]:
save_dir = "/home/jmarkow/data_dir/active_projects/keypoints_basler_nir_plexiglass_arena/sleap_training/_labels_qd/"
os.makedirs(save_dir, exist_ok=True)

for _nlabels in tqdm(label_range):
    if "joint" in base_dir:
        save_path = os.path.join(
            save_dir,
            f"training_data_qdots_kneejoints_nframes-{_nlabels}_vidtype-{use_vid_key}_minkpoints-{min_nkpoints}.slp",
        )
    elif hindleg_only:
        save_path = os.path.join(
            save_dir,
            f"training_data_qdots_hindleg_nframes-{_nlabels}_vidtype-{use_vid_key}_minkpoints-{min_nkpoints}.slp",
        )
    else:    
        save_path = os.path.join(
            save_dir,
            f"training_data_qdots_v2_nframes-{_nlabels}_vidtype-{use_vid_key}_minkpoints-{min_nkpoints}.slp",
        )
    if os.path.exists(save_path):
        continue

    choose_frames = np.round(np.linspace(1, nlabeled_frames - 1, _nlabels)).astype(
        "int"
    )
    save_labeled_frames = [all_labeled_frames_nomatch[i] for i in choose_frames]

    labels = sleap.Labels(labeled_frames=save_labeled_frames)
    labels.save(save_path, with_images=True)

  0%|          | 0/9 [00:00<?, ?it/s]

In [70]:
save_dir = "/home/jmarkow/data_dir/active_projects/keypoints_basler_nir_plexiglass_arena/sleap_training/_labels_qd_noimages/"
os.makedirs(save_dir, exist_ok=True)

for _nlabels in tqdm(label_range):
    if "joint" in base_dir:
        save_path = os.path.join(
            save_dir,
            f"training_data_qdots_kneejoints_nframes-{_nlabels}_vidtype-{use_vid_key}_minkpoints-{min_nkpoints}.slp",
        )
    elif hindleg_only:
        save_path = os.path.join(
            save_dir,
            f"training_data_qdots_hindleg_nframes-{_nlabels}_vidtype-{use_vid_key}_minkpoints-{min_nkpoints}.slp",
        )
    else:    
        save_path = os.path.join(
            save_dir,
            f"training_data_qdots_v2_nframes-{_nlabels}_vidtype-{use_vid_key}_minkpoints-{min_nkpoints}.slp",
        )
    if os.path.exists(save_path):
        continue

    choose_frames = np.round(np.linspace(1, nlabeled_frames - 1, _nlabels)).astype(
        "int"
    )
    save_labeled_frames = [all_labeled_frames_nomatch[i] for i in choose_frames]

    labels = sleap.Labels(labeled_frames=save_labeled_frames)
    labels.save(save_path, with_images=False)

  0%|          | 0/9 [00:00<?, ?it/s]

In [71]:
print("test")

test
