In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
import numpy as np
import skimage
import toml
import sleap_io as sio
import joblib
import pandas as pd
import json
from glob import glob
from tqdm.auto import tqdm
from markovids import pcl

In [3]:
config = toml.load("../preprocessing/config.toml")
config["dirs"]["analysis"]

'/storage/home/hcoda1/4/jmarkowitz30/shared_folder/active_lab_members/markowitz_jeffrey/active_projects/quantum_dots/_analysis'

In [4]:
def get_nested_value(data, key_string, delimiter='.'):
    """
    Accesses a value in a nested dictionary using a string key.

    Args:
        data (dict): The nested dictionary.
        key_string (str): The string representing the path to the value.
        delimiter (str, optional): The delimiter separating keys in the string. Defaults to '.'.

    Returns:
        The value at the specified path, or None if the path is invalid.
    """
    keys = key_string.split(delimiter)
    current = data
    for key in keys:
        if isinstance(current, dict) and key in current:
            current = current[key]
        else:
            return None
    return current

In [5]:
# save_dir = "/mnt/data/jmarkow/panels/2024-06 (QD paper)"
# os.makedirs(save_dir, exist_ok=True)

## Gather data and predicted keypoints

In [6]:
# pool repeats

In [7]:

root_dir = "/storage/home/hcoda1/4/jmarkowitz30/shared_folder/active_lab_members/markowitz_jeffrey/active_projects/"
# model_sub_dirs = ["keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_manual_data_final_model/models"]
# model_sub_dirs = ["keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_manual_data_parameter_sweep/models",
#                   "keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_manual_data_parameter_sweep_part2/models"]
# model_sub_dirs = ["keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_manual_data_subsample/models",
#                   "keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_manual_data_subsample_morerepeats/models"]
model_sub_dirs = ["keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_manual_data_subsample_reflectanceonly/models"]

# model_sub_dirs = ["keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_manual_data_different_modalities/models"]

# model_sub_dirs = ["keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_manual_data_different_modalities_kneejoints/models"]
# model_sub_dirs = ["keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_manual_data_final_model_kneejoints/models"]

# model_sub_dirs = ["keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_manual_data_different_modalities_bottomup/models"]

model_dirs = []
for _model_sub_dir in model_sub_dirs:
    model_dirs += [_dir for _dir in sorted(glob(os.path.join(root_dir,_model_sub_dir,"*_instance"))) if os.path.isdir(_dir)]

In [8]:
if "kneejoints" in model_sub_dirs[0]:
    training_fname = os.path.join(root_dir, "keypoints_basler_nir_plexiglass_arena/sleap_training_round2/_labels/basler-nir-plexiglass-arena-keypoints-fused-kneejoints_weights-(1.0, 0.0)_bpass-None.slp")    
else:
    training_fname = os.path.join(root_dir, "keypoints_basler_nir_plexiglass_arena/sleap_training_round2/_labels/basler-nir-plexiglass-arena-keypoints-fused-round2_weights-(1.0, 0.0)_bpass-None.slp")

In [9]:
save_file = os.path.join(root_dir, model_sub_dirs[0], "../aggregated_results.parquet") 

In [10]:
fluo_labels = sio.load_slp(training_fname)
skeleton = fluo_labels.skeleton

In [11]:
fluo_ims = [_label.image for _label in tqdm(fluo_labels.labeled_frames)]

  0%|          | 0/862 [00:00<?, ?it/s]

In [12]:
file_types = ["train","val","test"] # make sure indices line up between file and metadata types since we just zip
metadata_types = ["training","validation","test"]
# file_types = ["test"]
# metadata_types = ["test"]

In [13]:
save_params = {
    "head_output_stride": "model.heads.single_instance.output_stride",
    "head_sigma": "model.heads.single_instance.sigma",
    "backbone_output_stride": "model.backbone.unet.output_stride",
    "backbone_filters": "model.backbone.unet.filters",
    "backbone_filters_rate": "model.backbone.unet.filters_rate",
    "backbone_max_stride": "model.backbone.unet.max_stride",
}

In [14]:
# training_config["data"]["labels"]["validation_inds"]

