# DeepOF unsupervised pipeline: exploring the behavioral space

In [1]:
import os
os.chdir("../../..")
import deepof.data

In [2]:
%load_ext autoreload
%autoreload 2

In [None]:
import pickle
import numpy as np
with open("../train_models/deepof_unsupervised_VQVAE_encodings_input=coords_k=100_latdim=8_kmeans_loss=0.0_run=1.pkl", "rb") as handle:
    vqvae_solution = pickle.load(handle)

In [None]:
from sklearn.cluster import AgglomerativeClustering
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from deepof.post_hoc import get_transitions
from hmmlearn.hmm import GaussianHMM
from tqdm import tqdm_notebook as tqdm


def merge_and_smooth_clusters(
    n_clusters, centroids, embedding, concat_embedding, cluster_assignments
):
    """Merges the current clusters using a hierarchical agglomerative approach, and smoothens using a Gaussian HMM.

    Args:
        n_clusters (int): number of clusters to report.
        centroids (np.ndarray): precomputed means per cluster.
        embedding (tabdict): original deepof.TableDict object containing unsupervised embeddings.
        concat_embedding (np.ndarray): concatenated list of embeddings per animal in the dataset.
        cluster_assignments (tabdict): original deepof.TableDict object containing cluster assignments.

    Returns:
        new_soft_assignments (np.ndarray): concatenated postprocessed assignments for all animals in the dataset.
    """

    # Merge clusters ussing a hierarchical agglomerative approach
    new_hard_assignments = AgglomerativeClustering(
        n_clusters=n_clusters, compute_distances=True
    ).fit_predict(centroids)
    cluster_predictor = LinearDiscriminantAnalysis().fit(
        centroids, new_hard_assignments
    )
    centroids = cluster_predictor.means_
    new_soft_assignments = cluster_predictor.predict_proba(concat_embedding)

    # Rebuild the soft assignments dictionary per experimental animal
    new_soft_assignments = np.split(
        new_soft_assignments,
        np.cumsum([i.shape[0] for i in embedding.values()]),
    )
    new_soft_assignments = {
        key: val for key, val in zip(cluster_assignments.keys(), new_soft_assignments)
    }

    # Smooth assignments across time using a Gaussian HMM on the embeddings, with priors based on the clustering results
    for key, val in tqdm(new_soft_assignments.items()):

        hmm = GaussianHMM(
            startprob_prior=np.unique(val.argmax(axis=1), return_counts=True)[1],
            transmat_prior=get_transitions(val.argmax(axis=1), n_states=n_clusters) + 10,
            means_prior=centroids,
            n_components=n_clusters,
            covariance_type="diag",
            n_iter=100,
            tol=0.0001,
        )
        
        hmm.fit(embedding[key].numpy())
        new_soft_assignments[key] = hmm.predict_proba(embedding[key].numpy())

    return new_soft_assignments


def cluster_postprocessing(embedding, cluster_assignments, n_clusters="auto"):
    """Merges clusters using a hierarchical approach.

    Args:
        embedding (list): list of embeddings per animal in the dataset.
        cluster_assignments (list): list of cluster assignments per animal in the dataset.
        n_clusters (int): number of clusters to report.

    Returns:
        new_soft_assignments (list): list of new (merged) cluster assignments.

    """
    # Concatenate embeddings and cluster assignments in to unique np.ndarray objects
    concat_embedding = np.concatenate([tensor.numpy() for tensor in embedding.values()])
    hard_assignments = np.concatenate(
        [tensor.numpy().argmax(axis=1) for tensor in cluster_assignments.values()]
    )

    assert concat_embedding.shape[0] == hard_assignments.shape[0]

    # Get cluster centroids from the concatenated embeddings
    centroids = []
    for cluster in range(np.max(hard_assignments)):
        centroid = concat_embedding[hard_assignments == cluster]
        if len(centroid) == 0:
            continue
        centroid = np.mean(centroid, axis=0)
        centroids.append(centroid)

    centroids = np.stack(centroids)

    # Merge centroids using a hierarchical approach with the given resolution, and soft-assign instances to clusters
    if isinstance(n_clusters, int):
        new_soft_assignments = merge_and_smooth_clusters(
            n_clusters, centroids, embedding, concat_embedding, cluster_assignments
        )

    else:
        raise NotImplementedError

    return new_soft_assignments

In [None]:
# from hmmlearn.hmm import GaussianHMM

# new_ass = cluster_postprocessing(
#     vqvae_solution[0], 
#     vqvae_solution[1],
#     n_clusters=12
# )
# hcc = new_ass['20191203_Day1_SI_JB08_Test_54'].argmax(axis=1)

In [None]:
# import umap

