In [46]:
mode = "c"

In [47]:
import os
import pandas as pd
from pathlib import Path

import keypoint_moseq as kpms

In [48]:
unsupervised_aging_dir = Path(os.environ["UNSUPERVISED_AGING"])
kpms_dir = unsupervised_aging_dir / "data/kpms_projects"

project_name, model_name = {
    "b6": (
        "2025-07-03_kpms-v2",
        "2025-07-07_model-2"
    ),
    "do": (
        "2025-07-16_kpms-v3",
        "2025-07-16_model-4"
    ),
    "c": (
        "2025-09-20_kpms-v5_150_6",
        "2025-09-21_model-1"
    )
}[mode]

In [49]:
kpms_model_dir = kpms_dir / project_name / model_name

def get_results(filename, path = kpms_model_dir):
    path = Path(path)
    if not path.exists():
         raise FileNotFoundError(f"Path not found: {path}")

    target_file = path / filename
    if not target_file.exists():
        raise FileNotFoundError(f"File not found: {target_file}")

    if filename == "results.h5":
        return kpms.load_results(path.parent, path.name)

    results_path = path / "results.h5"
    backup_path = path / "results.h5.bak"
    has_existing_results = results_path.exists()
    
    if has_existing_results and backup_path.exists():
        raise FileExistsError(f"Backup file {backup_path} already exists. Cannot safely proceed.")

    try:
        if has_existing_results:
            results_path.rename(backup_path)
        target_file.rename(results_path)

        return kpms.load_results(path.parent, path.name)

    finally:
        if results_path.exists():
            results_path.rename(target_file)
        if has_existing_results and backup_path.exists():
            backup_path.rename(results_path)

agg_results = {}
for result_file in kpms_model_dir.rglob("results*.h5"):
    result_filename = result_file.name
    print(f"Retrieving results for {result_filename}...")

    results = get_results(result_filename)
    syllable_dict = {
        name: entry["syllable"] for name, entry in results.items()
    }
    agg_results |= syllable_dict

Retrieving results for results-2.h5...
Retrieving results for results.h5...
Retrieving results for results-3.h5...


In [50]:
print(len(agg_results))

1223


In [51]:
def _get_syllable_frequency_statistics(th: float = 0.0):
    sequences = [info["syllable"] for info in results.values()]
    uniq = sorted({s for seq in sequences for s in seq})
    if th > 0.0:
        global_counts = {s: 0 for s in uniq}
        for seq in sequences:
            for s in seq:
                global_counts[s] += 1
        total = sum(global_counts.values())
        uniq = [s for s in uniq if total and global_counts[s] / total >= th]

    if not uniq:
        return {}

    idx = {s: i for i, s in enumerate(uniq)}
    n = len(uniq)

    out = {}
    for s in uniq:
        out[f"avg_bout_length_{s}"] = []
        out[f"total_duration_{s}"] = []
        out[f"num_bouts_{s}"] = []

    for _, info in tqdm(results.items()):
        seq = info["syllable"]
        dur = None
        for k in ("durations", "duration", "syllable_durations", "syllable_duration"):
            if k in info and hasattr(info[k], "__len__") and len(info[k]) == len(seq):
                dur = info[k]
                break

        total_len = [0]*n
        bout_cnt = [0]*n
        sum_dur = [0.0]*n

        prev = None
        run_len = 0
        for i, s in enumerate(seq):
            if s == prev:
                run_len += 1
            else:
                if prev in idx:
                    j = idx[prev]
                    total_len[j] += run_len
                    bout_cnt[j] += 1
                prev = s
                run_len = 1
            if s in idx and dur is not None:
                sum_dur[idx[s]] += float(dur[i])

        if prev in idx:
            j = idx[prev]
            total_len[j] += run_len
            bout_cnt[j] += 1

        if dur is None:
            for j in range(n):
                sum_dur[j] = float(total_len[j])

        for j, s in enumerate(uniq):
            abl = (total_len[j] / bout_cnt[j]) if bout_cnt[j] else 0.0
            out[f"avg_bout_length_{s}"].append(abl)
            out[f"total_duration_{s}"].append(sum_dur[j])
            out[f"num_bouts_{s}"].append(int(bout_cnt[j]))

    return out

syllable_frequency_statistics = _get_syllable_frequency_statistics()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:01<00:00, 66.87it/s]


In [52]:
data_dir = unsupervised_aging_dir / "final_data_curation"
feature_sheet_filename = {
    "b6": "2026-01-28_UA_B6_masterdf.csv",
    "do": "2026-01-28_UA_DO_masterdf-trans.csv",
    "c":  "2025-12-30_UA_combined_masterdf.csv"
}[mode]