In [15]:
def get_kpoint_gt_df(
    model_dir,
    fluo_ims,
    file_types=file_types,
    metadata_types=metadata_types,
    search_radius=10,
    incl_other_node_thresh=5,
    save_params=save_params,
):

    search_obj = skimage.morphology.disk(radius=search_radius).astype("float")
    search_y, search_x = np.where(search_obj > 0)
    center = search_obj.shape[0] // 2
    xx = np.arange(search_obj.shape[1] + 1)
    yy = np.arange(search_obj.shape[0] + 1)
    xx -= center
    yy -= center

    with open(os.path.join(model_dir, "training_config.json")) as f:
        training_config = json.load(f)

    with open(os.path.join(model_dir, "initial_config.json")) as f:
        initial_config = json.load(f)

    training_data_path = os.path.join(model_dir, "../../", training_config["data"]["labels"]["training_labels"])
    training_metadata_path = os.path.splitext(training_data_path)[0] + ".toml"
    training_metadata = toml.load(training_metadata_path)["segments_metadata"]
    gt = {}
    for _type in file_types:
        try:
            gt[_type] = sio.load_slp(os.path.join(model_dir, f"labels_gt.{_type}.slp")).labeled_frames
        except FileNotFoundError as e:
            continue

    pred = {}
    for _type in file_types:
        try:
            pred[_type] = sio.load_slp(os.path.join(model_dir, f"labels_pr.{_type}.slp")).labeled_frames
        except FileNotFoundError as e:
            continue

    use_metadata = {}
    for _type in metadata_types:
        try:
            use_metadata[_type] = [training_metadata[i] for i in training_config["data"]["labels"][f"{_type}_inds"]]
        except (KeyError, TypeError) as e:
            continue

    # make sure everything is aligned
    if (list(pred.keys()) != list(gt.keys())) or (len(pred) != len(use_metadata)):
        print("misalign")
        return None

    nframes = {}
    for k, v in gt.items():
        nframes[k] = len(v)

    dcts = []
    for _gts, _preds, _metadatas, _type in zip(gt.values(), pred.values(), use_metadata.values(), file_types):
        for _dset_index, (_gt, _pred, _metadata) in enumerate(zip(_gts, _preds, _metadatas)):

            try:
                points_gt = _gt.instances[0].points
                points_pred = _pred.instances[0].points
            except IndexError as e:
                continue

            points_gt = {k.name: v for k, v in points_gt.items()}
            points_pred = {k.name: v for k, v in points_pred.items()}

            fluo_frame_idx = _gt.frame_idx
            bground_sub = fluo_ims[fluo_frame_idx].squeeze()

            # need to convert fluo to SNR...
            # frame_med = np.median(bground_sub[bground_sub>0])
            # frame_mad = np.median(np.abs(bground_sub[bground_sub>0] - frame_med))

            for node, _kpoint in points_pred.items():

                gt_x = points_gt[node].x
                gt_y = points_gt[node].y

                if ~np.isnan(_kpoint.x):

                    try:
                        gt_x_idx = int(np.round(gt_x))
                        gt_y_idx = int(np.round(gt_y))
                    except ValueError as e:
                        continue

                    kpoint_x_idx = int(np.round(_kpoint.x))
                    kpoint_y_idx = int(np.round(_kpoint.y))
                    kpoint_score = _kpoint.score

                    try:

                        new_coords = xx + kpoint_x_idx, yy + kpoint_y_idx
                        xrange = (new_coords[0][0], new_coords[0][-1])
                        yrange = (new_coords[1][0], new_coords[1][-1])

                        # get CoM shift relative to center, add as offset to original kpoint

                        # first mask search radius, then get central blob
                        masked_bground_sub = search_obj * bground_sub[slice(*yrange), slice(*xrange)].astype("float")

                        mask = pcl.fluo.get_closest_blob(masked_bground_sub)
                        masked_bground_sub[mask == 0] = 0

                        kpoint_gauss_params, kpoint_moment_params = pcl.fluo.fit_2d_gaussian_with_moments(
                            masked_bground_sub,
                            loss="linear",
                        )
                        kpoint_com = [
                            kpoint_moment_params["x0"] + xrange[0],
                            kpoint_moment_params["y0"] + yrange[0],
                        ]

                        if kpoint_gauss_params is not None:
                            kpoint_com_gauss = [
                                kpoint_gauss_params["x0"] + xrange[0],
                                kpoint_gauss_params["y0"] + yrange[0],
                            ]
                            kpoint_amp_gauss = kpoint_gauss_params["amplitude"]
                        else:
                            kpoint_com_gauss = [np.nan, np.nan]
                            kpoint_amp_gauss = np.nan

                        kpoint_fluo_ave = np.nanmean(masked_bground_sub[search_y, search_x])
                        kpoint_fluo_peak = np.nanmax(masked_bground_sub[search_y, search_x])

                    except ValueError:

                        kpoint_fluo_ave = np.nan
                        kpoint_fluo_peak = np.nan
                        kpoint_com = (np.nan, np.nan)

                    _dct = {
                        "gt_x": gt_x,
                        "gt_y": gt_y,
                        "kpoint_name": node,
                        "kpoint_x": _kpoint.x,
                        "kpoint_y": _kpoint.y,
                        "kpoint_com_x": kpoint_com[0],
                        "kpoint_com_y": kpoint_com[1],
                        "kpoint_com_gauss_x": kpoint_com_gauss[0],
                        "kpoint_com_gauss_y": kpoint_com_gauss[1],
                        "kpoint_amp_gauss": kpoint_amp_gauss,
                        "kpoint_fluo_ave": kpoint_fluo_ave,
                        "kpoint_fluo_peak": kpoint_fluo_peak,
                        "kpoint_score": kpoint_score,
                        "frame_index": _metadata["frame_index"],
                        "dset_index": _dset_index,
                        "dset_type": _type,
                        "model_dir": model_dir,
                        "search_radius": search_radius,
                    }
                else:
                    # make sure we count missed keypoints...
                    _dct = {
                        "gt_x": gt_x,
                        "gt_y": gt_y,
                        "kpoint_name": node,
                        "kpoint_x": _kpoint.x,
                        "kpoint_y": _kpoint.y,
                        "kpoint_score": _kpoint.score,
                        "frame_index": _metadata["frame_index"],
                        "dset_index": _dset_index,
                        "dset_type": _type,
                        "model_dir": model_dir,
                        "search_radius": search_radius,
                    }

                dcts.append(_dct)
    df = pd.DataFrame(dcts)
    for k, v in save_params.items():
        df[k] = get_nested_value(initial_config, v)
    for k, v in nframes.items():
        df[f"nframes {k}"] = v
    return df

