In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
import glob
import numpy as np
import toml
import pandas as pd
import natsort
import sleap
import json
from tqdm.auto import tqdm
from joblib import delayed, Parallel

In [4]:
def get_nested_value(data, key_string, delimiter='.'):
    """
    Accesses a value in a nested dictionary using a string key.

    Args:
        data (dict): The nested dictionary.
        key_string (str): The string representing the path to the value.
        delimiter (str, optional): The delimiter separating keys in the string. Defaults to '.'.

    Returns:
        The value at the specified path, or None if the path is invalid.
    """
    keys = key_string.split(delimiter)
    current = data
    for key in keys:
        if isinstance(current, dict) and key in current:
            current = current[key]
        else:
            return None
    return current

# Parse grid search

1. TODO: need to check mAP test/train to see how robust learning is
2. TODO: match 1 with overall performance (maybe equal weighting after normalize?)

In [5]:
root_dir = "/storage/home/hcoda1/4/jmarkowitz30/shared_folder/active_lab_members/markowitz_jeffrey/active_projects/"
gridsearch_dir = [
    # os.path.join(root_dir, "keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_qds_all_cameras_kneejoints"),
    # os.path.join(root_dir, "keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_qds_same_camera_kneejoints"),
    # os.path.join(root_dir, "keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_qds_heldout_camera_kneejoints"),
    # os.path.join(root_dir, "keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_qds_all_cameras"),
    # os.path.join(root_dir, "keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_qds_same_camera"),
    os.path.join(
        root_dir, "keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_qds_all_cameras_v3"
    ),
    os.path.join(
        root_dir, "keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_qds_heldout_camera_v3"
    ),
    os.path.join(
        root_dir, "keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_qds_same_camera_v3"
    ),
    os.path.join(
        root_dir,
        "keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_qds_same_camera_kneejoints_v3",
    ),
    os.path.join(
        root_dir,
        "keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_qds_all_cameras_kneejoints_v3",
    ),
    os.path.join(
        root_dir,
        "keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_qds_heldout_camera_kneejoints_v3",
    ),
    # os.path.join(root_dir, "keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_qds_heldout_camera"),
    # os.path.join(root_dir, "keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_qds_all_cameras_large_models"),
]

# gridsearch_dir = [
#     os.path.join(root_dir, "keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_manual_data_parameter_sweep"),
#     # os.path.join(root_dir, "keypoints_basler_nir_plexiglass_arena/sleap_training_round2/keypoints_from_manual_data_fluo_aligned_param_sweep_v2"),
# ]

training_configs_fnames = []
for _dir in gridsearch_dir:
    training_configs_fnames += glob.glob(os.path.join(_dir, "**/initial_config.json"), recursive=True)

In [6]:
training_configs_fnames = natsort.natsorted(training_configs_fnames)

In [7]:
metrics_fnames = {
    "val": "metrics.val.npz",
    "train": "metrics.train.npz",
    "test": "metrics.test.npz"
}

In [8]:
label_keynames = {
    "val": "validation",
    "train": "training",
    "test": "test",
}

In [9]:
metric_list = [
    "dist.p50",
    "dist.p90",
    "dist.p99",
    "dist.avg",
    "vis.precision",
    "vis.recall",
    "oks_voc.mAP",
    "oks_voc.mAR",
    "pck.mPCK",
    "pck.pcks.10",
    # "pck.thresholds",
    "oks.mOKS",
]

In [10]:
with open(training_configs_fnames[0], "r") as f:
    training_config = json.load(f)

In [11]:
save_params = {
    "head_output_stride": "model.heads.single_instance.output_stride",
    "head_sigma": "model.heads.single_instance.sigma",
    "backbone_output_stride": "model.backbone.unet.output_stride",
    "backbone_filters": "model.backbone.unet.filters",
    "backbone_filters_rate": "model.backbone.unet.filters_rate",
    "backbone_max_stride": "model.backbone.unet.max_stride",
}

In [12]:
use_dir = os.path.dirname(training_configs_fnames[0])

