In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch.utils.data import DataLoader, TensorDataset
import seaborn as sns
from sklearn.manifold import TSNE


import kinematics
import FK_data_generator
import network
from loss_functions import vae_loss_function
import visualization

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
LATENT_DIM = 32
model = network.PointVAE(latent_dim=LATENT_DIM, num_points_k=1024).to(device)
model.load_state_dict(torch.load('./models/point_vae_model.pth', map_location=device))
model.eval()
model

In [None]:
# --- Generate new vlaidation data ---
# Remember to match the specs of train data

num_users = 1000
num_points_per_user = 1024 
epsilon = 0.01  # meters
workspace_min = np.array([-1, -1, -1])
workspace_max = np.array([ 1,  1,  1])
    

val_joints = FK_data_generator.generate_data(
                num_users=num_users,
                num_points_per_user=num_points_per_user,
                epsilon=epsilon,
                workspace_min=workspace_min,
                workspace_max=workspace_max,
                device=device
            )  
 
val_joints = np.stack(val_joints, axis=0, dtype=np.float32)  # shape: (num_users, num_points_per_user, 4)
val_joints = torch.tensor(val_joints, dtype=torch.float32)

dataset = TensorDataset(val_joints)
val_loader = DataLoader(dataset, batch_size=100, shuffle=False)


In [None]:
# --- Validation loop ---
KL_WEIGHT = 1e-3  # KL weight


avg_total_loss = 0.0
avg_recon_loss = 0.0
avg_kl_loss = 0.0
        
with torch.no_grad():

    
    for (x,) in val_loader:
        x = x.to(device)  # shape: (batch_size, num_points_per_user, 4)
        recon_x, mu, logvar = model(x)
        
        loss, recon_loss, kl_loss = vae_loss_function(
                recon_x,
                x,
                mu,
                logvar,
                KL_WEIGHT
            )

        epoch_total_loss = loss.mean().item()
        epoch_recon_loss = recon_loss.mean().item()
        epoch_kl_loss = kl_loss.mean().item()
        
        avg_total_loss += epoch_total_loss
        avg_recon_loss += epoch_recon_loss
        avg_kl_loss += epoch_kl_loss
        
    num_batches = len(val_loader)
    print(f"Validation Total Loss: {avg_total_loss / num_batches}, \
          Recon Loss: {avg_recon_loss / num_batches}, \
          KL Loss: {avg_kl_loss / num_batches}")      
        
    

In [None]:
visualization.plot_N_joint_pairplots(val_joints[:10].cpu().numpy(), marker_size=3)
