In [None]:
import torch
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

# check if .git is in the path
import os
while not os.path.exists('.git'):
    os.chdir('..')

from src.data.addBiomechanicsDataset import AddBiomechanicsDataset, trial_id_mapping
from src.vaemodel import VAEModelWrapper

## Load addbiomechanics Dataset

In [None]:
dataset = AddBiomechanicsDataset(
    "/home/public/data/AddBiomechanicsDataset/train/With_Arm/",
    1,
    '',
    testing_with_short_dataset=False
)



## General Stuff, dataset exploration

In [None]:
# Label distribution
trial_original_names = []
trial_num_of_frames = []
last_name = 'efgrugfewlkr'
for sub in dataset.subjects:
    if sub is None:
        continue
    for trial in range(sub.getNumTrials()):
        if sub.getTrialOriginalName(trial) != last_name:
            trial_original_names.append(sub.getTrialOriginalName(trial))
            trial_num_of_frames.append(sub.getTrialLength(trial))
        last_name = sub.getTrialOriginalName(trial)
        trial_num_of_frames[-1] = sub.getTrialLength(trial)

trial_numbers = {'static': 0, 'walk': 0, 'run': 0, 'other': 0, 'gait_any': 0, 'sit_to_stand': 0, 'stair': 0, 'tttt': 0}
other = []
for t, n in zip(trial_original_names, trial_num_of_frames):
    if t.lower().find('static') != -1:
        trial_numbers['static'] += n
    elif t.lower().find('walk') != -1:
        trial_numbers['walk'] += n
    elif t.lower().find('run') != -1:
        trial_numbers['run'] += n
    elif t.lower().find('gait') != -1:
        trial_numbers['gait_any'] += n
    elif t.lower().find('sts') != -1:
        trial_numbers['sit_to_stand'] += n
    elif t.lower().find('stair') != -1:
        trial_numbers['stair'] += n
    elif t.startswith('t'):
        trial_numbers['tttt'] += n
    else:
        trial_numbers['other'] += n
        other.append(t)
print(trial_numbers)

## Latent space analysis

In [None]:
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
state_keys = ["pos", "vel"]
mu, logvar, recon_data, label = [], [], [], []
data_in, data_out = [], []
vae_model = VAEModelWrapper("result/model/BiomechPriorVAE_best.pth", "result/model/scaler.pkl", num_dofs = 54, latent_dim = 24)
for batch_idx, (data, label_, seq_len, _) in enumerate(tqdm(dataloader)):
    data = torch.concat([data[key] for key in state_keys], dim=-1)
    mu_, logvar_ = vae_model.model.encode(data)
    recon_data_ = vae_model.model.decode(mu_)
    mu.append(mu_.detach().cpu())
    logvar.append(logvar_.detach().cpu())
    recon_error = torch.mean((recon_data_ - data) ** 2, dim=(1,2))
    recon_data.append(recon_error.detach().cpu())
    label.append(label_['trialname'].detach().cpu())
    data_in.append(data.detach().cpu())
    data_out.append(recon_data_.detach().cpu())
    if batch_idx == 300: 
        break

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
r_error = np.array(torch.concat(recon_data, dim=0))**0.5
labels_lin = np.array(torch.concat(label, dim=0))
keys = list(trial_id_mapping.keys())
labels_text = [keys[int(i)] for i in labels_lin]
sns.boxplot(x=labels_text, y=r_error)
plt.xlabel('Trial Type')
plt.ylabel('Reconstruction Error (MSE)')
plt.title('Reconstruction Error by Trial Type')
plt.ylim(0, 3)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
data_in_np = np.array(torch.concat(data_in, dim=0))
data_out_np = np.array(torch.concat(data_out, dim=0))
r_error = np.mean((data_out_np[...,:27] - data_in_np[...,:27]) ** 2, axis=(1,2))**0.5
labels_lin = np.array(torch.concat(label, dim=0))
keys = list(trial_id_mapping.keys())
labels_text = [keys[int(i)] for i in labels_lin]
sns.boxplot(x=labels_text, y=r_error)
plt.xlabel('Trial Type')
plt.ylabel('Reconstruction Error (MSE)')
plt.title('Reconstruction Error by Trial Type, Pose Only')
plt.ylim(0, 2)

In [None]:
# Latent-space visualization: ALL movement types (UMAP if available, else t-SNE)
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
try:
    import umap
    use_umap = True
except Exception:
    use_umap = False

# assemble latent vectors and labels (from previous cells)
z = np.array(torch.concat(mu, dim=0))          # shape (N, latent_dim)
labels_arr = np.array(torch.concat(label, dim=0)).astype(int)

# invert mapping to get textual labels
inv_map = {v: k for k, v in trial_id_mapping.items()}
label_names = np.array([inv_map[int(i)] for i in labels_arr])

# Use all samples (no filtering) and compute counts per movement
unique_labels, counts = np.unique(label_names, return_counts=True)
order = np.argsort(-counts)  # descending by count
unique_labels = unique_labels[order]
counts = counts[order]
print('Found categories (label:count):', list(zip(unique_labels.tolist(), counts.tolist())))

