In [None]:
import sys
sys.path.insert(0, "/home/katharina/vame_approach/VAME")

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
from vame.analysis.kinutils import KinVideo, create_grid_video
import os
from datetime import datetime
from vame.util.auxiliary import read_config
import logging
import re
from pathlib import Path
from IPython import display
from sklearn.manifold import TSNE
import umap
from fcmeans import FCM
from ipywidgets import Output, GridspecLayout
from scipy.spatial.distance import pdist, squareform
from vame.analysis.visualize import create_aligned_mouse_video, create_pose_snipplet, create_visual_comparison,thin_dataset_iteratively
from matplotlib import cm
import seaborn as sns

logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(asctime)s: %(message)s')


%reload_ext autoreload
%autoreload 2

## 1 ) Load latent vectors

In [None]:
PROJECT_PATH = "/home/katharina/vame_approach/themis_tail_belly_align"

anchor_idx = 4000 # select a time point to find similar neighbors to

anchor_idx = int(335.6* 120) # standing mouse
#anchor_idx = int(373.32 * 120) # walking mouse+ tick at the end

SHOW_ALIGNED = True # if True create an aligned video to create snipplets from; otherwise use original video
min_dist_nn_factor = 1 # config["time_window"] * min_dist_nn_factor will define the min distance in time points between the anchor and between sampled neighbors
align_landmark_idx = [8,16] # landmarks to use for alignment of the videos

trained_models = [
    (datetime.strptime(element, "%m-%d-%Y-%H-%M"), element)
    for element in os.listdir(os.path.join(PROJECT_PATH, "model"))
]
# sort by time step
trained_models.sort(key=lambda x: x[0])
latest_model = trained_models[-1][-1]

config_file = os.path.join(PROJECT_PATH, "model", latest_model, "config.yaml")
config = read_config(config_file)
# select landmark file
landmark_file = config["video_sets"][0]
data_path = os.path.join(
        PROJECT_PATH,
        "results",
        latest_model,
        landmark_file,
        config["model_name"],
        "kmeans-" + str(config["n_init_kmeans"]),
    )
latent_vectors_all = np.load(
        os.path.join(data_path, "latent_vector_" + landmark_file + ".npy")
    )

In [None]:
# extract landmark data and names
landmarks_orig = pd.read_csv(
        os.path.join(PROJECT_PATH, "landmarks", landmark_file + ".csv"), header=[0, 1],
    )
column_names = landmarks_orig.columns
landmark_names = [col_name[0] for col_name in column_names if col_name[-1] == "x"]

landmark_data_file = os.path.join(
        PROJECT_PATH, "data", landmark_file, landmark_file + "-PE-seq.npy"
    )
landmark_data_aligned = np.load(landmark_data_file).T
# reshape to (N_samples, N_landmarks, 2)
landmark_data_aligned = landmark_data_aligned.reshape(landmark_data_aligned.shape[0], -1, 2)

## 2) Sample anchor latent embedding and visualize together with its nearest neighbors vs distant samples

In [None]:
window_start = max(0, anchor_idx - int(config["time_window"] * min_dist_nn_factor))
window_end = min(len(latent_vectors_all), anchor_idx + int(config["time_window"] * min_dist_nn_factor))
selected_latent_vector = latent_vectors_all[anchor_idx, :]

dist_orig = np.sqrt(np.sum((latent_vectors_all - selected_latent_vector.reshape(1,-1))**2, axis=1))

time_points = np.arange(0, latent_vectors_all.shape[0])
time_points = np.concatenate([time_points[0:window_start], time_points[window_end:-1]])
latent_vectors = np.concatenate([latent_vectors_all[0:window_start], latent_vectors_all[window_end:-1]])
# distances between each latent vector and the selected one excluding the distances of latent vectors corresponding to temporally close frames
dist = np.concatenate([dist_orig[0:window_start], dist_orig[window_end:-1]])

# select n neighbors, and enshure the neighbors are separated by a min timespan
selected_neighbor_idx = []
while len(selected_neighbor_idx) < 8 and len(dist) > 0:
    n_idx = np.argmin(dist)
    selected_neighbor_idx.append(time_points[n_idx])
    # remove all distances close to the selected anchor
    is_far_away = np.abs(time_points - time_points[n_idx]) > int(config["time_window"] * min_dist_nn_factor)
    dist = dist[is_far_away]
    time_points = time_points[is_far_away]

