In [1]:
import os
import re
import json
import yaml
import argparse
from test_metric import process_trial
from wbfm.utils.projects.finished_project_data import ProjectData
import matplotlib.pyplot as plt


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def build_final_dict(gt_path, result_dir):
    # Load GT once
    project_data_gt = ProjectData.load_final_project_data(gt_path)
    df_gt = project_data_gt.final_tracks

    result_dict = {
        "hyperparams": [],
        "trial": [],
        "accuracy": [],
        "per_neuron_accuracy": [],
        "per_timepoint_accuracy": [],
        "misses_per_neuron_norm": [],
        "misses_per_timepoint_norm": [],
        "mismatches_per_neuron_norm": [],
        "mismatches_per_timepoint_norm": [],
    }

    trial_num = 0
    for entry in os.listdir(result_dir):
        
        result_path = os.path.join(result_dir, entry, "project_config.yaml")

        try:
            if os.path.isfile(result_path):
                stats = process_trial(trial_num, df_gt, result_path)
                result_dict["hyperparams"].append(str(entry))
                result_dict["accuracy"].append(stats.get("accuracy"))
                result_dict["per_neuron_accuracy"].append(stats.get("accuracy_per_neuron"))
                result_dict["per_timepoint_accuracy"].append(stats.get("accuracy_per_timepoint"))
                result_dict["misses_per_neuron_norm"].append(stats.get("misses_per_neuron_norm"))
                result_dict["misses_per_timepoint_norm"].append(stats.get("misses_per_timepoint_norm"))
                result_dict["mismatches_per_neuron_norm"].append(stats.get("mismatches_per_neuron_norm"))
                result_dict["mismatches_per_timepoint_norm"].append(stats.get("mismatches_per_timepoint_norm"))
            else:
                print(f"{trial_name}: project_config.yaml not found.")
                result_dict["accuracy"].append(None)
                result_dict["per_neuron_accuracy"].append(None)
                result_dict["per_timepoint_accuracy"].append(None)

        except Exception as e:
            print(f"{trial_name}: ERROR -> {e}")

    return result_dict

In [10]:
ground_truth_path = "/lisc/scratch/neurobiology/zimmer/fieseler/wbfm_projects_future/flavell_data/images_for_charlie/flavell_data.nwb"
result_parent_dir = "/lisc/scratch/neurobiology/zimmer/schwartz/traces_mit_flip_finetune/"

In [11]:

final_dict = build_final_dict(
        gt_path=ground_truth_path,
        result_dir=result_parent_dir,
    )

print("\nFinal dictionary:")
for k, v in final_dict.items():
    print(f"{k}: {v}")

Loaded red and green data from NWB file: (1600, 64, 284, 120)
using calc_bipartite_from_ids
     81947  134171  218356  271021  170506  167371  270971  226429  154445  \
0.0   56.0    25.0   118.0    41.0   153.0    14.0   122.0   146.0    13.0   
1.0   26.0   167.0    72.0     7.0    74.0    53.0   170.0   117.0    47.0   
2.0   47.0   157.0   176.0    57.0    20.0    87.0     4.0   150.0    58.0   
3.0  156.0    86.0    90.0   172.0    45.0    98.0     7.0    63.0   147.0   
4.0   22.0    68.0   165.0    32.0   129.0    75.0   126.0   146.0     7.0   

     236547  ...  239878  193467  172626  213891  213136  212363  183592  \
0.0    94.0  ...   123.0    88.0   114.0   121.0    64.0   151.0     NaN   
1.0    33.0  ...   136.0     NaN    93.0    32.0   143.0   139.0     NaN   
2.0   128.0  ...   115.0    88.0     NaN     NaN   136.0     NaN     NaN   
3.0    94.0  ...    76.0    71.0     NaN     NaN   127.0    67.0   132.0   
4.0   169.0  ...     NaN    30.0     NaN     NaN    66.0   

In [12]:
print(final_dict["accuracy"])

[0.9290520089247107]


In [9]:
print(final_dict["accuracy"])

[0.9222503761869346, 0.9368870747358043, 0.906015418279158, 0.9280618157289379, None, 0.9040399069434503, 0.9294454918103672, 0.904790939238152, 0.9196948994240448, 0.927006762716848, 0.915942413587209, 0.9204937521228481, None, 0.9376826884826263, 0.9133110432717934, 0.930249753532698, None, 0.9220990366155283, None, 0.9100133784318345, 0.908576498161643, 0.9093171802237408, 0.9294541397858761, None, 0.9252685196395524, 0.9287536537696526, 0.904109709967735, 0.9080209690344127, 0.9270370306311293, None, None, 0.9096707136380725]


In [13]:
print(final_dict["hyperparams"][3])

2025_08_01trial_27_db-min_cluster_size0.5_min_samples0.01_umap-n_components10_n_neighbors5_min_dist0.1
