#### Import required libraries and functions

In [None]:
import torch

from bams.data import KeypointsDataset
from bams.models import BAMS
from bams import compute_representations
from custom_dataset_w_labels import load_data, load_annotations

import numpy as np
import os 
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
'''
input the path to the model folder and the data folder 

the data folder should be organized with subfolders named with the species/label. 
each subfolder contains dlc csvs. each csv should contain the same number of datapoints (frames) 
'''

### input path to data folder here ###
data_folder = r"X:\Behavior\DeepLabCut\Body_Models_2D\Movement_Curated\completedata\rats_treeshrews\threshold_0.8\movement" 

### input path to model folder here ###
model_folder = r"X:\MaryBeth\BAMS\completedata\bams-custom-2024-08-29-15-48-48_0.8" 

with os.scandir(model_folder) as entries:
    for entry in entries:
        if entry.is_file() and entry.name.startswith('bams-custom') and entry.name.endswith('.pt'):
            model_name = entry
            print("Loading model", model_name)
            print()
            
model_path = os.path.join(model_folder, model_name)
annotations_path = os.path.join(model_folder,"video_labels.csv")

hoa_bins = 32
keypoints = load_data(data_folder, model_folder, create_csv = False) ### set to True if you want to re create the mapping from bams to the original data csvs ### 
annotations, eval_utils = load_annotations(annotations_path)

dataset = KeypointsDataset(
        keypoints=keypoints,
        cache=False,
        hoa_bins=hoa_bins,
        annotations=annotations,
        eval_utils=eval_utils
    )

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = BAMS(
        input_size=dataset.input_size,
        short_term=dict(num_channels=(64, 64, 64, 64), kernel_size=3),
        long_term=dict(num_channels=(64, 64, 64, 64, 64), kernel_size=3, dilation=4),
        predictor=dict(
            hidden_layers=(-1, 256, 512, 512, dataset.target_size * hoa_bins)
        ),
    ).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()

embeddings = compute_representations(model, dataset, device)

## Plot PCAs of sequence level embeddings

In [None]:
''' 
each embedding (short, long, all) should be shape (n samples (total csvs), n frames per sample, n bams features)
'''

### retrieve frame level embeddings ###

short_term = embeddings['short_term']
long_term = embeddings['long_term']
all_embeddings = torch.cat([short_term, long_term], dim=2)

print("short_term: ", np.shape(short_term))
print("long_term: ", np.shape(long_term))
print("all_embeddings: ", np.shape(all_embeddings))
print("")

### compute sequence level embeddings of the bams features by averaging over the frames for each sample ###

short_term_seq = torch.mean(short_term, dim=1, keepdim=False)
long_term_seq = torch.mean(long_term, dim=1, keepdim=False)
all_embeddings_seq = torch.cat([short_term_seq, long_term_seq], dim=1)

print("short_term: ", np.shape(short_term_seq))
print("long_term: ", np.shape(long_term_seq))
print("all_embeddings: ", np.shape(all_embeddings_seq))

### 2D PCA 

In [None]:
### fit PCA for each sequence level embedding ###

pca = PCA(n_components=2)
pca.fit(short_term_seq)
short_term_pca = pca.transform(short_term_seq)
ev_short_term = pca.explained_variance_ratio_

pca = PCA(n_components=2)
pca.fit(long_term_seq)
long_term_pca = pca.transform(long_term_seq)
ev_long_term = pca.explained_variance_ratio_

pca.fit(all_embeddings_seq)
all_pca = pca.transform(all_embeddings_seq)
ev_all = pca.explained_variance_ratio_

### print the explaiened variance for each emebddings ###

print("ev short term: ", ev_short_term)
print("ev long term: ", ev_long_term)
print("ev all: ", ev_all)

#### Create mapping for PCA legend

In [None]:
import pandas as pd
import re
import numpy as np
import matplotlib.pyplot as plt

### load the mapping from bams to the original data csvs ### 

df = pd.read_csv(annotations_path)

num_segments = len(df)
video_names = df['video_name'].unique()
unique_labels = df['label'].unique()
label_mapping = {label: index for index, label in enumerate(unique_labels)}

print(label_mapping)

video_label_mapping = df.groupby('video_name')['label'].first().to_dict()

print(video_label_mapping)

consolidated_video_label_mapping = {}