In [16]:
def get_l2_norm(df, x1="kpoint_x", x2="gt_x", y1="kpoint_y", y2="gt_y"):
    return np.linalg.norm([df[x1] - df[x2], df[y1] - df[y2]])

In [17]:
# get_kpoint_gt_df(_dir, fluo_ims, search_radius=10.)

In [18]:
search_radius_scan = [10.]

In [19]:
delays = []
for _search_radius in search_radius_scan:
    for _dir in model_dirs:
        delays.append(joblib.delayed(get_kpoint_gt_df)(_dir, fluo_ims, search_radius=_search_radius))

In [20]:
force = True

In [21]:
# get_kpoint_gt_df(model_dirs[2], fluo_ims, search_radius=_search_radius)

In [22]:
if (not os.path.exists(save_file)) or force:
    print(f"Processing {len(delays)} jobs")
    results = joblib.Parallel(n_jobs=18, verbose=10)(delays)
    result_df = pd.concat(results, ignore_index=True)
    result_df["model_weight"] = result_df["model_dir"].str.extract("weights-(.*?)_")
    result_df["model_subsample"] = result_df["model_dir"].str.extract("subsample-(.*?)_")
    result_df["kpoint_gt_l2"] = result_df.apply(
        lambda x: get_l2_norm(x, x1="kpoint_x", y1="kpoint_y"), axis=1
    )
    result_df["kpoint_com_gauss_gt_l2"] = result_df.apply(
        lambda x: get_l2_norm(x, x1="kpoint_com_gauss_x", y1="kpoint_com_gauss_y"), axis=1
    )
    result_df["kpoint_com_gt_l2"] = result_df.apply(
        lambda x: get_l2_norm(x, x1="kpoint_com_x", y1="kpoint_com_y"), axis=1
    )
    result_df.to_parquet(save_file)
else:
    print(f"Loading saved data from {save_file}")
    result_df = pd.read_parquet(save_file)

Processing 56 jobs


[Parallel(n_jobs=18)]: Using backend LokyBackend with 18 concurrent workers.
[Parallel(n_jobs=18)]: Done   5 tasks      | elapsed:   30.3s
[Parallel(n_jobs=18)]: Done  14 tasks      | elapsed:   51.3s
[Parallel(n_jobs=18)]: Done  27 out of  56 | elapsed:  1.3min remaining:  1.4min


misalign


[Parallel(n_jobs=18)]: Done  33 out of  56 | elapsed:  1.5min remaining:  1.1min
[Parallel(n_jobs=18)]: Done  39 out of  56 | elapsed:  1.7min remaining:   44.8s


misalign
misalign


[Parallel(n_jobs=18)]: Done  45 out of  56 | elapsed:  2.3min remaining:   33.3s
[Parallel(n_jobs=18)]: Done  51 out of  56 | elapsed:  2.6min remaining:   15.3s
[Parallel(n_jobs=18)]: Done  56 out of  56 | elapsed:  2.7min finished


misalign
misalign
misalign
misalign
