In [None]:
import pandas as pd
import torch
import sys
import matplotlib.pyplot as plt


sys.path.append("../../")
from src.vae_architectures.lstm import LSTMVariationalAutoEncoder
from src.vae_architectures.signal_cnn import SignalCNNVariationalAutoEncoder
from src.vae_architectures.graph_cnn import GraphVariationalAutoEncoder
from src.dataset import ExerciseDataset

## Load data and model

In [None]:
from src.utils.constants import (
    HIDDEN_SIZE,
    LATENT_SIZE,
    NUM_JOINTS,
    NUM_LAYERS,
    SEQUENCE_LENGTH,
)

In [None]:
dct_lstm_autoencoder = LSTMVariationalAutoEncoder(
    SEQUENCE_LENGTH, NUM_JOINTS * 3, HIDDEN_SIZE, LATENT_SIZE, NUM_LAYERS
)
dct_lstm_autoencoder.load_state_dict(
    torch.load("../../models/squat/dct_lstm.pt", map_location=torch.device("cpu"))
)

dct_cnn_autoencoder = SignalCNNVariationalAutoEncoder(
    SEQUENCE_LENGTH, NUM_JOINTS * 3, HIDDEN_SIZE, LATENT_SIZE
)
dct_cnn_autoencoder.load_state_dict(
    torch.load("../../models/squat/dct_cnn.pt", map_location=torch.device("cpu"))
)

dct_graph_autoencoder = GraphVariationalAutoEncoder(
    SEQUENCE_LENGTH, NUM_JOINTS * 3, HIDDEN_SIZE, LATENT_SIZE
)
dct_graph_autoencoder.load_state_dict(
    torch.load("../../models/squat/dct_graph.pt", map_location=torch.device("cpu"))
)

In [None]:
total_params_cnn = sum(p.numel() for p in dct_cnn_autoencoder.parameters())
total_params_graph = sum(p.numel() for p in dct_graph_autoencoder.parameters())

print("Liczba parametrów modelu dct_cnn_autoencoder:", total_params_cnn)
print("Liczba parametrów modelu dct_graph_autoencoder:", total_params_graph)

In [None]:
squat_dct_df = pd.read_csv("../../data/train/squat/dct.csv")
squat_dct_dataset = ExerciseDataset(squat_dct_df, representation="dct")

squat_dct_test_df = pd.read_csv("../../data/test/squat/dct.csv")
squat_dct_dataset_test = ExerciseDataset(squat_dct_test_df, representation="dct")

## Generation of the embedded instances

In [None]:
import numpy as np

X = torch.stack([rep for rep in squat_dct_dataset.data])
y = np.array([1 if label == 0 else 0 for label in squat_dct_dataset.labels_encoded])

X_test = torch.stack([rep for rep in squat_dct_dataset_test.data])
y_test = np.array(
    [1 if label == 0 else 0 for label in squat_dct_dataset_test.labels_encoded]
)

In [None]:
from torch.utils.data import DataLoader

train_squat_dct_dl = DataLoader(
    squat_dct_dataset,
    batch_size=8,
    shuffle=True,
)

## Visualization of the latent space

In [None]:
models = {
    "lstm": dct_lstm_autoencoder,
    "cnn": dct_cnn_autoencoder,
    "graph": dct_graph_autoencoder,
}

In [None]:
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2)
fig, axes = plt.subplots(ncols=3, figsize=(14, 4))

for (model_name, model), axis in zip(models.items(), axes):
    X_embedded = model.encoder(X)[0].detach().numpy()
    X_test_embedded = model.encoder(X_test)[0].detach().numpy()

    latent_space = tsne.fit_transform(np.concatenate([X_embedded, X_test_embedded]))
    all_y = np.concatenate([y, y_test])
    axis.scatter(
        latent_space[all_y == 1][:, 0], latent_space[all_y == 1][:, 1], c="green"
    )
    axis.scatter(
        latent_space[all_y == 0][:, 0], latent_space[all_y == 0][:, 1], c="red"
    )

    axis.legend(["Correct", "Incorrect"])
    axis.set_title(model_name)

## Training of the classifier

In [None]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report, f1_score


for model_name, model in models.items():
    f1 = []
    for _ in range(50):
        X_embedded = model.encoder(X)[0].detach().numpy()
        X_test_embedded = model.encoder(X_test)[0].detach().numpy()

        clf = DecisionTreeClassifier().fit(X_embedded, y)
        y_pred = clf.predict(X_test_embedded)
        f1.append(f1_score(y_test, y_pred))
    print(f"{model_name} mean f1-score: {np.mean(f1)}, std: {np.std(f1)}")

In [None]:
import pickle

with open("../../models/clf.pkl", "rb") as f:
    clf = pickle.load(f)

## Generate CFE

In [None]:
from src.explainer import Explainer
from src.utils.data import get_random_sample

wrong_sample, sample_length, label = get_random_sample(train_squat_dct_dl, 3)
explainer = Explainer(dct_lstm_autoencoder, clf, train_squat_dct_dl, "squat")
latent_query, cf_sample, cf_sample_decoded = explainer.generate_cf(wrong_sample)

In [None]:
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2)
latent_space = tsne.fit_transform(
    np.concatenate([X_embedded, X_test_embedded, cf_sample, latent_query])
)
all_y = np.concatenate([y, y_test, [4], [5]])
plt.scatter(latent_space[all_y == 1][:, 0], latent_space[all_y == 1][:, 1], c="green")
plt.scatter(latent_space[all_y == 0][:, 0], latent_space[all_y == 0][:, 1], c="red")
plt.scatter(latent_space[all_y == 4][:, 0], latent_space[all_y == 4][:, 1], c="yellow")
plt.scatter(latent_space[all_y == 5][:, 0], latent_space[all_y == 5][:, 1], c="black")

plt.legend(["Correct", "Incorrect"])

In [None]:
from scipy.spatial.distance import cdist

sample_id = 3

query_instance = X_test_embedded[sample_id]

closest_correct_instances = cdist(
    np.expand_dims(query_instance, 0), X_test_embedded
).squeeze()
mask = np.where(y_test == 1)[0]

mask_argmin = closest_correct_instances[mask].argmin()
cf_id = mask[mask_argmin]
cf_instance = X_test_embedded[cf_id]

## Decode latent_space to DCT

In [None]:
cf_dct = dct_lstm_autoencoder.decoder(torch.tensor(cf_instance).unsqueeze(0))
original_dct = dct_lstm_autoencoder.decoder(torch.tensor(wrong_sample).unsqueeze(0))

## Decode DCT to pose

In [None]:
original_dct.shape

In [None]:
from src.utils.visualization import get_3D_animation
from src.utils.data import decode_dct


cf_sample = decode_dct(
    cf_dct.detach().numpy().squeeze(), squat_dct_dataset.lengths[sample_id]
)
original_sample = decode_dct(
    original_dct.detach().numpy().squeeze(), squat_dct_dataset.lengths[sample_id]
)

## Results visualization

In [None]:
original_sample_anim = get_3D_animation(original_sample, color="red")
cf_sample_anim = get_3D_animation(cf_sample, color="green");

In [None]:
from IPython.display import HTML


html_code = f"""
<video width="400" height="300" controls>
  <source src="original_sample.mp4" type="video/mp4">
</video>
<video width="400" height="300" controls>
  <source src="cf_sample.mp4" type="video/mp4">
</video>
"""

HTML(html_code)

In [None]:
from src.utils.visualization import get_3D_animation_comparison

comparison_anim = get_3D_animation_comparison(original_sample, cf_sample)

HTML(comparison_anim.to_jshtml())