for video_name, label in video_label_mapping.items():
    ### extract the date from each video name ###
    starting_number = video_name.split('_')[0]
    
    ### add each unique date to the consolidated mapping ###
    if starting_number not in consolidated_video_label_mapping:
        consolidated_video_label_mapping[starting_number] = label

print(consolidated_video_label_mapping)

videos_label_1 = [video for video, label in consolidated_video_label_mapping.items() if label == 'treeshrew']
videos_label_0 = [video for video, label in consolidated_video_label_mapping.items() if label == 'rat']

print(videos_label_1)
print(videos_label_0)

count_treeshrew = df[df['label'] == 'treeshrew'].shape[0]
count_rat = df[df['label'] == 'rat'].shape[0]

print("num rats: ", count_rat)
print("num treeshrews: ", count_treeshrew)

### make color gradients for each species ### 
num_rat_videos = len(videos_label_0)
num_treeshrew_videos = len(videos_label_1)

rat_colormap = plt.cm.summer(np.linspace(0, 1, num_rat_videos))
ts_colormap = plt.cm.cool(np.linspace(0, 1, num_treeshrew_videos))

video_color_mapping = {video: tuple(ts_colormap[i]) for i, video in enumerate(videos_label_1)}
video_color_mapping.update({video: tuple(rat_colormap[i]) for i, video in enumerate(videos_label_0)})

print(video_color_mapping)

def extract_numbers(video_name):
    match = re.match(r'(\d+)', video_name)
    return match.group(1) if match else video_name

video_number_color_mapping = {extract_numbers(video): color for video, color in video_color_mapping.items()}

#### 2D Visualization with Matplotlib

In [None]:
import matplotlib.pyplot as plt
import numpy as np

fig, ax = plt.subplots(1, 3, figsize=(15, 5))

df['starting_number'] = df['video_name'].apply(lambda x: x.split('_')[0])
c = df['starting_number'].map(video_color_mapping) 

### plot the pcas for the short term, long term, and all embeddings ### 
ax[0].scatter(short_term_pca[:, 0], short_term_pca[:, 1], c=c, s=10)
ax[0].set_title(f'Short term PCA explained variance: {round(sum(ev_short_term)*100)}%')
ax[0].set_xlabel('PC1')
ax[0].set_ylabel('PC2')
ax[1].scatter(long_term_pca[:, 0], long_term_pca[:, 1], c=c, s=10)
ax[1].set_title(f'Long term PCA explained variance: {round(sum(ev_long_term)*100)}%')
ax[2].scatter(all_pca[:, 0], all_pca[:, 1], c=c, s=10)
ax[2].set_title(f'All PCA explained variance: {round(sum(ev_all)*100)}%')

handles_numbers = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10) for color in video_number_color_mapping.values()]
labels_numbers = [number for number in video_number_color_mapping.keys()]

treeshrew_handle = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10) for color in ts_colormap]
rat_handle = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10) for color in rat_colormap]

handles_species = treeshrew_handle + rat_handle
labels_species = [f'Treeshrew' for i in range(len(ts_colormap))] + [f'Rat' for i in range(len(rat_colormap))]

labels_species[0] = 'Treeshrew (1)'
labels_species[1] = 'Treeshrew (2)'
labels_species[2] = 'Treeshrew (2)'
labels_species[3] = 'Treeshrew (2)'
labels_species[4] = 'Treeshrew (2)'

first_legend = fig.legend(handles_numbers, labels_numbers, loc='upper right', title='Video', bbox_to_anchor=(1, 0.8))
second_legend = fig.legend(handles_species, labels_species, loc='upper left', title='Species', bbox_to_anchor=(1, 0.8))

plt.gca().add_artist(first_legend)
plt.gca().add_artist(second_legend)
plt.show()

### uncomment if you want to save the 2d pca plot ###
# plt.savefig(os.path.join(model_folder,"2dpca.png"),bbox_inches='tight') 

## Create the Pearson correlation matrices for the short and long term sequence level embeddings

In [None]:
short_term_seq_np = short_term_seq.numpy()
long_term_seq_np = long_term_seq.numpy()
all_embeddings_seq_np = all_embeddings_seq.numpy()

print("short_term: ", np.shape(short_term_seq_np))
print("long_term: ", np.shape(long_term_seq_np))