In [None]:
# Plot histogramm of distances
%matplotlib widget
# plot dist showing the neighbors but not the excluded ones?
# plotting dist removing only the nearby neighbors
bins = 50
hist_range = (min(dist_orig), max(dist_orig))
dist_wo_nearby = np.concatenate([dist_orig[0:window_start], dist_orig[window_end:-1]])
plt.hist(dist_orig,bins=bins, range=hist_range, label="all dist")
plt.hist(dist_wo_nearby, bins=bins, range=hist_range, alpha=0.5, label="dist w/o nearby")
plt.legend()

### 2.1) Visualize anchor vs nearest neighbors

In [None]:
# get corresponding video
video_df = pd.read_csv(os.path.join(PROJECT_PATH, "video_info.csv"))
video_id = int(re.findall(r"\d+", landmark_file)[0])
video_file = os.path.join(
            *video_df[video_df["video_id"] == video_id][
                ["vid_folder", "vid_file"]
            ].values[0]
        )
subject, date, camera_pos, video_name = Path(video_file).parts[-4:]


In [None]:
# all selected time points including anchor and its nearest neighbors
time_ids = [anchor_idx, *selected_neighbor_idx]
# change from frames to seconds by dividing with fps
if SHOW_ALIGNED:
    video_name, ending = os.path.basename(video_file).split(".")
    aligned_video_path = os.path.join(
        PROJECT_PATH,
        "results", "align", "a" + video_name + "." + ending)
    if not os.path.exists(aligned_video_path):
        landmark_file_path = os.path.join(PROJECT_PATH, "landmarks", landmark_file+".csv")
        create_aligned_mouse_video(
            video_file,
            landmark_file,
            align_landmark_idx,
            os.path.dirname(aligned_video_path),
            crop_size=(300, 300))

    selected_video_file = aligned_video_path
else:
    selected_video_file = video_file
#print(selected_video_file)

video = KinVideo(selected_video_file, view=camera_pos)
video.probevid()
video_clip_duration = config["time_window"]/video.getfps()


video_clip_data = [(selected_video_file, t_id/ video.getfps(), (0,0,video.width,video.height)) for t_id in time_ids]
#print(video_clip_data)
grid_video_name = create_grid_video(video_clip_data,video_clip_duration,speed=0.5) # duration is in seconds!!
dist_matrix = np.round(squareform(pdist(latent_vectors_all[time_ids])), 3)
print(f"Distances:\n {dist_matrix}")
display.Video(grid_video_name, embed=True,html_attributes="loop autoplay", width=600,height=600)


In [None]:
## create full pose video and then sample the snipplets
video_name = os.path.basename(video_file)
pose_video_file = os.path.join(PROJECT_PATH, "results", "poses_"+video_name)
if not os.path.exists(pose_video_file):
    crop_size = 400
    # min max normalize the data to a fixed grid shape for visualization
    landmark_name = os.path.basename(landmark_file).split(".")[0]
    # reshape to (N_samples, N_landmarks, 2)
    landmark_data_aligned = np.load(
        os.path.join(PROJECT_PATH, "data", landmark_name, landmark_name + "-PE-seq.npy")
    ).T
    landmark_data_aligned = landmark_data_aligned.reshape(
        landmark_data_aligned.shape[0], -1, 2
    )
    landmark_data_trafo = (
        (landmark_data_aligned - landmark_data_aligned.min())
        / (landmark_data_aligned.max() - landmark_data_aligned.min())
        * (crop_size - 1)
    )
    column_names = pd.read_csv(landmark_file, header=[0, 1]).columns
    landmark_names = [col_name[0] for col_name in column_names if col_name[-1] == "x"]
    time_ids = np.arange(0, len(landmark_data_trafo))
    create_pose_snipplet(
        landmark_data_trafo,
        landmark_names,
        time_ids,
        pose_video_file,
        crop_size=(crop_size, crop_size),
    )
pose_video = KinVideo(pose_video_file, view=camera_pos)
pose_video.probevid()