# # Cluster on the original embedding space
# new_emb = umap.UMAP(n_components=2, n_neighbors=75).fit_transform(vqvae_solution[0]['20191203_Day1_SI_JB08_Test_54'])

# sns.scatterplot(x=new_emb[:, 0], y=new_emb[:, 1], hue=hcc, palette="tab20")

# plt.show()

In [None]:
# # How prevalent are these clusters?
# from collections import Counter
# print(Counter(hcc))

# new_ass = hcc

In [None]:
# # How often does the model change clusters?
# from collections import defaultdict

# lengths = defaultdict(list)
# cur = 0
# for i in range(1, len(new_ass)):
#     if new_ass[i-1] == new_ass[i]:
#         cur += 1
#     else:
#         lengths[new_ass[i-1]].append(cur)
#         cur = 1

# {key:np.mean(val) for key, val in lengths.items()}

In [None]:
# import pandas as pd

# # Duration histograms per cluster
# lengths_df = pd.DataFrame([lengths]).melt().explode("value").astype(int)
# sns.violinplot(data=lengths_df, x="variable", y="value")

# plt.axhline(25, linestyle="--", color="black")
        
# plt.show()

In [3]:
my_deepof_project = deepof.data.load("../../Desktop/deepOF_CSDS_tutorial_dataset/deepof_tutorial_saved_project_1672667128.pkl")

In [4]:
# Check scales across animals. Can we detect to which animal a given time series belongs to?
# Once happy with a solution, check that all animals show comparable cluster interpretations.

# Add preprocessing options to include multiple animals, concatenated and together in a graph

tt = my_deepof_project.get_coords(center="Center", align="Spine_1")
# ss = my_deepof_project.get_coords(speed=1)

# tt = cc.merge(ss)

tt = tt.preprocess(
    window_size=25,
    window_step=1,
    test_videos=1,
    scale="standard",
    handle_ids="split", # "concat" uses bps from != animals as features, "split"
)

tt = (tt[0][:25000], tt[1][:25000], tt[2][:25000], tt[3][:25000])

In [44]:
from deepof.utils import connect_mouse_topview
import networkx as nx

pp = my_deepof_project.get_graph_dataset(
    animal_id="B",
    center="Center",
    align="Spine_1",
    preprocess=True,
    scale="standard"
)

G = connect_mouse_topview(animal_ids=["B"], exclude_bodyparts=["Tail_1", "Tail_2", "Tail_tip"])
adj = nx.adjacency_matrix(G).todense()
pp = (pp[0][:25000], pp[1][:25000], pp[2][:25000], pp[3][:25000], pp[4][:25000], pp[5][:25000])

  adj = nx.adjacency_matrix(G).todense()


In [45]:
pp[1].shape

(25000, 25, 11)

In [None]:
%%time
cons = my_deepof_project.deep_unsupervised_embedding(
    pp,
    adjacency_matrix=adj,
    embedding_model="VaDE",
    epochs=10,
    encoder_type="TCN",
    n_components=15,
    latent_dim=8,
    kl_warmup=10,
    kl_annealing_mode="linear",
    batch_size=128,
    kmeans_loss=0.0,
    reg_cat_clusters=0.0,
)

2023-01-03 15:10:00.490775: I tensorflow/core/profiler/lib/profiler_session.cc:99] Profiler session initializing.
2023-01-03 15:10:00.490786: I tensorflow/core/profiler/lib/profiler_session.cc:114] Profiler session started.
2023-01-03 15:10:00.490882: I tensorflow/core/profiler/lib/profiler_session.cc:126] Profiler session tear down.


Epoch 1/10


2023-01-03 15:10:02.160424: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


Epoch 2/10
Epoch 3/10
  1/195 [..............................] - ETA: 22s - total_loss: 42.1775 - reconstruction_loss: 41.8540 - clustering_loss: 0.0000e+00 - prior_loss: 0.0000e+00 - kl_weight: 0.2000 - kl_divergence: -5.8449 - kmeans_loss: 1.3096 - number_of_populated_clusters: 15.0000 - confidence_in_selected_cluster: 0.3850

2023-01-03 15:10:42.934361: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


Epoch 4/10
Epoch 5/10
  1/195 [..............................] - ETA: 27s - total_loss: 46.2827 - reconstruction_loss: 47.1904 - clustering_loss: 0.0000e+00 - prior_loss: 0.0000e+00 - kl_weight: 0.4000 - kl_divergence: -5.7315 - kmeans_loss: 1.2273 - number_of_populated_clusters: 15.0000 - confidence_in_selected_cluster: 0.3609

2023-01-03 15:11:15.993809: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