def plot_correlation_heatmap(ax, data, title):
    corr_matrix = np.corrcoef(data, rowvar=False) ### setting rowvar=False because the columns contain the features and the rows contain the observations ###
    sns.heatmap(corr_matrix, annot=False, cmap="coolwarm", ax=ax, vmin=-1, vmax=1)
    ax.set_title(title)

    ticks = np.arange(0, data.shape[1], 2)
    ax.set_xticks(ticks)
    ax.set_yticks(ticks)
    ax.set_xticklabels(ticks)
    ax.set_yticklabels(ticks)

fig, ax = plt.subplots(1, 2, figsize=(24, 8))

plot_correlation_heatmap(ax[0], short_term_seq_np, "Short Term Sequence")
plot_correlation_heatmap(ax[1], long_term_seq_np, "Long Term Sequence")

### uncomment if you want to save the pearson correlation plot ###
#plt.savefig(os.path.join(model_folder,"pc_seq.png"),bbox_inches='tight') 

## Create 3D UMAP to explore frame level emebddings 

In [None]:
### reshape frame array ###

short_term_reshaped = short_term.view(short_term.size(0) * short_term.size(1), short_term.size(2))
long_term_reshaped = long_term.view(long_term.size(0) * long_term.size(1), long_term.size(2))
all_embeddings_reshaped = all_embeddings.view(all_embeddings.size(0) * all_embeddings.size(1), all_embeddings.size(2))

### print shapes before and after reshaping ###

print("short_term: ", np.shape(short_term))
print("long_term: ", np.shape(long_term))
print("all_embeddings: ", np.shape(all_embeddings))
print("")

print("short_term: ", np.shape(short_term_reshaped))
print("long_term: ", np.shape(long_term_reshaped))
print("all_embeddings: ", np.shape(all_embeddings_reshaped))

In [None]:
### create mapping from bams embeddings to the video/ csv space ###

### read the order in which the csvs were passed to bams ###
df = pd.read_csv(annotations_path)
video = df['video_name']

### for each csv, repeat its name n frames per sample times ###
video = np.repeat(video, short_term.shape[1])
print(np.shape(video))

### for each csv, number the frames from 0 to n frames per sample ### 
repeated_array = np.tile(np.arange(short_term.shape[1]), short_term.shape[0])
print(np.shape(repeated_array))

In [None]:
### make mapping that contains [video name, frame number] for each sample ###
video_frames = np.column_stack((video, repeated_array))

print(np.shape(video_frames))

### check ###
print(video_frames[0])

In [None]:
### just checking ###

for i in video_frames:
    print(i)

### UMAP

In [None]:
from sklearn.neighbors import NearestNeighbors
import kneed

### how many neighbors will be used for DBSCAN? ###
k = 5

### randomly sample the input data for time/computation efficiency ### 

### choose how many samples you want to plot ###
sample_size = 1000

sample_indices = np.random.choice(len(short_term_reshaped), size=sample_size, replace=False)

short_sample = short_term_reshaped[sample_indices]
video_frames_sample = video_frames[sample_indices]

### for automating eps selection, you can use the knee method ### 

nbrs = NearestNeighbors(n_neighbors=k+1).fit(short_sample)

distances, indeces = nbrs.kneighbors(short_sample)

distances = np.sort(distances, axis=0)
distances = distances[:,1]

i = np.arange(len(distances))
knee = kneed.KneeLocator(i, distances, curve='convex', direction='increasing', online = False)
x = distances[knee.knee]

print("eps :", x)

plt.axhline(y=x, color='r', linestyle='-')

plt.plot(distances)
plt.show()

In [None]:
import umap
from sklearn.cluster import DBSCAN

umap_model = umap.UMAP(n_components=3, random_state=42) 
short_umap = umap_model.fit_transform(short_sample)

### set eps to a computed or  experimentally determined value ###
dbscan_model = DBSCAN(eps=0.15, min_samples=k) 
dbscan_labels = dbscan_model.fit_predict(short_umap)

In [None]:
### check shapes to make sure everything makes sense ###

print(np.shape(short_sample))
print(np.shape(video_frames_sample))
print(np.shape(dbscan_labels))

In [None]:
import plotly.graph_objs as go

def index_to_time(index, frame_rate=30):

    total_seconds = index // frame_rate
    minutes = total_seconds // 60
    seconds = total_seconds % 60
    
    return f"{minutes}:{seconds:02d}"