feature_df = pd.read_csv(data_dir / feature_sheet_filename)
feature_df.head()

Unnamed: 0,NetworkFilename,latent_embedding_mean_0,latent_embedding_mean_1,latent_embedding_mean_2,latent_embedding_mean_3,latent_embedding_mean_4,latent_embedding_mean_5,latent_embedding_mean_6,latent_embedding_mean_7,latent_embedding_mean_8,...,grooming_duration_secs,Rearing_supported_T5,Rearing_supported_T20,Rearing_supported_T55,Rearing_unsupported_T5,Rearing_unsupported_T20,Rearing_unsupported_T55,Grooming_T5,Grooming_T20,Grooming_T55
0,LL1-B2B/2019-12-24_SPD/LL1-1_AgedB6-0420,0.040169,0.198675,-0.037979,0.052876,-0.003299,-0.010149,-0.055371,0.060958,-0.085216,...,68.20147,0.047222,0.059278,0.045596,0.002,0.002111,0.004485,0.018111,0.01425,0.024364
1,LL1-B2B/2020-01-02_SPD/LL1-1_AgedB6-0744,-0.296932,0.263532,0.100476,-0.1474,0.061194,-0.023672,-0.257141,0.099431,-0.173461,...,170.43687,0.088111,0.076472,0.070101,0.011222,0.057806,0.114293,0.022778,0.024194,0.049485
2,LL1-B2B/2020-01-02_SPD/LL1-4_AgedB6-0746,0.013086,0.418435,0.148745,-0.155249,0.053694,-0.02038,-0.269689,0.178564,-0.192979,...,148.8153,0.044444,0.048361,0.045515,0.008111,0.0095,0.015737,0.016556,0.017889,0.018061
3,LL1-B2B/2020-06-16_SPD/AgedB6-0411,-0.151999,0.127667,0.05441,0.010174,0.019283,-0.106259,-0.251713,0.064764,-0.170414,...,69.73327,0.029667,0.018194,0.031596,0.001333,0.019694,0.024646,0.024778,0.016806,0.016273
4,LL1-B2B/2020-06-17_SPD/AgedB6-0420,-0.15182,0.409758,-0.160638,0.16057,0.014334,0.061617,-0.026318,0.140276,0.017099,...,72.26663,0.017667,0.024333,0.028747,0.005333,0.002417,0.005384,0.012333,0.005833,0.017707


In [53]:
from tqdm import tqdm

def _get_syllable_frequency_statistics(agg_results, th: float = 0.0):
    sequences = list(agg_results.values())
    uniq = sorted({s for seq in sequences for s in seq})
    
    if th > 0.0:
        global_counts = {s: 0 for s in uniq}
        for seq in sequences:
            for s in seq:
                global_counts[s] += 1
        total = sum(global_counts.values())
        uniq = [s for s in uniq if total and global_counts[s] / total >= th]

    if not uniq:
        return pd.DataFrame()

    idx = {s: i for i, s in enumerate(uniq)}
    n = len(uniq)

    out = {'name': []}
    for s in uniq:
        out[f"avg_bout_length_{s}"] = []
        out[f"total_duration_{s}"] = []
        out[f"num_bouts_{s}"] = []

    for name, seq in tqdm(agg_results.items(), desc="Processing sequences"):
        out['name'].append(name)
        
        total_len = [0]*n
        bout_cnt = [0]*n
        sum_dur = [0.0]*n

        prev = None
        run_len = 0
        for i, s in enumerate(seq):
            if s == prev:
                run_len += 1
            else:
                if prev in idx:
                    j = idx[prev]
                    total_len[j] += run_len
                    bout_cnt[j] += 1
                prev = s
                run_len = 1
            if s in idx:
                sum_dur[idx[s]] += 1.0

        if prev in idx:
            j = idx[prev]
            total_len[j] += run_len
            bout_cnt[j] += 1

        for j in range(n):
            sum_dur[j] = float(total_len[j])

        for j, s in enumerate(uniq):
            abl = (total_len[j] / bout_cnt[j]) if bout_cnt[j] else 0.0
            out[f"avg_bout_length_{s}"].append(abl)
            out[f"total_duration_{s}"].append(sum_dur[j])
            out[f"num_bouts_{s}"].append(int(bout_cnt[j]))

    df = pd.DataFrame(out)
    df.set_index('name', inplace=True)
    return df

