## Embed multiple rats in the same latent space

Motivation: if the learned embeddings are more or less indepenend from the rat, embeddings from different rats shouldn't be embedded
in different parts of the latent space - appart from extreme behavior which are occuring only in a single rat e.g. seisures

In [None]:
import sys

sys.path.insert(0, "/home/katharina/vame_approach/VAME")

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
from vame.analysis.kinutils import KinVideo, create_grid_video
import os
from datetime import datetime
from vame.util.auxiliary import read_config
import logging
import re
from pathlib import Path
from IPython import display
from sklearn.manifold import TSNE
import umap
from fcmeans import FCM
from ipywidgets import Output, GridspecLayout
from scipy.spatial.distance import pdist, squareform
from vame.analysis.utils import (
    create_aligned_mouse_video,
    create_pose_snipplet,
    create_visual_comparison,
    thin_dataset_iteratively,
    find_percentile_threshold,
    estimate_fuzzifier,
    fukuyama_sugeno_index,
)
from matplotlib import cm
import seaborn as sns
from vame.initialize_project.themis_new import get_video_metadata

np.random.seed(42)

%reload_ext autoreload
%autoreload 2

## 1) Load latent vectors predicted from different videos

In [None]:
PROJECT_PATH = "/home/katharina/vame_approach/tb_align_0044_0089"
PROJECT_PATH = "/home/katharina/vame_approach/themis_tail_belly_align"
PROJECT_PATH = "/home/katharina/vame_approach/tb_align_0089"

trained_models = [
    (datetime.strptime(element, "%m-%d-%Y-%H-%M"), element)
    for element in os.listdir(os.path.join(PROJECT_PATH, "model"))
]
# sort by time step
trained_models.sort(key=lambda x: x[0])
latest_model = trained_models[-1][-1]

config_file = os.path.join(PROJECT_PATH, "model", latest_model, "config.yaml")
config = read_config(config_file)

latent_vec_dir = os.path.join(PROJECT_PATH, "inference", "results", latest_model)
latent_vec_files = [
    os.path.join(latent_vec_dir, file) for file in os.listdir(latent_vec_dir)
]
latent_vectors = {
    os.path.basename(file).split("_")[3]: np.load(file) for file in latent_vec_files
}

# use only 0089, 0088, 0087 - all from H06
# latent_vectors = {k: v for k, v in latent_vectors.items() if k in ["0087", "0088", "0089"]}

In [None]:
print(latent_vectors.keys())

## 2) Data Dilution
Dilute each set of latent vectors sepeately

In [None]:
# dilute the datasets
neighbor_percentiles = {}
latent_vectors_diluted = {}
time_ids_diluted = {}
sub_sampling_factor = (
    config["time_window"] // 10
)  # choose a subsampling factor for neighbor percentile estimation to save memory
for video_id, latent_vec in latent_vectors.items():
    neighbor_percentiles[video_id] = find_percentile_threshold(
        latent_vec[::sub_sampling_factor],
        config["time_window"],
        time_idx=np.arange(0, len(latent_vec))[::sub_sampling_factor],
        test_fraction=0.01 * sub_sampling_factor,
    )
    remaining_embeddings, remaining_time_ids = thin_dataset_iteratively(
        latent_vec, 0.00001, neighbor_percentiles[video_id], config["time_window"]
    )

    latent_vectors_diluted[video_id] = remaining_embeddings
    time_ids_diluted[video_id] = remaining_time_ids

## 3) Visualize Diluted Data
Ideally the learned latent space should embed the same behavior, indepentend of the actual specimen, to the same place
in the latent space. 

In [None]:
labels, all_latent_vectors = list(
    zip(*[(k, v) for k, v in latent_vectors_diluted.items()])
)

In [None]:
labels_full = [[l] * len(latent_vectors_diluted[l]) for l in labels]
labels_full = np.array([l for sub_list in labels_full for l in sub_list])
all_latent_vectors = np.concatenate(all_latent_vectors)

In [None]:
umap_trafo = umap.UMAP(
    n_components=2, min_dist=0.001, n_neighbors=30, random_state=config["random_state"]
).fit(all_latent_vectors)

In [None]:
# get rat id based on the video id
video_info_file = os.path.join(PROJECT_PATH, "video_info.csv")
video_info = pd.read_csv(video_info_file)
video_id_rat_id = {
    l: video_info[video_info["vid_file"] == l + ".MP4"]["rat"].values[0] for l in labels
}

In [None]:
%matplotlib widget
print(all_latent_vectors.shape)
umap_embeddings = umap_trafo.transform(all_latent_vectors)