pose_video_clip_data = [(pose_video_file, t_id/ pose_video.getfps(), (0,0,pose_video.width,pose_video.height)) for t_id in time_ids]
pose_grid_video_name = create_grid_video(pose_video_clip_data,video_clip_duration,speed=0.5) # duration is in seconds!!
display.Video(pose_grid_video_name, embed=True,html_attributes="loop autoplay", width=600,height=600)

### 2.2) Visualize anchor vs distant samples

In [None]:
## Visualize Anchor together with distant embeddings
# select other embeddings from the 80% distance percentiles
dist_percentile = 80

dist_thr = np.percentile(dist_orig, dist_percentile)

time_idx_other = np.where(dist_orig > dist_thr)[0].reshape(-1)
sampled_idx = np.random.choice(time_idx_other, 8, replace=False)
# select anchors
video_clip_data_distant = [(selected_video_file, t_id/ video.getfps(), (0,0,video.width,video.height)) for t_id in [anchor_idx, *sampled_idx]]
#print(video_clip_data)
grid_video_name_distant = create_grid_video(video_clip_data_distant,video_clip_duration,speed=0.5) # duration is in seconds!!

dist_matrix = np.round(squareform(pdist(latent_vectors_all[[anchor_idx, *sampled_idx]])), 3)
print(f"Distances:\n {dist_matrix}")
display.Video(grid_video_name_distant, embed=True,html_attributes="loop autoplay", width=600,height=600)


### 2.3) Compare video clips of a bunch of randomly selected anchors vs neighbors and distant samples

In [None]:
pick_n_anchors = 3 # how many anchor ids to pick randomly


random_anchor_ids = np.random.choice(np.arange(0, latent_vectors_all.shape[0]), pick_n_anchors, replace=False)
min_frame_distance = int(config["time_window"] * min_dist_nn_factor)

video = KinVideo(selected_video_file, view=camera_pos)
video.probevid()
video_clip_duration = config["time_window"]/video.getfps()

video_stack = []
for a_idx in random_anchor_ids:
    video_stack.append(create_visual_comparison(a_idx, latent_vectors_all, min_frame_distance, selected_video_file,video_clip_duration, upper_dist_percentile=80))

# plot next to each other: left side: anchor and its 8 closest neighbors; right side anchor and 8 samples belonging to the 
# 20% of the most distant latent vectors wrt. the anchor embedding
grid = GridspecLayout(pick_n_anchors, 2)
# sorted video files
for i_row, video_pair in enumerate(video_stack):
    for j_vid,video_f in enumerate(video_pair):
        out = Output()
        with out:
            display.display(display.Video(video_f, embed=True, html_attributes="loop autoplay", width=450,height=450))
        grid[i_row, j_vid] = out
grid



## 3) Diluting the latent space by removing samples iteratively

In [None]:
min_frame_rate = config["time_window"]
min_remaining_dataset = 0.001 # minimum fraction of remaining samples  e.g. 0.1 = 10%
neighbor_percentile = 1 # remove vectors which are temporally close to the sampled anchor if the belong to its closest N% percentile of embeddings in the latent space
remaining_embeddings, remaining_time_ids = thin_dataset_iteratively(
    latent_vectors_all, min_remaining_dataset, neighbor_percentile, min_frame_rate)
print(f"{len(remaining_embeddings)} remaining samples")

In [None]:
tsne_thinned = TSNE(perplexity=30).fit_transform(remaining_embeddings)

In [None]:
umap_func = umap.UMAP(densmap=True,n_components=2,min_dist=0.0001,n_neighbors=30, random_state=config["random_state"])
umap_vectors = umap_func.fit_transform(remaining_embeddings)

## 3.1) Sample and visualize anchors and neighbors / distant sampples from the thinned dataset

In [None]:
pick_n_anchors = 5 # how many anchor ids to pick randomly

random_anchor_ids = np.random.choice(np.arange(0, remaining_embeddings.shape[0]), pick_n_anchors, replace=False)
min_frame_distance = int(config["time_window"] * min_dist_nn_factor)

video = KinVideo(selected_video_file, view=camera_pos)
video.probevid()
video_clip_duration = config["time_window"]/video.getfps()

