In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import time
import math
import pandas as pd
import tqdm

import network
import kinematics

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
task_vae_model = network.fROM_VAE_task(input_dim=3, latent_dim=128, num_task_points=256).to(device)
task_vae_model.load_state_dict(torch.load("task_vae_76_4096.pth", map_location=device))

In [None]:

N_USERS = 76
M_POSES = 4096  # Number of points to sample joint-space
K_TASK_POINTS = 256   # Number of points to generate in task-space
LATENT_DIM = 128
LEARNING_RATE = 1e-4
EPOCHS = 1000
BATCH_SIZE = 76
BETA_KL = 0.001 # Weight for KL loss


joint_limits = np.loadtxt("joint_limits.csv", delimiter=',', skiprows=1)
joint_limits = torch.tensor(joint_limits, dtype=torch.float32)

joint_configs = np.loadtxt("joints_data_81_4096.csv", delimiter=',', skiprows=1)
joint_configs = torch.tensor(joint_configs, dtype=torch.float32)

task_points   = np.loadtxt("task_sp_points.csv", delimiter=',', skiprows=1)
task_points   = torch.tensor(task_points, dtype=torch.float32)

DATASET_SIZE = N_USERS
TRAIN_SIZE   = math.floor(0.8 * DATASET_SIZE)
VAL_SIZE     = DATASET_SIZE - TRAIN_SIZE

print(joint_configs.shape)
print(task_points.shape)
print(TRAIN_SIZE, VAL_SIZE)

In [None]:
joint_configs = joint_configs[:, 1:]  # Remove user ID column
print(joint_configs.shape)

task_points_FK = kinematics.forward_kinematics_pytorch(joint_configs)
task_points_FK =task_points_FK.view(-1, M_POSES, 3)

joint_configs = joint_configs.view(-1, M_POSES, 4)  # [N, M, 4]

task_points = task_points[:, 1:]  # Remove user ID column
task_points = task_points.view(-1, M_POSES, 3)  # [N, K, 3]


# joint_limits = joint_limits[:, 1:]  # Remove user ID column

print(joint_limits.shape)
print(joint_configs.shape)
print(task_points.shape)
print(task_points_FK.shape)

error_FK = nn.functional.mse_loss(task_points_FK.view(-1, 3), task_points.view(-1,3))
error_FK

In [None]:
dataset = TensorDataset(joint_configs, task_points)

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [TRAIN_SIZE, VAL_SIZE])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False)

joints_encoder = network.Encoder(input_dim=4, latent_dim=LATENT_DIM).to(device)

# trained task space vae, freeze the encodeer
task_vae_model = network.fROM_VAE_task(input_dim=3, latent_dim=128, num_task_points=256).to(device)
task_vae_model.load_state_dict(torch.load("task_vae_76_4096.pth", map_location=device))
task_encoder = task_vae_model.encoder
task_decoder = task_vae_model.decoder
for param in task_encoder.parameters():
    param.requires_grad = False
for param in task_decoder.parameters():
    param.requires_grad = False
for param in task_vae_model.parameters():
    param.requires_grad = False



optimizer = optim.Adam(joints_encoder.parameters(), lr=LEARNING_RATE)
recon_loss_fn = nn.MSELoss()

In [None]:
# train the joint encoder to match the task space representation in the latent space
EPOCHS = 10000
LEARNING_RATE = 5e-6
# joints_encoder.load_state_dict(torch.load("joints_encoder.pth", map_location=device))

In [None]:
# train the joint encoder to match the task space representation in the latent space

