## Syllable Analysis

#### This notebook demonstrates common analyses of syllable statistics derived from keypoint-MoSeq. 

- Make a copy of this notebook if you plan to make changes and want them saved
- Go to "Runtime">"change runtime type" and select "Python 3". This notebook does not require a GPU.
- If you have not already run the keypoint MoSeq tutorial, download [example output](https://drive.google.com/drive/folders/1Fh9gWCsIqvV8Kl2BDtZxVNIH08sn7yxz?usp=share_link) to your drive or create a shortcut to it. 


### Load modeling results

- `project_dir` should point to the example data or to the project directory you used for modeling
- `name` specifies which modeling output to use; it should be a subdirectory of `project_dir`

In [None]:
import keypoint_moseq as kpm
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 100

In [None]:
project_dir = 'demo_project'
name = '2022_11_19-17_58_32'

# define config loader
config = lambda: kpm.load_config(project_dir)

# load modeling results: `results` is a dict mapping each 
# experiment to a sub-dict with time-series of model outputs, 
# including syllable sequence, centroid and heading. 
results = kpm.load_results(project_dir=project_dir, name=name)

### Collate results into a dataframe


Combine syllables and scalar-measures of behvior (velocity, centroid, etc.) into a single `DataFrame` that includes all experimental recordings. The `DataFrame` format is required for the remainder of this notebook generally is a useful way to organize data and export it for downstream analysis. 

#### *The first few rows of `moseq_df` are rendered after the following cell is executed*

In [None]:
moseq_df = kpm.compute_moseq_df(results, **config())
print('moseq_df has shape', moseq_df.shape)
moseq_df.head()

### Add group labels
Use the following cell to assign a group label to each recording. Group labels are required for group-wise comparisons, e.g. to analyze differences in syllable frequency across conditions. 
- Based on the list of sessions in the dataframe below, create a matching list of group labels.
- The n'th group label in your list will be assigned to the session in the n'th row of the dataframe

In [None]:
group_lbl_df = moseq_df[['session_name', 'uuid']].drop_duplicates(['uuid'])
print('Session information below')
group_lbl_df

#### *The first few rows of `moseq_df` are rendered after the following cell is executed*

In [None]:
# assign group labels.
group_labels = [
    'group1',
    'group2',
    'group1',
    'group2',
    'group1',
    'group1',
    'group2',
    'group2',
    'group1',
    'group2',
]

# merge group labels into moseq_df
group_lbl_df['group'] = group_labels
group_lbl_df=group_lbl_df.drop(['session_name'], axis=1)
moseq_df = pd.merge(moseq_df, group_lbl_df, how = 'right', on='uuid')
moseq_df.head()

#### [optional] export moseq_df

The following command saves `moseq_df` to the csv file `[project_dir]/[name]/moseq_df.csv`

In [None]:
# moseq_df.to_csv(os.path.join(project_dir, name, 'moseq_df.csv'))

### Aggregate statistics for each recording

`kpm.compute_stats_df` creates a `stats_df` dataframe with statistical summaries for each recording session and group, including the min, max, mean, and std of scalar values associated with each syllable, as well as each syllable's frequency. `stats_df` is then used to plot syllable statistics and perform hypothesis testing.

#### *The first few rows of `stats_df` are rendered after the following cell is executed*

In [None]:
stats_df = kpm.compute_stats_df(moseq_df, **config())
stats_df.head()

#### [optional] export moseq_df

The following command saves `stats_df` to the csv file `[project_dir]/[name]/stats_df.csv`

In [None]:
# stats_df.to_csv(os.path.join(project_dir, name, 'stats_df.csv'))

### Generate behavioral fingerprint plots
Fingerprints plots summarize behavior by showing distributions of kinematic scalars and syllable frequencies.

In [None]:
summary, range_dict = kpm.create_fingerprint_dataframe(moseq_df, stats_df)
kpm.plotting_fingerprint(summary, range_dict)

### Generate syllable frequency plots
Compare the frequency of each syllable across groups

In [None]:
# groups to be plotted
groups = ['group1','group2']

# name of the control group
ctrl_group = ['group1']

# name of the experimental group
exp_group = ['group2']

kpm.plot_syll_stats_with_sem(stats_df, groups=groups, ctrl_group=ctrl_group, exp_group=exp_group)

## Transition matrices
Transition matrices compactly represent the frequency any syllable transitions into any other syllable and is one way to describe structure in behavior.
The row of the transition matrix represents an incoming syllable, while the column represents the outgoing syllable and the value at a specific row and column position represent the frequency the incoming syllable transitions into the outgoing syllable.

These plots can help visualize gross changes in the structure of behavior between two experimental groups. For example, certain syllables that frequently transition into one set of syllables in one experimental condition might transition into a completely different set in another experimental condition.

In [None]:
# set maximum syllable to include
max_syllables = int(stats_df.syllable.max())
print('maximum syllable to include:', max_syllables)
# select a transition matrix normalization method
normalize='bigram' # options: bigram, columns, rows

# Get modeled session uuids to compute group-mean transition graph for each group
syll_key = 'syllables_reindexed'
uuid_groups = stats_df[['uuid', 'group', 'session_name']].drop_duplicates(['uuid']).to_numpy()
label_group, uuids, sessions = uuid_groups[:,1], uuid_groups[:, 0], uuid_groups[:,2]
group = list(set(label_group))
print('Group(s):', ', '.join(group))
model_labels = [results[session][syll_key] for session in sessions]
# compute transition matrices and usages for each group
trans_mats, usages = kpm.get_group_trans_mats(model_labels, label_group, group, max_sylls=max_syllables, normalize=normalize)


In [None]:
# plot the trnsition matrices
fig, ax = plt.subplots(1, len(group), figsize=(12, 15), sharex=False, sharey=True)
title_map = dict(bigram='Bigram', columns='Incoming', rows='Outgoing')

# max color threshold for graphs - set to any value 
color_lim = max([x.max() for x in trans_mats])

for i, g in enumerate(group):
    h = ax[i].imshow(trans_mats[i][:max_syllables,:max_syllables], cmap='cubehelix', vmax=color_lim)
    if i == 0:
        ax[i].set_ylabel('Incoming syllable')
        plt.yticks(np.arange(0, max_syllables, 4))
    cb = fig.colorbar(h, ax=ax[i], fraction=0.046, pad=0.04)
    cb.set_label(f'{title_map[normalize]} transition probability')
    ax[i].set_xlabel('Outgoing syllable')
    ax[i].set_title(g)
    ax[i].set_xticks(np.arange(0, max_syllables, 4))
    