unique_labels = np.unique(dbscan_labels)
colors = plt.cm.Spectral(np.linspace(0, 1, len(unique_labels)))
traces = []
for label, color in zip(unique_labels, colors):
    ### get points with the current label ###
    cluster_points = short_umap[dbscan_labels == label]
    video_frame_text = video_frames_sample[dbscan_labels == label]

    # print(np.shape(cluster_points))
    # print(np.shape(video_frame_text))
    
    ### add trace for the current label ### 
    traces.append(go.Scatter3d(
        x=cluster_points[:, 0],
        y=cluster_points[:, 1],
        z=cluster_points[:, 2],
        mode='markers',
        name=f'Cluster {label}' if label != -1 else 'Noise',
        marker=dict(
            size=3,
            color=f'rgb({color[0]*255},{color[1]*255},{color[2]*255})',
            opacity=0.8
        ),
        text=[f'{vf[0]} {vf[1]} {index_to_time(vf[1])}' for vf in video_frame_text],
    ))

layout = go.Layout(
    scene=dict(
        xaxis=dict(title='UMAP 1'),
        yaxis=dict(title='UMAP 2'),
        zaxis=dict(title='UMAP 3')
    ),
    margin=dict(l=0, r=0, b=0, t=0),  
    legend=dict(title='DBSCAN Clusters')
)

fig = go.Figure(data=traces, layout=layout)
fig.show()

### uncomment if you want to save the umap ###
# fig.write_html(os.path.join(model_folder,"short_term_dbscan_5_0.15.html"))

### Save the samples from each cluster as gifs

In [None]:
# import cv2

# ''''
# to save each data point as a gif to its cluster folder, you will need the mp4s corresponding to the sample csvs in the data folder. 
# the mp4 folder should be organized with all of the mp4s in it (do not organize into subfolders for each label/species) 
# '''

# mp4_folder = r"X:\Behavior\DeepLabCut\Body_Models_2D\Movement_Curated\completedata\rats_treeshrews\threshold_0.8\0.8_mp4s_combined"

# output_dir = os.path.join(model_folder, "output_gifs")
# os.makedirs(output_dir, exist_ok=True)

# ### get all of the unique video names ###
# unique_video_names = np.unique([vf[0] for vf in video_frames_sample])

# ### get all of the unique dbscan labels ###
# unique_labels = np.unique(dbscan_labels)

# for video_name in unique_video_names:
#     ### get the indices that correspond to the current video name ###
#     relevant_indices = [i for i, vf in enumerate(video_frames_sample) if vf[0] == video_name]

#     ### extract the video file name from the video_name variable ###
#     video_base_name = video_name.split('DLC')[0]
#     start_frame = video_name.split('_')[-2]
#     end_frame = video_name.split('_')[-1]

#     new_video_name = video_base_name + '_' + start_frame + '_' + end_frame

#     ### construct the full path to the video file ###
#     mp4_path = os.path.join(mp4_folder, new_video_name + '.mp4')

#     ### get the mp4 ###
#     cap = cv2.VideoCapture(mp4_path)
#     if not cap.isOpened():
#         print(f"Error opening video file {mp4_path}")
#         continue

#     fourcc = cv2.VideoWriter_fourcc(*'mp4v')
#     fps = int(cap.get(cv2.CAP_PROP_FPS))
#     frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
#     frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

#     ### for each label in all of the dbscan labels ###
#     for label in unique_labels:

#         ### make a directory for it ###
#         label_dir = os.path.join(output_dir, f"cluster_{label}")
#         os.makedirs(label_dir, exist_ok=True)

#         ### search relevant indeces (the indeces that contain the video) for dbscan labels that match the current label ###
#         label_indices = [i for i in relevant_indices if dbscan_labels[i] == label]
#         print(label_indices)
        
#         #### for each indeces that matches both the video and the dbscan label ###
#         for i in label_indices:

#             vf = video_frames_sample[i]
#             frame_index = int(vf[1])

#             ### get 0.5 sec before and after ###
#             frame_start = max(frame_index - 14, 0)
#             frame_end = frame_index + 15

#             ### construct the output path for the gif ###
#             output_video_path = os.path.join(label_dir, f"{new_video_name}_frame_{frame_index}.mp4")
#             out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
            