video_stack = []
sampled_idx_stack = []
for a_idx in random_anchor_ids:
    output = create_visual_comparison(a_idx, remaining_embeddings, min_frame_distance, selected_video_file,video_clip_duration, upper_dist_percentile=80, time_idx=remaining_time_ids, return_sampled_idx=True)
    video_close, video_distant, samples_close, samples_distant = output
    video_stack.append((video_close, video_distant))
    sampled_idx_stack.append((samples_close, samples_distant))

# plot next to each other: left side: anchor and its 8 closest neighbors; right side anchor and 8 samples belonging to the 
# 20% of the most distant latent vectors wrt. the anchor embedding
grid = GridspecLayout(pick_n_anchors * 2, 2)
# sorted video files
for i_row, video_pair in enumerate(video_stack):
    for j_vid,video_f in enumerate(video_pair):
        out = Output()
        with out:
            display.display(display.Video(video_f, embed=True, html_attributes="loop autoplay", width=450,height=450))
        grid[i_row * 2, j_vid] = out
        
        dist_matrix = np.round(squareform(pdist(latent_vectors_all[sampled_idx_stack[i_row][j_vid]])), 2)
        out = Output()
        with out:
            display.display(display.Pretty(f"Distances: \n {dist_matrix}"))
        grid[i_row * 2+1, j_vid] = out

grid

### 3.2) Visualize the sampled anchors with their neighbors and distant samples in a t-SNE and UMAP

Idea: check how close samples which are neighbors in the latent space will be projected in t-SNE / UMAP.
To use the t-SNE / UMAP projections for clustering the anchor and its close neighbors should also be close in the t-SNE/ UMPA projection and samples which are distant to the selected anchor should also be distant to the anchor in the projection.

In [None]:
# visualize the sampled points in the TSNE plot
%matplotlib widget

fig, ax = plt.subplots(len(sampled_idx_stack), 2)
fig.set_size_inches(4 * 2, len(sampled_idx_stack) * 4)
for i_row, (s_close, s_distant) in enumerate(sampled_idx_stack):
    anchor_idx = s_close[0]
    ax[i_row, 0].plot (tsne_thinned[:, 0], tsne_thinned[:, 1], 'k.', label='TSNE')
    ax[i_row, 0].plot (tsne_thinned[np.isin(remaining_time_ids, s_close[1:]), 0], tsne_thinned[np.isin(remaining_time_ids, s_close[1:]), 1], 'go', label='Close Neighbors')
    ax[i_row, 0].plot (tsne_thinned[np.isin(remaining_time_ids, s_distant[1:]), 0], tsne_thinned[np.isin(remaining_time_ids, s_distant[1:]), 1], 'ro', label='Distant Samples')
    ax[i_row, 0].plot (tsne_thinned[remaining_time_ids == anchor_idx, 0], tsne_thinned[remaining_time_ids == anchor_idx, 1], 'yo', label='Anchor')
    
    ax[i_row, 1].plot (umap_vectors[:, 0], umap_vectors[:, 1], 'k.', label='UMAP')
    ax[i_row, 1].plot (umap_vectors[np.isin(remaining_time_ids, s_close[1:]), 0], umap_vectors[np.isin(remaining_time_ids, s_close[1:]), 1], 'go', label='Close Neighbors')
    ax[i_row, 1].plot (umap_vectors[np.isin(remaining_time_ids, s_distant[1:]), 0], umap_vectors[np.isin(remaining_time_ids, s_distant[1:]), 1], 'ro', label='Distant Samples')
    ax[i_row, 1].plot (umap_vectors[remaining_time_ids == anchor_idx, 0], umap_vectors[remaining_time_ids == anchor_idx, 1], 'yo', label='Anchor')
    
    ax[i_row, 0].legend()
    ax[i_row, 1].legend()
    
plt.tight_layout()


Observation: after running the code multiple times with different anchors, the close neighbors in the latent space tend to be also closer than the distant neighbors in the t-SNE/UMAP projection. However, quite often the anchor and its close neighbors are not concentrated at one point in the projections.

## 4)  Cluster the thinned dataset
Test different clustering approaches and visualize samples from the clusters for visual inspection

In [None]:
### 4.1) fuzzy c-means

In [None]:
n_clusters = 30
fcm = FCM(n_clusters=n_clusters, m=1.1)

fcm.fit(remaining_embeddings)