In [14]:
def get_metrics(
    config_fname,
    metrics_fnames=metrics_fnames,
    label_keynames=label_keynames,
    metric_list=metric_list,
    save_params=save_params,
    node_names=None,
):
    param_dct = {}
    metrics = {}
    use_dir = os.path.dirname(config_fname)
    param_dir = os.path.basename(use_dir)

    if node_names is None:
        slp_data = sleap.load_file(os.path.join(use_dir, "labels_gt.val.slp"))
        node_names = [_.name for _ in slp_data.skeletons[0].nodes]

    if ("hindleg" in config_fname) or ("joint" in config_fname):
        try:
            metrics = {}
            for k, v in metrics_fnames.items():
                tmp_gt = sleap.load_file(
                    os.path.join(use_dir, v.replace("metrics.", "labels_gt.").replace("npz", "slp"))
                )
                tmp_pr = sleap.load_file(
                    os.path.join(use_dir, v.replace("metrics.", "labels_pr.").replace("npz", "slp"))
                )
                metrics[k] = sleap.nn.evals.evaluate(tmp_gt, tmp_pr, oks_scale=150, user_labels_only=True)
        except (IndexError, FileNotFoundError) as e:
            print(e)
            return None
    else:
        try:
            metrics = {}
            for k, v in metrics_fnames.items():
                metrics[k] = sleap.load_metrics(os.path.join(use_dir, v))
        except FileNotFoundError as e:
            print(e)
            return None

    with open(config_fname, "r") as f:
        initial_config = json.load(f)

    with open(os.path.join(use_dir, "training_config.json"), "r") as f:
        training_config = json.load(f)

    nframes = {}
    for k, v in label_keynames.items():
        nframes[k] = len(training_config["data"]["labels"][f"{v}_inds"])

    for _metric in metric_list:
        for _metric_type, _metric_data in metrics.items():
            if "pcks" in _metric:
                # find where threshold is number, pull out that pck
                use_name = ".".join(_metric.split(".")[:2])
                use_threshold = float(_metric.split(".")[-1])
                use_idx = np.flatnonzero(_metric_data["pck.thresholds"] == use_threshold)
                try:
                    param_dct[f"{_metric_type}_{_metric}"] = np.nanmean(_metric_data[use_name][:, use_idx, :])
                except IndexError:
                    pass
            else:
                param_dct[f"{_metric_type}_{_metric}"] = np.nanmean(_metric_data[_metric])

    # get kpoint specific here...
    for _metric_type, _metric_data in metrics.items():
        kpoint_dist = np.nanmean(_metric_data["dist.dists"], axis=0)
        # now break out by node names
        for _dist, _node in zip(kpoint_dist, node_names):
            param_dct[f"{_metric_type}_dist.parts.{_node}"] = _dist
    param_dct["basename"] = os.path.basename(os.path.dirname(config_fname))
    param_dct["condition"] = "all"
    param_dct["filename"] = config_fname

    for k, v in save_params.items():
        param_dct[k] = get_nested_value(initial_config, v)

    if "same_camera" in config_fname:
        param_dct["condition"] = "same"
    elif "heldout_camera" in config_fname:
        param_dct["condition"] = "different"
    param_dct["is_joint"] = "kneejoints" in config_fname
    for k, v in nframes.items():
        param_dct[f"nframes_{k}"] = v
    return param_dct

In [15]:
config = toml.load("../preprocessing/config.toml")
save_file = os.path.join(config["dirs"]["analysis"], "sleap_metrics_qd_training.parquet")
force = True

In [16]:
if not os.path.exists(save_file) or force:
    delays = []
    for _config_fname in training_configs_fnames:
        delays.append(delayed(get_metrics)(_config_fname))

    results = Parallel(n_jobs=15, backend="loky", verbose=10)(delays)
    df = pd.DataFrame([_result for _result in results if _result is not None])
    convert_cols = df.filter(regex="(recall|precision)").columns
    df[convert_cols] = df[convert_cols].astype("float")
    df.to_parquet(save_file)
else:
    df = pd.read_parquet(save_file)

[Parallel(n_jobs=15)]: Using backend LokyBackend with 15 concurrent workers.
[Parallel(n_jobs=15)]: Done   2 tasks      | elapsed:    6.3s
[Parallel(n_jobs=15)]: Done  11 tasks      | elapsed:    8.3s
[Parallel(n_jobs=15)]: Done  20 tasks      | elapsed:   11.2s
[Parallel(n_jobs=15)]: Done  31 tasks      | elapsed:   12.0s
[Parallel(n_jobs=15)]: Done  42 tasks      | elapsed:   13.0s
[Parallel(n_jobs=15)]: Done  55 tasks      | elapsed:   14.3s
[Parallel(n_jobs=15)]: Done  68 tasks      | elapsed:   15.2s
[Parallel(n_jobs=15)]: Done  83 tasks      | elapsed:   17.0s
[Parallel(n_jobs=15)]: Done  98 tasks      | elapsed:   22.1s
[Parallel(n_jobs=15)]: Done 115 tasks      | elapsed:   23.2s
[Parallel(n_jobs=15)]: Done 132 tasks      | elapsed:   23.8s
[Parallel(n_jobs=15)]: Done 151 tasks      | elapsed:   24.5s
[Parallel(n_jobs=15)]: Done 170 tasks      | elapsed:   25.4s
[Parallel(n_jobs=15)]: Done 191 tasks      | elapsed:   26.7s
[Parallel(n_jobs=15)]: Done 212 tasks      | elapsed:  