cmap = cm.get_cmap("rainbow", len(labels))
for l in labels:
    print(l)
    plt.scatter(
        umap_embeddings[labels_full == l, 0],
        umap_embeddings[labels_full == l, 1],
        color=cmap(labels.index(l)),
        edgecolor="k",
        label=l,
    )
plt.legend()

In [None]:
%matplotlib widget

from numpy import linalg
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.colors as colors


fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection="3d")

pca = PCA(n_components=3)

pca_vectors = pca.fit_transform(all_latent_vectors)
# set colour map so each ellipsoid as a unique colour
norm = colors.Normalize(vmin=0, vmax=len(labels))
cmap = cm.get_cmap("tab10", len(labels))
m = cm.ScalarMappable(norm=norm, cmap=cmap)

cluster_means = [
    np.mean(pca_vectors[labels_full == i_cluster], axis=0) for i_cluster in labels
]
cluster_cov = [np.cov(pca_vectors[labels_full == i_cluster].T) for i_cluster in labels]

for i_cluster, label in enumerate(labels):
    # your ellispsoid and center in matrix form

    center = cluster_means[i_cluster]
    A = cluster_cov[i_cluster]
    # calc eigenvalues (the srt is the radius) and the eigenvectors (rotation) of the ellipsoid!
    eigen_vals, eigen_vec = linalg.eig(A)
    radii = np.sqrt(eigen_vals)

    # calculate cartesian coordinates for the ellipsoid surface
    u = np.linspace(0.0, 2.0 * np.pi, 60)
    v = np.linspace(0.0, np.pi, 60)
    x = radii[0] * np.outer(np.cos(u), np.sin(v))
    y = radii[1] * np.outer(np.sin(u), np.sin(v))
    z = radii[2] * np.outer(np.ones_like(u), np.cos(v))

    for i in range(len(x)):
        for j in range(len(x)):
            [x[i, j], y[i, j], z[i, j]] = (
                np.dot(eigen_vec, [x[i, j], y[i, j], z[i, j]]) + center
            )
    # ax.plot_surface(
    #    x,
    #    y,
    #    z,
    #    rstride=3,
    #    cstride=3,
    #    color=m.to_rgba(i_cluster),
    #    linewidth=0.1,
    #    alpha=0.3,
    #    shade=True,
    #    label=label_name,
    # )
    ax.plot(
        pca_vectors[labels_full == label, 0],
        pca_vectors[labels_full == label, 1],
        pca_vectors[labels_full == label, 2],
        ".",
        color=m.to_rgba(i_cluster),
        alpha=0.5,
        label=":".join([video_id_rat_id[label], label]),
    )

min_val = np.amin(pca_vectors)  # lowest number in the array
max_val = np.amax(pca_vectors)  # highest number in the array

ax.set_xlim3d(min_val, max_val)
ax.set_ylim3d(min_val, max_val)
ax.set_zlim3d(min_val, max_val)
ax.legend()

# K8 (videos 0056 and 0053) has actually seizures and a very different behavior to the other rats
plt.show()

## Clustering with DBSCAN based on a reduced number of feature dimensions using PCA

In [None]:
pca_dim_red = PCA(n_components=20)
pca_dim_red.fit(all_latent_vectors)
print(
    f"Explained variance cumulated over the dimensions: {pca_dim_red.explained_variance_ratio_.cumsum()}"
)

In [None]:
print(len(all_latent_vectors))

In [None]:
from sklearn.cluster import DBSCAN

pca = PCA(n_components=10)

pca_vectors = pca.fit_transform(all_latent_vectors)
dbscan = DBSCAN(eps=1.0, min_samples=10)

dbscan_labels = dbscan.fit_predict(pca_vectors)

In [None]:
print(
    "Num labels:",
    np.max(dbscan_labels) + 1,
    ", Outlier percentage:",
    np.round(sum(dbscan_labels == -1) / len(dbscan_labels) * 100, 2),
    "%",
)

In [None]:
# print per cluster number of assigned samples / fraction of samples from one rat total / relative
for i_cluster in range(np.max(dbscan_labels)):
    print(f"Cluster: {i_cluster}")
    print(
        f"Assigned samples: {np.sum(dbscan_labels == i_cluster)} / {len(dbscan_labels)}; percentage {np.round(np.sum(dbscan_labels == i_cluster) / len(dbscan_labels) * 100,2)}%"
    )
    samples_per_video = [
        (l, np.sum(labels_full[dbscan_labels == i_cluster] == l)) for l in labels
    ]
    print(f"Total samples per video: {samples_per_video}")
    samples_per_video = [
        (
            l,
            str(
                np.round(
                    np.sum(labels_full[dbscan_labels == i_cluster] == l)
                    / np.sum(labels_full == l)
                    * 100
                )
            )
            + "%",
        )
        for l in labels
    ]
    print(f"samples rel. to video length: {samples_per_video}")
    print("-" * 30)