for epoch in range(EPOCHS):
    joints_encoder.train()
    task_encoder.eval()
    total_loss_epoch = 0
    total_recon_loss_epoch = 0
    total_kl_loss_epoch = 0
    
    for x, y in tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False):
        # print(x.dtype)
        x = x.to(device)  # [B, M, 4]
        y = y.to(device)  # [B, M, 3]
        
        # x_gt = kinematics.forward_kinematics_pytorch(x.view(-1, 4))  # [B * M, 3]
        mu_y, logvar_y = task_encoder(y.view(-1, 4096, 3))
        var_y = torch.exp(logvar_y)
        
        # Forward pass
        mu_pred, logvar_pred = joints_encoder(x) # [B, 8], [B, D_z], [B, D_z]
        var_pred = torch.exp(logvar_pred)
        
        
        # Loss
        kl_div = 0.5 * torch.sum(
            logvar_y - logvar_pred + (var_pred + (mu_pred - mu_y).pow(2)) / var_y - 1,
            dim=1 # Sum over the latent dimensions
        )
        
        # mse_loss = nn.functional.mse_loss(mu_pred, mu_y)
                
        loss = torch.mean(kl_div)
        
        if epoch == EPOCHS - 1:
            # print(f"{mu_pred[:3].detach().cpu().numpy()}")
            # print(f"{mu_y[:3].detach().cpu().numpy()}")
            print(f"recon_loss: {loss.item():.4f})")
            
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss_epoch += loss.item()


    # Print epoch stats
    avg_loss = total_loss_epoch / len(train_loader) 
    
    if (epoch + 1) % 500 == 0:
        print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {avg_loss:.10f}")

In [None]:
from network import fROM_VAE_task
import utils

task_encoder.eval()
for param in task_vae_model.parameters():
    param.requires_grad = False

recon_error = 0.0
latent_error = 0.0
for x, y in val_loader:
    with torch.no_grad():
        x = x.to(device)
        y = y.to(device)
        
        mu_pred, logvar_pred = joints_encoder(x) # [B, D_z], [B, D_z]
        mu_y, logvar_y = task_encoder(y)
        
        var_y = torch.exp(logvar_y)
        var_pred = torch.exp(logvar_pred)
        
        z_pred = task_vae_model.reparameterize(mu_pred, logvar_pred)
        z_y    = task_vae_model.reparameterize(mu_y, logvar_y)
        
        pc_pred = task_decoder(z_pred)
        pc_y    = task_decoder(z_y)
        
        recon_error += utils.chamfer_loss(pc_pred, pc_y).item()
        
        kl_div = 0.5 * torch.sum(
            logvar_y - logvar_pred + (var_pred + (mu_pred - mu_y).pow(2)) / var_y - 1,
            dim=1 # Sum over the latent dimensions
        )
        
        latent_error += torch.mean(kl_div)
        
        # print(f"{pc_pred[:3].detach().cpu().numpy()}")
        # print(f"{pc_y[:3].detach().cpu().numpy()}")
        
avg_recon_error = recon_error / len(val_loader)
avg_latent_error = latent_error / len(val_loader)
print(f"Chamfer dist: {avg_recon_error:.8f}, Latent KL: {avg_latent_error:.8f}")

    
    

In [None]:
torch.save(joints_encoder.state_dict(), 'joints_encoder.pth')

In [None]:
joints_encoder

In [None]:
# visualize some reconstructions
with torch.no_grad():
    x = joint_configs[13, :, :].to(device) # [B, 4]
    y = kinematics.forward_kinematics_pytorch(x.view(-1, 4).to(device))
    y = y.view(1, -1, 3)
    # y = task_points[1, :, :].to(device) # [B, 3]

    x = x.unsqueeze(1)
    # y = y.unsqueeze(1)

    print(1)
    mu_pred, logvar_pred = joints_encoder(x) # [B, D_z], [B, D_z]
    
    print(2)
    mu_y, logvar_y = task_encoder(y)

    z_pred = task_vae_model.reparameterize(mu_pred, logvar_pred)
    z_y    = task_vae_model.reparameterize(mu_y, logvar_y)

    pc_pred = task_decoder(z_pred)
    pc_y    = task_decoder(z_y)
    
    print("plotting")

    fig = plt.figure(figsize=(10, 10))
    pc_pred_np = pc_pred[0].detach().cpu().numpy()
    pc_y_np    = pc_y[0].detach().cpu().numpy()

    ax1 = fig.add_subplot(1, 2, 1, projection='3d')
    ax1.scatter(pc_pred_np[:, 0], pc_pred_np[:, 1], pc_pred_np[:, 2], s=2, c='r')
    ax1.set_title("Predicted point cloud")
    ax1.set_xlabel("X"); ax1.set_ylabel("Y"); ax1.set_zlabel("Z")

    ax2 = fig.add_subplot(1, 2, 2, projection='3d')
    ax2.scatter(pc_y_np[:, 0], pc_y_np[:, 1], pc_y_np[:, 2], s=2, c='b')
    ax2.set_title("Ground-truth point cloud")
    ax2.set_xlabel("X"); ax2.set_ylabel("Y"); ax2.set_zlabel("Z")

    all_pts = np.vstack([pc_pred_np, pc_y_np])
    mins = all_pts.min(axis=0)
    maxs = all_pts.max(axis=0)
    for ax in (ax1, ax2):
        ax.set_xlim(mins[0], maxs[0])
        ax.set_ylim(mins[1], maxs[1])
        ax.set_zlim(mins[2], maxs[2])
        ax.view_init(elev=20, azim=120)
    plt.show()