#             ### read the first frame ###
#             cap.set(cv2.CAP_PROP_POS_FRAMES, frame_start)
            
#             ### read and write each frame within the sec range ###
#             for j in range(frame_start, frame_end + 1):
#                 ret, frame = cap.read()
#                 if not ret:
#                     print(f"Error reading frame {j} from {new_video_name}")
             
#                 out.write(frame)
#             ### release the VideoWriter for this segment ###
#             out.release()  
#     ### release the video capture object after processing the current video ###
#     cap.release()  

### Plot the centroids of each data point on x,y plot

In [None]:
print(np.shape(keypoints))
keypoints_reshaped = keypoints.reshape(keypoints.shape[0] * keypoints.shape[1], keypoints.shape[2])
print(np.shape(keypoints_reshaped))

In [None]:
### index by the same sample indeces as above ### 
keypoints_reshaped_sample = keypoints_reshaped[sample_indices]

xs = keypoints_reshaped_sample[:,0::2]
ys = keypoints_reshaped_sample[:,1::2]

print(np.shape(xs))
print(np.shape(ys))

### average the x and y coordinates for each body part for each data point
x_centroids = np.mean(xs, axis = 1)
y_centroids = np.mean(ys, axis = 1)

print(np.shape(x_centroids))
print(np.shape(y_centroids))

#### Using Matplotlib

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.colors as mcolors

### remove the cluster with label -1 ###
mask = dbscan_labels != -1
filtered_x_centroids = x_centroids[mask]
filtered_y_centroids = y_centroids[mask]
filtered_labels = dbscan_labels[mask]

### get all of the unique dbscan labels ###
unique_labels = np.unique(filtered_labels)

### make a color map for the unique labels ###
colors = plt.cm.get_cmap('tab20', len(unique_labels))  

### map each unique label to a color ###
color_mapping = {label: colors(i) for i, label in enumerate(unique_labels)}

### map the filtered labels to the corresponding colors ###
color_list = [color_mapping[label] for label in filtered_labels]

plt.figure(figsize=(10, 8))
scatter = plt.scatter(filtered_x_centroids, filtered_y_centroids, c=color_list, s=50, alpha=0.7)

# ### add the cluter labels next to each dot ###
# for i, label in enumerate(filtered_labels):
#      plt.text(filtered_x_centroids[i], filtered_y_centroids[i], str(label), fontsize=9, ha='right')

### adding plot title and labels ###
plt.xlabel("X Coordinate")
plt.ylabel("Y Coordinate")

#### create a custom legend ###
handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color_mapping[label], markersize=10) for label in unique_labels]
plt.legend(handles, unique_labels, title="Cluster Label")

plt.show()

### uncomment if you want to save the centroids plot ###
#plt.savefig(os.path.join(model_folder,"centroids.png"),bbox_inches='tight') 

#### Using Plotly

In [None]:
import plotly.graph_objects as go
import numpy as np

### remove the cluster with label -1 ###
mask = dbscan_labels != -1
filtered_x_centroids = x_centroids[mask]
filtered_y_centroids = y_centroids[mask]
filtered_labels = dbscan_labels[mask]

### get all of the unique dbscan labels ###
unique_labels = np.unique(filtered_labels)


### make a color map for the unique labels ###
colors = [f'rgba({int(np.random.rand()*255)}, {int(np.random.rand()*255)}, {int(np.random.rand()*255)}, 0.8)' for _ in unique_labels]
color_mapping = {label: color for label, color in zip(unique_labels, colors)}

# Assign colors to each point based on its label
color_list = [color_mapping[label] for label in filtered_labels]

### create the scatter plot ###
fig = go.Figure()

fig.add_trace(go.Scatter(
    x=filtered_x_centroids,
    y=filtered_y_centroids,
    mode='markers+text',
    text=filtered_labels,
    textposition='top right',
    marker=dict(color=color_list, size=10, opacity=0.7),
    showlegend=False
))

### customize axes and layout ###
fig.update_layout(
    xaxis_title="X Coordinate",
    yaxis_title="Y Coordinate",
    title="DBSCAN Cluster Centroids",
    template="plotly_white",
)

fig.update_yaxes(
    scaleanchor="x",
    scaleratio=1,
  )

fig.show()

### uncomment if you want to save the centroids plot ###
#fig.write_html(os.path.join(model_folder,"centroids_plotly.html"))