# Compute syllable similarities

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm.auto import tqdm
from toolz import valmap
from scipy.spatial.distance import squareform, pdist
from scipy.cluster.hierarchy import dendrogram, complete, ward, leaves_list
from tslearn.barycenters import softdtw_barycenter
from tslearn.utils import to_time_series_dataset
from aging.moseq_modeling.pca import apply_whitening, get_whitening_params_from_training_data
from aging.plotting import format_plots, figure, save_factory, PlotConfig

In [None]:
format_plots()
saver = save_factory(PlotConfig().save_path / "fig-s2", tight_layout=False)

In [None]:
mu, L = get_whitening_params_from_training_data(Path('/n/groups/datta/win/longtogeny/data/ontogeny/version_11'))

In [None]:
df = pd.read_parquet('/n/groups/datta/win/longtogeny/data/ontogeny/version_11/ontogeny_males_syllable_df_v00.parquet')
keep_syllables = np.loadtxt('/n/groups/datta/win/longtogeny/data/ontogeny/version_11/to_keep_syllables_raw.txt', dtype=int)

# compute durations
idx = np.where(df['onsets'])[0]
durs = np.diff(idx).tolist() + [len(df) - idx[0]]
df.loc[df.index[idx], 'dur'] = durs
df['dur'] = df['dur'].ffill().astype('int16')

In [None]:
usage_df = pd.read_parquet('/n/groups/datta/win/longtogeny/data/ontogeny/version_11/ontogeny_males_raw_usage_matrix_v00.parquet')

In [None]:
pc_keys = [f"pc_{i:02d}" for i in range(10)]

In [None]:
df[pc_keys] = df[pc_keys].interpolate()

In [None]:
df.query('onsets')['dur'].median()

In [None]:
# filter short duration syllables
df = df.query('dur > 6 & dur < 25')  # frames

In [None]:
# filter for only syllables we are keeping in the paper
df = df[df['syllables'].isin(keep_syllables)]

In [None]:
syll_counts = df.query("onsets").groupby("age")["syllables"].value_counts()

In [None]:
# only include age/syllable pairs that have more than 50 examples
syll_data = syll_counts[syll_counts > 50]

In [None]:
pd.pivot_table(
    syll_counts.reset_index(),
    columns="age",
    index="syllables",
    values="count",
)

In [None]:
# sample a subset of syllables across ages
def sample_syllable(df, syllable, ages, n_samples_per_age=6, length=20):
    df = df.query("syllables == @syllable")
    df = df[df["age"].isin(ages)].copy()
    df["unique_id"] = df["onsets"].cumsum()
    sample = []
    for _age, _df in df.groupby("age", sort=False):
        idx = np.random.permutation(_df["unique_id"].unique())[:n_samples_per_age]
        sample.append(_df[_df["unique_id"].isin(idx)])
    sample = pd.concat(sample)

    # repeat sampling until I get what I want
    while sample['dur'].max() < length:
        sample = []
        for _age, _df in df.groupby("age", sort=False):
            idx = np.random.permutation(_df["unique_id"].unique())[:n_samples_per_age]
            sample.append(_df[_df["unique_id"].isin(idx)])
        sample = pd.concat(sample)
        
    return sample

In [None]:
def construct_ts(df):
    out = []
    for _, _df in df.groupby('unique_id', sort=False):
        out.append(apply_whitening(_df[pc_keys].to_numpy(), L, mu))
    return to_time_series_dataset(out)

In [None]:
np.random.seed(0)

subset = []
for syllable in tqdm(keep_syllables):
    has_ages = syll_data.loc[pd.IndexSlice[:, syllable]].index
    sample = sample_syllable(df, syllable, has_ages, n_samples_per_age=8)
    ts = construct_ts(sample)
    break
    # subset.append()
# subset = pd.concat(subset)

In [None]:
traj = softdtw_barycenter(ts, gamma=7, max_iter=1_000)
plt.plot(traj[:, 0])

In [None]:
np.random.seed(0)

subset = []
for syllable in tqdm(keep_syllables):
    has_ages = syll_data.loc[pd.IndexSlice[:, syllable]].index
    sample = sample_syllable(df, syllable, has_ages, n_samples_per_age=15)
    ts = construct_ts(sample)

    plt.figure()
    traj = softdtw_barycenter(ts, gamma=0.1, max_iter=1_000)
    plt.plot(traj[:, 0], label=0.1)
    
    traj = softdtw_barycenter(ts, gamma=1, max_iter=1_000)
    plt.plot(traj[:, 0], label=1)
    
    traj = softdtw_barycenter(ts, gamma=2, max_iter=1_000)
    plt.plot(traj[:, 0], label=2)
    
    traj = softdtw_barycenter(ts, gamma=4, max_iter=1_000)
    plt.plot(traj[:, 0], label=4)
    
    traj = softdtw_barycenter(ts, gamma=8, max_iter=1_000)
    plt.plot(traj[:, 0], label=8)
    plt.title(syllable)
    plt.legend()
    plt.show()

In [None]:
np.random.seed(0)

syllable_traj = {}
for syllable in tqdm(keep_syllables):
    has_ages = syll_data.loc[pd.IndexSlice[:, syllable]].index
    sample = sample_syllable(df, syllable, has_ages, n_samples_per_age=16, length=24)
    ts = construct_ts(sample)
    bc = softdtw_barycenter(ts, gamma=1, max_iter=1_000)
    syllable_traj[syllable] = bc

In [None]:
flattened = valmap(lambda v: v.flatten(), syllable_traj)

In [None]:
mtx = np.array(list(flattened.values()))

In [None]:
dists = pdist(mtx, metric='correlation')

In [None]:
sns.clustermap(squareform(dists), method='complete')

In [None]:
Z = complete(dists)
leaves = leaves_list(Z)
fig, ax = plt.subplots(figsize=(2, 10), gridspec_kw=)

dendrogram(Z, orientation='right', above_threshold_color='k', color_threshold=0, )
ax = fig.gca()
ax.set(xlabel='MoSeq distance')
sns.despine()

In [None]:
leaves = leaves_list(Z)


In [None]:
mosaic = [[str(l), 'D'] for l in leaves]

In [None]:
avg_use = usage_df[keep_syllables].groupby('age').mean()

In [None]:
fig, ax = plt.subplot_mosaic(mosaic, figsize=(3, 12), width_ratios=[3, 5])

dendrogram(Z, orientation='right', above_threshold_color='k', color_threshold=0, ax=ax['D'])

for i, l in enumerate(leaves):
    a = ax[str(l)]
    a.plot(avg_use.index, avg_use[keep_syllables[l]], color='k', lw=0.75)
    if i < len(leaves) - 1:
        a.set(yticks=[], xticks=[])
    else:
        a.set(yticks=[])
a.set(xlabel="Age (weeks)")
ax['D'].set(xlabel='MoSeq distance')
sns.despine()
saver(fig, "behavior-similarity-dendrogram");