Epoch 6/10
Epoch 7/10
  1/195 [..............................] - ETA: 25s - total_loss: 41.7762 - reconstruction_loss: 43.9682 - clustering_loss: 0.0000e+00 - prior_loss: 0.0000e+00 - kl_weight: 0.6000 - kl_divergence: -6.6221 - kmeans_loss: 1.4290 - number_of_populated_clusters: 15.0000 - confidence_in_selected_cluster: 0.4322

2023-01-03 15:11:47.288774: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


Epoch 8/10
Epoch 9/10
  1/195 [..............................] - ETA: 22s - total_loss: 52.1037 - reconstruction_loss: 56.1550 - clustering_loss: 0.0000e+00 - prior_loss: 0.0000e+00 - kl_weight: 0.8000 - kl_divergence: -8.8627 - kmeans_loss: 1.8298 - number_of_populated_clusters: 15.0000 - confidence_in_selected_cluster: 0.5976

2023-01-03 15:12:18.958510: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


Epoch 10/10
Epoch 1/10


2023-01-03 15:12:52.342067: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


  1/195 [..............................] - ETA: 10:28 - total_loss: 44.4326 - reconstruction_loss: 37.9789 - clustering_loss: -1.0000 - prior_loss: 2.7080 - kl_weight: 0.0000e+00 - kl_divergence: -10.8601 - kmeans_loss: 2.3813 - number_of_populated_clusters: 4.0000 - confidence_in_selected_cluster: 1.0000

2023-01-03 15:12:55.692480: I tensorflow/core/profiler/lib/profiler_session.cc:99] Profiler session initializing.
2023-01-03 15:12:55.692489: I tensorflow/core/profiler/lib/profiler_session.cc:114] Profiler session started.


  4/195 [..............................] - ETA: 6:25 - total_loss: 47.7165 - reconstruction_loss: 41.2788 - clustering_loss: -0.9950 - prior_loss: 2.7080 - kl_weight: 7.6923e-04 - kl_divergence: -10.8530 - kmeans_loss: 2.3820 - number_of_populated_clusters: 4.0000 - confidence_in_selected_cluster: 0.9967

2023-01-03 15:13:01.586760: I tensorflow/core/profiler/lib/profiler_session.cc:66] Profiler session collecting data.
2023-01-03 15:13:01.587378: I tensorflow/core/profiler/lib/profiler_session.cc:126] Profiler session tear down.
2023-01-03 15:13:01.588350: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: /Users/lucas_miranda/PycharmProjects/deepof/unsupervised_trained_models/fit/deepof_unsupervised_VaDE_TCN_encodings_input_type=coords_kmeans_loss=0.0_encoding=8_k=15_20230103-151000/plugins/profile/2023_01_03_15_13_01

2023-01-03 15:13:01.589270: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for trace.json.gz to /Users/lucas_miranda/PycharmProjects/deepof/unsupervised_trained_models/fit/deepof_unsupervised_VaDE_TCN_encodings_input_type=coords_kmeans_loss=0.0_encoding=8_k=15_20230103-151000/plugins/profile/2023_01_03_15_13_01/MC-C9791E.local.trace.json.gz
2023-01-03 15:13:01.589866: I tensorflow/core/profiler/rpc/client/

Epoch 2/10

2023-01-03 15:13:34.139225: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


Epoch 3/10

In [None]:
emb = cons.encoder([pp[0][:25000], pp[1][:25000]])
cls = cons.grouper([pp[0][:25000], pp[1][:25000]])

In [None]:
import umap
umap = umap.UMAP(
    n_components=2, 
    n_neighbors=50,
    min_dist=1.0,
).fit_transform(emb.numpy())
# umap = emb.numpy()

In [None]:
from collections import Counter

Counter(cls.numpy().argmax(axis=1))

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

#tt = GaussianMixture(n_components=5, covariance_type="diag", reg_covar=1e-04).fit(emb.numpy())
#means = tt.means_
means = cons.get_gmm_params['means'].numpy()

sns.scatterplot(x=umap[:, 0], y=umap[:, 1], hue=cls.numpy().argmax(axis=1), palette="tab20")
means = cons.get_layer("grouper").get_layer("gaussian_mixture_latent").c_mu.numpy()
#sns.scatterplot(x=means[:,0], y=means[:,1], s=250, c="black")

plt.title("GMVAE embeddings")

# plt.legend("")
plt.show()

In [None]:
cons.get_gmm_params['weights'].numpy().sum()

In [None]:
tt = np.split(np.concatenate(tt), np.cumsum([i.shape[0] for k,i in vqvae_solution[0].items() if k in list(cc.keys())]))

for i in tt:
    print(i.shape)
    print(np.max(np.abs(i.mean(axis=0))))
    print(np.mean(np.abs(i.std(axis=0))))