In [None]:
model.eval()
with torch.no_grad():
    x_test_sample, y_test_sample = next(iter(val_dataset)) # [B, M, 4], [B, 8]
    
    x_input = x_test_sample[0].unsqueeze(0).to(device) # [1, M, 4]
    y_truth = y_test_sample[0].cpu().numpy()          # [8]
    
    y_pred, _, _ = model(x_input) # [1, 8]
    y_pred = y_pred.squeeze().cpu().numpy()
    
    y_truth_deg = np.rad2deg(y_truth)
    y_pred_deg = np.rad2deg(y_pred)

    y_truth_deg[0:4] = -1 * (y_truth_deg[0:4] % 360)
    y_truth_deg[4:] = y_truth_deg[4:] % 360
    y_pred_deg[0:4] = -1 * (y_pred_deg[0:4] % 360)
    y_pred_deg[4:] = y_pred_deg[4:] % 360

    
    # print("Example from Test Set (in deg):")
    # print("---------------------------------")
    # print(f"JOINT         | TRUTH (min/max)  | PRED (min/max)")
    # print(f"Shoulder Abd  | {y_truth_deg[0]:.1f} / {y_truth_deg[4]:.1f}   | {y_pred_deg[0]:.1f} / {y_pred_deg[4]:.1f}")
    # print(f"Shoulder Flex | {y_truth_deg[1]:.1f} / {y_truth_deg[5]:.1f}  | {y_pred_deg[1]:.1f} / {y_pred_deg[5]:.1f}")
    # print(f"Shoulder Rot  | {y_truth_deg[2]:.1f} / {y_truth_deg[6]:.1f}   | {y_pred_deg[2]:.1f} / {y_pred_deg[6]:.1f}")
    # print(f"Elbow Flex    | {y_truth_deg[3]:.1f} / {y_truth_deg[7]:.1f}    | {y_pred_deg[3]:.1f} / {y_pred_deg[7]:.1f}")
    
    print("Example from Test Set (in deg):")
    print("---------------------------------")
    print(f"JOINT         | TRUTH (min/max)  | PRED (min/max)")
    print(f"Shoulder Abd  | {y_truth[0]:.1f} / {y_truth[4]:.1f}   | {y_pred[0]:.1f} / {y_pred[4]:.1f}")
    print(f"Shoulder Flex | {y_truth[1]:.1f} / {y_truth[5]:.1f}   | {y_pred[1]:.1f} / {y_pred[5]:.1f}")
    print(f"Shoulder Rot  | {y_truth[2]:.1f} / {y_truth[6]:.1f}   | {y_pred[2]:.1f} / {y_pred[6]:.1f}")
    print(f"Elbow Flex    | {y_truth[3]:.1f} / {y_truth[7]:.1f}   | {y_pred[3]:.1f} / {y_pred[7]:.1f}")

    mse_rad = np.mean((y_truth - y_pred)**2) # MSE in rad
    print(f"\nExample's MSE (deg^2): {mse_rad:.6f}")
    # mse_deg = np.mean((y_truth_deg - y_pred_deg)**2) # MSE in deg=
    # print(f"Example's MSE (deg^2): {mse_deg:.6f}")