### Observation: Many of the found clusters with DBSCAN are extremely small (eps. ~2.3; min_samples=5), whereas the biggest cluster contains ~85% of all samples. Decreasing the eps or increasing the number of samples just increases the fraction of outliers and results in tiny clusters

## Clustering with Fuzzy C-Means

In [None]:
n_clusters_fcm = 50
fcm = FCM(n_clusters=n_clusters_fcm, m=1.1)

# cluster the data in fewer dimensions
fcm.fit(pca_vectors)

# output
fcm_centers = fcm.centers
# output is [N,K]: N number of latent embeddings and K the number of clusters; for where each entry is a membership score between 0...1
fcm_labels_soft = fcm.soft_predict(pca_vectors)
fcm_labels = np.argmax(fcm_labels_soft, axis=1)

In [None]:
# mark samples below a certain membership thr as outliers
min_membership_thr = 0.7
fcm_labels[np.max(fcm_labels_soft, axis=1) < min_membership_thr] = -1
print(f"Samples assigned to outlier group: {np.sum(fcm_labels == -1)} / {len(fcm_labels)}; percentage {np.round(100 * np.sum(fcm_labels == -1) / len(fcm_labels),2)}%")
# plot num samples assigned to outliers per video
for l in labels:
    num_outliers = np.sum(fcm_labels[labels_full == l] == -1)
    print(f"Video: {l}, Rat:{video_id_rat_id[l]}")
    print(f"Num Outliers {num_outliers} / {np.sum(labels_full==l)}; Percentage {num_outliers / np.sum(labels_full==l) * 100}%")
    print("-"*20)

In [None]:
# print per cluster number of assigned samples / fraction of samples from one rat total / relative
for i_cluster in range(np.max(n_clusters_fcm)):
    print(f"Cluster: {i_cluster}")
    print(
        f"Assigned samples: {np.sum(fcm_labels == i_cluster)} / {len(fcm_labels)}; percentage {np.round(100 * np.sum(fcm_labels == i_cluster) / len(fcm_labels),2)}%"
    )
    samples_per_video = [
        (l, np.sum(labels_full[fcm_labels == i_cluster] == l)) for l in labels
    ]
    print(f"Total samples per video: {samples_per_video}")
    samples_per_video = [
        (
            l,
            str(
                np.round(
                    np.sum(labels_full[fcm_labels == i_cluster] == l)
                    / np.sum(labels_full == l)
                    * 100
                )
            )
            + "%",
        )
        for l in labels
    ]
    print(f"samples rel. to video length: {samples_per_video}")
    print("-" * 30)

In [None]:
sub_set = ["0087", "0088", "0089"]

is_in_sub_set = np.isin(labels_full, sub_set)
is_in_sub_set[:] = True

In [None]:
grid = GridspecLayout(int(np.ceil(n_clusters_fcm /2.0)), 2)
# sorted video files
aligned_video_files = {
    l: os.path.join(PROJECT_PATH, "videos", "aligned_videos", "a" + l + ".MP4")
    for l in labels
}
time_idx_stacked = np.concatenate([time_ids_diluted[l] for l in labels])


video = KinVideo(aligned_video_files[sub_set[0]], view="Down")
video.probevid()
video_clip_duration = config["time_window"] / video.getfps()


for i_cluster_id in range(n_clusters_fcm):

    sampled_idx = np.random.choice(
        np.arange(0, len(fcm_labels))[
            (fcm_labels == i_cluster_id) & is_in_sub_set
        ],
        min(16, np.sum((fcm_labels == i_cluster_id) & is_in_sub_set)),
        replace=False,
    )
    if len(sampled_idx) > 0:
        video_clip_data_cluster = [
            (
                aligned_video_files[labels_full[idx]],
                time_idx_stacked[idx] / video.getfps(),
                (0, 0, video.width, video.height),
            )
            for idx in sampled_idx
        ]
        grid_video_cluster = create_grid_video(
            video_clip_data_cluster, video_clip_duration, speed=0.5, nrows=4, ncols=4,
        )
        out = Output()
        with out:
            display.display(
                display.Video(
                    grid_video_cluster,
                    embed=True,
                    html_attributes="loop autoplay",
                    width=450,
                    height=450,
                )
            )
        if i_cluster_id % 2 == 0:
            idx_col = 0
        else:
            idx_col = 1
        grid[i_cluster_id // 2, idx_col] = out

# just h
grid