syllable_stats_df = _get_syllable_frequency_statistics(agg_results)
print(f"Computed statistics for {len(syllable_stats_df)} sequences")
print(f"Columns: {len(syllable_stats_df.columns)} ({len(syllable_stats_df.columns)//3} unique syllables)")
syllable_stats_df.head()

Processing sequences: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 1223/1223 [00:27<00:00, 44.99it/s]

Computed statistics for 1223 sequences
Columns: 285 (95 unique syllables)





Unnamed: 0_level_0,avg_bout_length_0,total_duration_0,num_bouts_0,avg_bout_length_1,total_duration_1,num_bouts_1,avg_bout_length_2,total_duration_2,num_bouts_2,avg_bout_length_3,...,num_bouts_91,avg_bout_length_92,total_duration_92,num_bouts_92,avg_bout_length_93,total_duration_93,num_bouts_93,avg_bout_length_94,total_duration_94,num_bouts_94
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
LL1-B6__2023-07-06_MFS__DO2271_DO_F_25075.csv,3.0,36.0,12,2.636364,29.0,11,2.5,25.0,10,3.851974,...,7,0.0,0.0,0,0.0,0.0,0,0.0,0.0,0
LL1-B6__2023-07-06_MFS__DO2272_DO_F_25076.csv,3.7,37.0,10,1.75,14.0,8,3.769231,49.0,13,3.868056,...,6,0.0,0.0,0,0.0,0.0,0,0.0,0.0,0
LL1-B6__2023-07-06_MFS__DO2273_DO_F_25077.csv,3.787879,125.0,33,2.657895,101.0,38,2.774194,86.0,31,4.694158,...,4,0.0,0.0,0,0.0,0.0,0,0.0,0.0,0
LL1-B6__2023-07-06_MFS__DO2274_DO_F_25078.csv,2.870968,623.0,217,2.574468,484.0,188,4.65625,745.0,160,3.674419,...,0,0.0,0.0,0,0.0,0.0,0,0.0,0.0,0
LL1-B6__2023-07-06_MFS__DO2331_DO_M_25047.csv,2.25,135.0,60,3.140351,179.0,57,2.928571,82.0,28,4.090909,...,38,0.0,0.0,0,0.0,0.0,0,0.0,0.0,0


In [55]:
def match_feature_df_to_stats(feature_df, syllable_stats_df, k=None):
    stat_cols = [col for col in syllable_stats_df.columns 
                 if col.startswith('avg_bout_length_')]
    
    common_cols = [col for col in stat_cols if col in feature_df.columns]
    
    if k is not None:
        syllable_ids = sorted(set([
            int(col.split('_')[-1]) 
            for col in common_cols 
            if col.split('_')[-1].isdigit()
        ]))[:k]
        
        common_cols = [col for col in common_cols 
                      if col.split('_')[-1].isdigit() and 
                         int(col.split('_')[-1]) in syllable_ids]
    
    print(f"Found {len(common_cols)} common syllable statistic columns between feature_df and syllable_stats_df")
    print(f"Total columns in syllable_stats_df: {len(stat_cols)}")
    if k is not None:
        print(f"Using only first {k} syllables ({len(common_cols)} columns)")
    
    if len(common_cols) == 0:
        print("WARNING: No common syllable statistic columns found!")
        return 0, len(feature_df), []
    
    success_count = 0
    failure_count = 0
    match_details = []
    
    for idx, row in tqdm(feature_df.iterrows(), total=len(feature_df), desc="Matching rows"):
        network_filename = row.get('NetworkFilename', f'row_{idx}')
        feature_values = row[common_cols]
        
        matches = []
        for stats_idx, stats_row in syllable_stats_df.iterrows():
            stats_values = stats_row[common_cols]
            
            if all(abs(float(fv) - float(sv)) < 0.0001 if isinstance(fv, (int, float)) and isinstance(sv, (int, float)) 
                   else fv == sv 
                   for fv, sv in zip(feature_values, stats_values)):
                matches.append(stats_idx)
        
        if len(matches) == 1:
            success_count += 1
            match_details.append({
                'feature_row': idx,
                'network_filename': network_filename,
                'status': 'SUCCESS',
                'matched_to': matches[0],
                'num_matches': 1
            })
        else:
            failure_count += 1
            match_details.append({
                'feature_row': idx,
                'network_filename': network_filename,
                'status': 'FAILURE',
                'matched_to': matches if matches else None,
                'num_matches': len(matches)
            })
    
    return success_count, failure_count, match_details

success_count, failure_count, match_details = match_feature_df_to_stats(feature_df, syllable_stats_df, k=10)

