In [None]:
import h5py
import torch
import numpy as np
import matplotlib.pyplot as plt
from random import randint

from keypoint_dataset import KeypointDataset, keypoint_normalization,gaussian_jitter, length_variance, rotation_2D, scaling, horizontal_flip



In [None]:

# Visualizar keypoints
def plot_keypoints(keypoints, title="Keypoints"):
    """
    Función para graficar los keypoints 2D.
    keypoints: Tensor de shape (T, N, 2) donde T es el número de frames, N es el número de keypoints y 2 son las coordenadas.
    """
    plt.figure(figsize=(10, 6))
    for frame in range(keypoints.shape[0]):  # Iteramos sobre cada frame de la secuencia
        plt.plot(keypoints[frame, :, 0].numpy(), keypoints[frame, :, 1].numpy(), 'bo-', alpha=0.6)  # 'bo-' es para dibujar puntos con líneas
    plt.title(title)
    plt.xlabel("X")
    plt.ylabel("Y")
    plt.show()


In [None]:
! dir "/home/giorgio6846/Code/Sign-AI/data/"

In [None]:
h5_file = "/home/giorgio6846/Code/Sign-AI/data/dataset_clean_clean.hdf5"
kd = KeypointDataset(h5_file, data_augmentation=True, return_label=True)
train_dataset, validation_dataset, train_length, val_length = kd.split_dataset(0.5)

In [None]:
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

def visualize_augmentation(keypoints, f_keypoint, augmentation_name=""):
    """
    Visualiza la animación de los keypoints originales y los keypoints después de una augmentación.
    
    Args:
    - keypoints: Tensor de shape (T, N, 2) para los keypoints originales.
    - f_keypoint: Tensor de shape (T, N, 2) para los keypoints después de aplicar la augmentación.
    - augmentation_name: Nombre de la augmentación para mostrar en los títulos.
    """
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))  # ax será ahora un array con 2 elementos
    ax[0].set_title(f'Keypoints Originales - {augmentation_name}')
    ax[1].set_title(f'Keypoints aumentados - {augmentation_name}')

    # Graficar los puntos clave en ambos subgráficos
    sc1 = ax[0].scatter(keypoints[0, :, 0], keypoints[0, :, 1], s=10, c='blue', alpha=0.5)
    sc2 = ax[1].scatter(f_keypoint[0, :, 0], f_keypoint[0, :, 1], s=10, c='red', alpha=0.5)

    # Configuración de la visualización
    for a in ax:
        a.grid(True)
        a.set_aspect('equal')
        a.invert_yaxis()  # Invertir eje Y para que coincida con la convención de coordenadas

    # Función para actualizar la animación
    def update(frame):
        points = keypoints[frame]
        filtered_points = f_keypoint[frame]
        sc1.set_offsets(points)  # Actualizar posiciones de keypoints originales
        sc2.set_offsets(filtered_points)  # Actualizar posiciones de keypoints filtrados
        ax[0].set_title(f'Frame {frame} - Original')
        ax[1].set_title(f'Frame {frame} - Filtrado')
        return sc1, sc2

    #  Definir cuántos frames usar (el mínimo entre ambos)
    frames = min(keypoints.shape[0], f_keypoint.shape[0])

    # Crear y devolver la animación
    anim = FuncAnimation(fig, update, frames=frames, interval=50, blit=True)
    return HTML(anim.to_jshtml())


#visualize_augmentation(keypoints, f_keypoint, augmentation_name="Gaussian Jitter")


In [None]:
keypoints = train_dataset[2][0]
visualize_augmentation(train_dataset[0][0], train_dataset[4][0], augmentation_name="Gaussian Jitter")

In [None]:

keypoints = kd[1200][0].clone()

import time

start = time.time()
f_keypoint = gaussian_jitter(keypoints, sigma=0.004, clip=3.0)

visualize_augmentation(keypoints, f_keypoint, augmentation_name="Gaussian Jitter")


In [None]:
f_keypoint = length_variance(keypoints)

visualize_augmentation(keypoints, f_keypoint, augmentation_name="Length Variance")

In [None]:

#Aplicar augmentación: Rotation 2D
f_keypoint = rotation_2D(keypoints)
visualize_augmentation(keypoints, f_keypoint, augmentation_name="Rotation 2D")


In [None]:


# # Aplicar augmentación: Horizontal Flip
# f_keypoint = horizontal_flip(keypoints)
# visualize_augmentation(keypoints, f_keypoint, augmentation_name="Horizontal Flip")


In [None]:

# Aplicar augmentación: Scaling
f_keypoint = scaling(keypoints)
visualize_augmentation(keypoints, f_keypoint, augmentation_name="Scaling")

In [None]:


import time

keypoints = kd[1200][0].clone()

def measure_time(func, *args, **kwargs):
    start = time.time()
    result = func(*args, **kwargs)
    end = time.time()
    print(f"Time taken: {((end - start)*1000):.4f} ms")
    return result

f_keypoint = measure_time(gaussian_jitter, keypoints, sigma=0.004, clip=3.0)
f_keypoint = measure_time(rotation_2D, keypoints)
f_keypoint = measure_time(scaling, keypoints)
f_keypoint = measure_time(length_variance, keypoints)


