In [None]:
import keypoint_moseq as kpms

In [None]:
import h5py
import numpy as np
import os
import pandas as pd

from pathlib import Path
from tqdm import tqdm

In [None]:
project_name = "2025-06-13_kpms-feature-extraction"
project_dir = Path("/projects/kumar-lab/miaod/projects/uvFI/experiments/2025-06-13_kpms-feature-extraction/")
outputs_dir = project_dir / "outputs"
data_dir = project_dir / "data/2025-06-13_kpms-inference_data"

min_frequency = 0.005

### Data Loading and Preprocessing

In [None]:
kpms_project_dir = str(project_dir / "data")
kpms_model_name = "2025-06-13_kpms-inference_data"
kpms_group_csv_path = project_dir / "data/index.csv"

pose_dir = data_dir / "poses_csv"
csv_str = "name,group\n" + "\n".join([f"{file.name},{file.stem}" for file in pose_dir.iterdir()])
print(csv_str[:100])

kpms_group_csv_path.write_text(csv_str)

In [None]:
kpms_result_path = Path(f"{kpms_project_dir}/{kpms_model_name}/results.h5")

# def compute_batch_moseq_stats_df(batch_file: Path):
#     try:
#         os.symlink(batch_file, str(kpms_result_path))
#     except OSError as e:
#         print(f"creation of symlink {batch_file.name} -> {kpms_result_path.name} failed")
#         return

#     moseq_df = kpms.compute_moseq_df(kpms_project_dir, kpms_model_name, smooth_heading=True)
#     stats_df = kpms.compute_stats_df(kpms_project_dir, kpms_model_name, moseq_df, min_frequency=0.005, groupby=["name"])
    
#     kpms_result_path.unlink()
#     return (moseq_df, stats_df)

In [None]:
num_batches = 79
# indices = range(1, num_batches+1)

# kpms_result_path.unlink(missing_ok=True)
# batch_dfs = [
#     compute_batch_moseq_stats_df(data_dir / f"out_{batch}.h5")
#     for batch in tqdm(indices, "computing batches", total=len(indices))
# ]

In [None]:
# moseq_dfs, stats_dfs = zip(*batch_dfs)
# moseq_df = pd.concat(moseq_dfs)
# stats_df = pd.concat(stats_dfs)

In [None]:
with h5py.File(kpms_result_path, "w") as h5out:
    for batch in range(1, num_batches+1):
        input_file = data_dir / f"out_{batch}.h5"
        with h5py.File(input_file, "r") as h5in:
            for group_name in h5in:
                h5in.copy(group_name, h5out)

In [None]:
results = kpms.load_results(kpms_project_dir, kpms_model_name)
kpms.save_results_as_csv(results, kpms_project_dir, kpms_model_name)

In [None]:
results["LL1-B2B_LL1-B2B__2019-09-04_SPD__LL1-1_AgedB6-0396_pose_est_v6.csv"]["latent_state"].shape

In [None]:
moseq_df = kpms.compute_moseq_df(kpms_project_dir, kpms_model_name, smooth_heading=True)

In [None]:
stats_df = kpms.compute_stats_df(kpms_project_dir, kpms_model_name, moseq_df, min_frequency=min_frequency, groupby=["name"])

In [None]:
print(moseq_df.shape)
moseq_df.head()

In [None]:
# moseq_df.to_csv(outputs_dir / "moseq_df.csv.gz", index=False, compression="gzip")
# stats_df.to_csv(outputs_dir / "stats_df.csv.gz", index=False, compression="gzip")

In [None]:
%%capture
normalize = "bigram"

trans_mats, usages, groups, syll_include = kpms.generate_transition_matrices(
    kpms_project_dir,
    kpms_model_name,
    normalize=normalize,
    min_frequency=min_frequency,
)

In [None]:
# kpms.visualize_transition_bigram(
#     kpms_project_dir,
#     kpms_model_name,
#     groups,
#     trans_mats
#     syll_include,
#     normalize=normalize,
#     show_syllable_names=True,
# )

In [None]:
# kpms.plot_transition_graph_group(
#     kpms_project_dir,
#     kpms_model_name,
#     groups,
#     trans_mats,
#     usages,
#     syll_include,
#     layout="circular",
#     show_syllable_names=False,
#     save_dir = project_dir / "outputs/transition_graphs.png"
# )

In [None]:
# kpms.plot_transition_graph_difference(
#     kpms_project_dir, kpms_model_name, groups, trans_mats, usages, syll_include, layout="circular",
#     save_dir = project_dir / "outputs/transition_graphs_diff.png"
# )

### Feature Extraction

In [None]:
videos = list(dict.fromkeys(stats_df["name"]))
videos[:5]

In [None]:
stats_df["name"]

In [None]:
num_videos = len(videos)
num_syllables = len(trans_mats[0][0])
print(num_videos, num_syllables)

In [None]:
videos_df = pd.DataFrame(videos, columns=["name"])
videos_df["trans_mat"] = trans_mats

In [None]:
freq_wide = stats_df.pivot(
    index="name",
    columns="syllable",
    values="frequency",
).fillna(0)
freq_array = freq_wide.to_numpy()

videos_df["freqs"] = list(freq_array)
videos_df

In [None]:
videos_df.to_pickle(outputs_dir / f"{project_name}_videos-df.pkl")

In [None]:
name_to_freq = dict(zip(freq_wide.index, freq_array))
name_to_freq