print(f"\n{'='*60}")
print(f"MATCHING RESULTS:")
print(f"{'='*60}")
print(f"Total rows in feature_df: {len(feature_df)}")
print(f"SUCCESS (matched to exactly 1 element): {success_count}")
print(f"FAILURE (matched to 0 or >1 elements): {failure_count}")
print(f"Success rate: {100*success_count/(success_count+failure_count):.2f}%")
print(f"{'='*60}")

failures = [m for m in match_details if m['status'] == 'FAILURE']
if failures:
    print(f"\nFirst 5 failures:")
    for i, fail in enumerate(failures[:5]):
        print(f"  {i+1}. Row {fail['feature_row']} ({fail['network_filename']}): "
              f"{fail['num_matches']} matches")


Found 10 common syllable statistic columns between feature_df and syllable_stats_df
Total columns in syllable_stats_df: 95
Using only first 10 syllables (10 columns)


Matching rows: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1138/1138 [04:39<00:00,  4.07it/s]


MATCHING RESULTS:
Total rows in feature_df: 1138
SUCCESS (matched to exactly 1 element): 1074
FAILURE (matched to 0 or >1 elements): 64
Success rate: 94.38%

First 5 failures:
  1. Row 638 (/DO2279_DO_F_25083): 0 matches
  2. Row 639 (/DO2343_DO_M_25058): 0 matches
  3. Row 640 (/DO2342_DO_M_25057): 0 matches
  4. Row 641 (/DO2281_DO_F_25085): 0 matches
  5. Row 642 (/DO2282_DO_F_25086): 0 matches





In [59]:
def create_network_to_syllable_mapping(feature_df, agg_results, match_details):
    mapping_data = []
    
    for match in match_details:
        network_filename = match['network_filename']
        
        if match['status'] == 'SUCCESS':
            matched_name = match['matched_to']
            
            if matched_name in agg_results:
                syllable_sequence = agg_results[matched_name]
                mapping_data.append({
                    'NetworkFilename': network_filename,
                    'matched_name': matched_name,
                    'syllable_sequence': syllable_sequence
                })
        else:
            # For unmatched elements, include them with empty syllable sequence
            mapping_data.append({
                'NetworkFilename': network_filename,
                'matched_name': None,  # or '' if you prefer empty string
                'syllable_sequence': []
            })
    
    mapping_df = pd.DataFrame(mapping_data)
    return mapping_df

network_syllable_mapping = create_network_to_syllable_mapping(feature_df, agg_results, match_details)

print(f"\n{'='*60}")
print(f"NETWORK TO SYLLABLE MAPPING:")
print(f"{'='*60}")
print(f"Successfully mapped {len(network_syllable_mapping)} NetworkFilenames to syllable sequences")
print(f"{'='*60}")
network_syllable_mapping.head()


NETWORK TO SYLLABLE MAPPING:
Successfully mapped 1138 NetworkFilenames to syllable sequences


Unnamed: 0,NetworkFilename,matched_name,syllable_sequence
0,LL1-B2B/2019-12-24_SPD/LL1-1_AgedB6-0420,LL1-B2B__2019-12-24_SPD__LL1-1_AgedB6-0420.csv,"[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1..."
1,LL1-B2B/2020-01-02_SPD/LL1-1_AgedB6-0744,LL1-B2B__2020-01-02_SPD__LL1-1_AgedB6-0744.csv,"[35, 35, 35, 35, 44, 44, 13, 13, 13, 13, 13, 1..."
2,LL1-B2B/2020-01-02_SPD/LL1-4_AgedB6-0746,LL1-B2B__2020-01-02_SPD__LL1-4_AgedB6-0746.csv,"[45, 45, 45, 45, 45, 45, 9, 18, 18, 18, 18, 18..."
3,LL1-B2B/2020-06-16_SPD/AgedB6-0411,LL1-B2B__2020-06-16_SPD__AgedB6-0411.csv,"[30, 30, 30, 30, 30, 12, 12, 12, 12, 12, 54, 5..."
4,LL1-B2B/2020-06-17_SPD/AgedB6-0420,LL1-B2B__2020-06-17_SPD__AgedB6-0420.csv,"[54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 5..."


In [60]:
mapping_csv = network_syllable_mapping.copy()
mapping_csv['syllable_sequence'] = mapping_csv['syllable_sequence'].apply(
    lambda x: ','.join(map(str, x))
)
mapping_csv.to_csv(data_dir / f"2025-01-28_{mode}_syll-seq.csv", index=False)