# 2. Curve comparison
In this notebook we:
- Load the DPS curves that were collected in step 1.
- Normalize the lengths so that they can be treated as datapoints.
- Use Kmeans to group this data into 3 clusters, and get the centroid to obtain 3 curves that represetn each cluster.
- Calculate differences between curves.

## Setup

In [None]:
# Installs
import sys
!echo "Purging pip environment and installing packages..."
!{sys.executable} -m pip cache purge 
!{sys.executable} -m pip uninstall -y jhutils 
!{sys.executable} -m pip install -q git+https://github.com/jdchart/jh-py-utils.git

# Imports
print("Importing packages...")
import os
from jhutils.local_files import collect_files
import dps
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import utils
from sklearn.cluster import KMeans
print("Ready!")

## Analysis config
- `NORM_LEN`: the number of samples to use as dimensions for each curve.

In [None]:
DPS_CURVES = "/Users/jacob/Documents/Repos/dps/projects/data/output/dps_curves/dps_curve_32_1920_20"
OUTPUT_DEST = "/Users/jacob/Documents/Repos/dps/projects/data/output/curve_comparison"
FPS = 32
NORM_LEN = 5000
NUM_CLUSTERS = 3

THRESHOLD = 0.3
TIME_WINDOW = 1
WINDOW_SIZE = int(TIME_WINDOW * FPS)

## Normalize curves

In [None]:
analysis_files = collect_files(DPS_CURVES, ["npy"])
processed_curves = []

for file in analysis_files:
    loaded = np.load(file)
    interpolated_curve = np.interp(
        np.linspace(0, len(loaded)-1, NORM_LEN),
        np.arange(len(loaded)), 
        loaded
    )
    processed_curves.append(interpolated_curve)

processed_curves = np.array(processed_curves)

if np.isnan(processed_curves).any():
    processed_curves = np.nan_to_num(processed_curves, nan=0.0)
else:
    pass

print(processed_curves.shape)

## Clustering

In [None]:
kmeans = KMeans(n_clusters= NUM_CLUSTERS, random_state=42)
cluster_labels = kmeans.fit_predict(processed_curves)
cluster_centers = kmeans.cluster_centers_

print("👍 Clustering complete.")

In [None]:
plt.figure(figsize=(15, 5 * NUM_CLUSTERS))

for cluster_id in range(NUM_CLUSTERS):
    plt.subplot(NUM_CLUSTERS, 1, cluster_id + 1)

    for i, curve in enumerate(processed_curves):
        if cluster_labels[i] == cluster_id:
            plt.plot(
                np.linspace(0, len(curve)/FPS, NORM_LEN),
                curve,
                label=f"File {os.path.basename(analysis_files[i])}",
                alpha=0.6
            )

    plt.plot(
        np.linspace(0, len(curve)/FPS, NORM_LEN),
        cluster_centers[cluster_id],
        label=f"Cluster {cluster_id} Center",
        linewidth=3,
        linestyle="--",
        color="black"
    )

    plt.title(f"Cluster {cluster_id} from the clustering of WPS curves from 13 runthroughs and performances of Le Malade Imaginaire directed by Arthur Nauzyciel")
    plt.xlabel("Time (MM:SS)")
    plt.ylabel("DPS")
    plt.legend(fontsize="small")
    plt.tight_layout()
    time_max = len(cluster_centers[0]) / FPS
    x_ticks = np.arange(0, time_max + 1, 5)
    plt.xticks(x_ticks)
    plt.grid()

os.makedirs(os.path.join(OUTPUT_DEST, f"{os.path.basename(DPS_CURVES)}"), exist_ok = True)
plt.savefig(os.path.join(OUTPUT_DEST, f"{os.path.basename(DPS_CURVES)}", "Curve clusters.png"), dpi=300, bbox_inches='tight')
plt.show()
plt.close()

## Centroid comparison

In [None]:
THRESH = 0.7

differentials = {}
for i in range(len(cluster_centers)):
    for j in range(i + 1, len(cluster_centers)):
        differ = np.abs(cluster_centers[i] - cluster_centers[j])
        differentials[f"{i}-{j}"] = differ

plt.figure(figsize=(15, 5))
for pair, differ in differentials.items():
    plt.plot(np.linspace(0, len(differ) / FPS, len(differ)), 
             differ, label=f"Différence intégrale entre Cluster {pair}")

plt.title("Differences entre centroids des clusters")
plt.xlabel("Time (s)")
plt.ylabel("Difference in Cumulative Variation")
plt.legend()
plt.tight_layout()
time_max = len(cluster_centers[0]) / FPS
x_ticks = np.arange(0, time_max + 1, 5)
plt.xticks(x_ticks)

plt.grid()
plt.savefig(os.path.join(OUTPUT_DEST, f"{os.path.basename(DPS_CURVES)}", "Difference curves.png"), dpi=300, bbox_inches='tight')
plt.show()
plt.close()

plt.figure(figsize=(15, 5 * len(differentials)))
for idx, (pair, differ) in enumerate(differentials.items()):
    plt.subplot(len(differentials), 1, idx + 1)
    plt.plot(np.linspace(0, len(differ) / FPS, len(differ)), differ, label=f"Différence intégrale entre Cluster {pair}")
    
    high_indices = np.where(differ > THRESH)[0]
    plt.scatter(
        np.array(high_indices) / FPS,
        [differ[i] for i in high_indices], 
        color='red', 
        label=f"High Variation (>{THRESH})", 
        zorder=5
    )
    
    plt.axhline(y=THRESH, color='red', linestyle='--', label=f"Threshold = {THRESH}")  # Ligne de seuil
    plt.title(f"Différentiel entre Cluster {pair}")
    plt.xlabel("Time (s)")
    plt.ylabel("Difference in Cumulative Variation")
    plt.legend()
plt.tight_layout()

plt.savefig(os.path.join(OUTPUT_DEST, f"{os.path.basename(DPS_CURVES)}", "Difference curves (each cluster).png"), dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
single_graph_key = "0-1 vs 0-2"

differentials = {}
for i in range(len(cluster_centers)):
    for j in range(i + 1, len(cluster_centers)):
        differ = np.abs(cluster_centers[i] - cluster_centers[j])
        differentials[f"{i}-{j}"] = differ

print("Clés des différentiels :", list(differentials.keys()))

differentials_of_differentials = {}
diff_keys = list(differentials.keys())

for i in range(len(diff_keys)):
    for j in range(i + 1, len(diff_keys)):
        key1, key2 = diff_keys[i], diff_keys[j]
        differ1, differ2 = differentials[key1], differentials[key2]
        differ_of_differ = np.abs(differ1 - differ2)
        differentials_of_differentials[f"{key1} vs {key2}"] = differ_of_differ

print("Clés des différentiels des différentiels :", list(differentials_of_differentials.keys()))

# Utilisation d'une clé valide
if single_graph_key not in differentials_of_differentials:
    print(f"Clé '{single_graph_key}' introuvable. Utilisation d'une clé par défaut.")
    single_graph_key = list(differentials_of_differentials.keys())[0]

# Traçage
utils.plot_differentials_with_segments(differentials_of_differentials, FPS, THRESHOLD, TIME_WINDOW, WINDOW_SIZE, os.path.join(OUTPUT_DEST, f"{os.path.basename(DPS_CURVES)}", "Mutations.png"), key=single_graph_key)