# output
fcm_centers = fcm.centers
fcm_labels_soft = fcm.soft_predict(remaining_embeddings)
#c_ids, label_counts = np.unique(fcm_labels, return_counts=True)
#mapped_counts = [(c_id, c_count) for c_id, c_count in zip(c_ids, label_counts)]
#print(f"Counts per label: {mapped_counts}")

In [None]:
fcm_labels_soft = fcm.soft_predict(remaining_embeddings)
%matplotlib widget
# plot dist showing the neighbors but not the excluded ones?
# plotting dist removing only the nearby neighbors
bins_clustering = 20
hist_range_fuzzy = (0, 1)

sns.histplot(data=fcm_labels_soft,binrange=(0,1), bins=20)


In [None]:
# histogramm of the largest cluster assignment score per sample
# How many samples have one cluster assignment score of at least P?
%matplotlib widget
cluster_assignment_thr = 0.7

fcm_max_val = np.max(fcm_labels_soft, axis=1)
plt.hist(fcm_max_val, bins=20, range=(0,1), density=True, cumulative=-1)

thr_fraction = round(sum(fcm_max_val > cluster_assignment_thr) / len(fcm_max_val) * 100, 2)
print(f"{thr_fraction}% of the samples have a largest cluster assignment score above {cluster_assignment_thr}")

In [None]:
potential_time_idx = remaining_time_ids[np.max(fcm_labels_soft, axis=1) > cluster_assignment_thr]
fcm_labels = np.argmax(fcm_labels_soft, axis=1)[np.max(fcm_labels_soft, axis=1) > cluster_assignment_thr]

# predict label from soft labels and set for samples with max value < cluster_assignment_thr to -1
fcm_labels_all = np.ones(fcm_labels_soft.shape[0]) * -1
fcm_labels_all[np.max(fcm_labels_soft, axis=1) > cluster_assignment_thr] = np.argmax(fcm_labels_soft, axis=1)[np.max(fcm_labels_soft, axis=1) > cluster_assignment_thr]


In [None]:
# overlay found clusters with t-SNE / UMAP Projection
%matplotlib widget

fig, ax = plt.subplots(1, 2)
fig.set_size_inches(4.5 * 2, 4.5)

ax[0].plot(tsne_thinned[:, 0], tsne_thinned[:, 1], 'k.', label='TSNE')
ax[1].plot(umap_vectors[:, 0], umap_vectors[:, 1], "k.", label=f'UMAP')
cmap = cm.get_cmap("nipy_spectral", n_clusters)
for i_cluster in range(n_clusters):
    anchor_idx = s_close[0]
    ax[0].plot(tsne_thinned[fcm_labels_all==i_cluster, 0], tsne_thinned[fcm_labels_all==i_cluster, 1],".", color=cmap(i_cluster), label=f'Cluster: {i_cluster}')
    ax[1].plot(umap_vectors[fcm_labels_all==i_cluster, 0], umap_vectors[fcm_labels_all==i_cluster, 1],".", color=cmap(i_cluster), label=f'Cluster: {i_cluster}')

        
#ax[0].legend()
#ax[1].legend()
    
plt.tight_layout()

In [None]:
Observation: Some clusters found by the fuzzy-c-means tend to also also form clusters in the t-SNE/UMAP projections.

In [None]:
# Select a cluster, sample from it and visualize
# select only the points with a high assignment score to sample from

cluster_id = 4

# all selected time points including anchor and its nearest neighbors
time_ids_cluster = np.random.choice(potential_time_idx[fcm_labels == cluster_id], 9)
print(time_ids_cluster)

video_clip_data_cluster = [(selected_video_file, t_id/ video.getfps(), (0,0,video.width,video.height)) for t_id in time_ids_cluster]
grid_video_cluster = create_grid_video(video_clip_data_cluster,video_clip_duration,speed=0.5) # duration is in seconds!!

# print Euclidean distances between the samples from the cluster
dist_matrix_cluster = np.round(squareform(pdist(latent_vectors_all[time_ids_cluster])), 3)
print(f"Distances close:\n {dist_matrix_cluster}")

display.Video(grid_video_cluster, embed=True,html_attributes="loop autoplay", width=600,height=600)

In [None]:
# overlay found clusters with t-SNE / UMAP Projection


In [None]:
### 4.2) FLAME clustering