# embed the full latent set to 2D
z_sel = z.copy()
labels_sel = label_names.copy()
if z_sel.shape[0] == 0:
    print('No latent vectors found. Make sure earlier cells collected mu and label.')
else:
    if use_umap:
        reducer = umap.UMAP(n_components=2, random_state=0)
        z2 = reducer.fit_transform(z_sel)
        method = 'UMAP'
    else:
        # TSNE can be slow for many points; you can downsample if needed
        reducer = TSNE(n_components=2, init='pca', random_state=0, perplexity=30)
        z2 = reducer.fit_transform(z_sel if z_sel.ndim==2 else z_sel.squeeze())
        method = 't-SNE'

    # optional: point sizes from reconstruction error if available
    sizes = None
    try:
        data_in_np = np.array(torch.concat(data_in, dim=0))
        data_out_np = np.array(torch.concat(data_out, dim=0))
        per_sample_err = np.mean((data_out_np - data_in_np)**2, axis=(1,2))**0.5
        sizes = 20 + (per_sample_err - per_sample_err.min()) / (per_sample_err.ptp() + 1e-8) * 80
    except Exception:
        sizes = None

    # prepare colors: use a categorical palette sized to number of unique labels
    n_cats = len(unique_labels)
    palette_list = sns.color_palette('tab20' if n_cats<=20 else 'husl', n_colors=n_cats)
    color_map = {lab: palette_list[i] for i, lab in enumerate(unique_labels)}
    colors = np.array([color_map[l] for l in labels_sel])

    # Scatter plot of all movement types
    plt.figure(figsize=(10,8))
    for i, lab in enumerate(unique_labels):
        sel = (labels_sel == lab)
        if sel.sum() == 0:
            continue
        plt.scatter(z2[sel,0], z2[sel,1],
                    s=sizes[sel] if sizes is not None else 30,
                    color=[color_map[lab]], label=f"{lab} (n={sel.sum()})", alpha=0.19, edgecolors='none')
    plt.legend(bbox_to_anchor=(1.05,1), loc='upper left')
    plt.title(f'Latent {method} projection â€” all movement types (n={z_sel.shape[0]})')
    plt.xlabel('dim 1')
    plt.ylabel('dim 2')
    sns.despine()
    plt.tight_layout(rect=[0,0,0.78,1])
    plt.show()

    # KDE contour overlays for categories with enough points
    plt.figure(figsize=(10,8))
    for lab in unique_labels:
        sel = (labels_sel == lab)
        if sel.sum() >= 20:
            sns.kdeplot(x=z2[sel,0], y=z2[sel,1], color=color_map[lab], levels=3, thresh=0.05, fill=False, alpha=0.9)
    plt.scatter(z2[:,0], z2[:,1], s=8, c=colors, alpha=0.25)
    plt.title('Density contours (all movement types)')
    plt.xlabel('dim 1')
    plt.ylabel('dim 2')
    sns.despine()
    plt.tight_layout(rect=[0,0,0.78,1])
    plt.show()

# End of all-movements visualization

In [None]:
z = np.array(torch.concat(mu, dim=0)).squeeze()   # shape (N, latent_dim)
label_ = np.array(torch.concat(label, dim=0)).astype(int).squeeze()
figure = plt.figure(figsize=(12,12))
# skip all warnings
import warnings
warnings.filterwarnings("ignore")


for i in range(z.shape[-1]):
    for j in range(i, z.shape[-1]):
        if i == j:
            # KDE plot on the diagonal
            ax = plt.subplot(z.shape[-1], z.shape[-1], i * z.shape[-1] + j + 1)
            sns.kdeplot(x=z[:, i], fill=False, ax=ax, hue=label_, palette=palette_list, legend=False, hue_order=np.arange(len(unique_labels)))

        else:
            # Scatter plot off-diagonal
            ax = plt.subplot(z.shape[-1], z.shape[-1], i * z.shape[-1] + j + 1)
            scatter = ax.scatter(z[:, i], z[:, j], c=colors, s=5, alpha=0.6)
        # Remove all axis labels for clarity
        ax.set_xticks([])
        ax.set_yticks([])
        # Remove axis spines
        for spine in ax.spines.values():
            spine.set_visible(False)
        # Remove grid lines
        ax.grid(False)
        # Remove labels and ticks
        ax.set_xlabel('')
        ax.set_ylabel('')
        # Clear legend
        ax.legend_.remove() if ax.legend_ else None

# Crate a single legend for the entire figure in the bottom left corner
handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color_map[lab], markersize=10, label=lab) for lab in unique_labels]
figure.legend(handles=handles, loc='lower left', bbox_to_anchor=(0.1, 0.1), fontsize='small')
plt.suptitle('Latent Space Pairwise Scatter and KDE Plots', y=0.1)
plt.tight_layout()





In [None]:
unique_labels