In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import time
import math
import pandas as pd
import tqdm

import network

In [68]:

N_USERS = 81
M_POSES = 1024  # Number of points to sample joint-space
K_TASK_POINTS = 256   # Number of points to generate in task-space
LATENT_DIM = 128
LEARNING_RATE = 1e-4
EPOCHS = 5000
BATCH_SIZE = 4096
BETA_KL = 0.001 # Weight for KL loss


# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

joint_limits = np.loadtxt("joint_limits.csv", delimiter=',', skiprows=1)
joint_limits = torch.tensor(joint_limits, dtype=torch.float32)

joint_configs = np.loadtxt("all_users_fROM_data.csv", delimiter=',', skiprows=1)
joint_configs = torch.tensor(joint_configs, dtype=torch.float32)

DATASET_SIZE = N_USERS
TRAIN_SIZE   = math.floor(0.8 * DATASET_SIZE)
VAL_SIZE     = DATASET_SIZE - TRAIN_SIZE


Using device: cuda


In [69]:
joint_configs = joint_configs.view(-1, M_POSES, 5)  # [N, M, 4]
joint_configs = joint_configs[:, :, 1:]  # Remove user ID column

joint_limits = joint_limits[:, 1:]  # Remove user ID column

print(joint_limits.shape)
print(joint_configs.shape)

torch.Size([81, 8])
torch.Size([81, 1024, 4])


In [70]:
dataset = TensorDataset(joint_configs, joint_limits)

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [TRAIN_SIZE, VAL_SIZE])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False)

model = network.fROM_VAE_joints(joint_dim=4, latent_dim=LATENT_DIM, global_feat_dim=1024, output_limits_dim=8).to(device)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
recon_loss_fn = nn.MSELoss()

In [None]:

for epoch in range(EPOCHS):
    model.train()
    total_loss_epoch = 0
    total_recon_loss_epoch = 0
    total_kl_loss_epoch = 0
    
    for x, y in tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False):
        print(x.shape, y.shape)
        x = x.to(device)  # [B, M, 4]
        y = y.to(device)  # [B, 8]
        
        # Forward pass
        y_recon, mu, logvar = model(x) # [B, 8], [B, D_z], [B, D_z]
        
        # Loss
        # Reconstruction Loss (MSE on the 8 limit values)
        recon_loss = recon_loss_fn(y_recon, y)
        
        # KL
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        
        loss = (recon_loss / BATCH_SIZE) + (BETA_KL * kl_loss / BATCH_SIZE)
        
        if epoch == EPOCHS - 1:
            print(f"{y_recon[:3].detach().cpu().numpy()}")
            print(f"{y[:3].detach().cpu().numpy()}")
            print(f"recon_loss: {recon_loss.item():.4f}, kl_loss: {kl_loss.item():.4f}, total_loss: {loss.item():.4f}")
            
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss_epoch += loss.item() * BATCH_SIZE
        total_recon_loss_epoch += recon_loss.item()
        total_kl_loss_epoch += kl_loss.item()

    # Print epoch stats
    avg_loss = total_loss_epoch / TRAIN_SIZE  # 2560
    avg_recon = total_recon_loss_epoch / TRAIN_SIZE
    avg_kl = total_kl_loss_epoch / TRAIN_SIZE
    
    if (epoch + 1) % 500 == 0:
        print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {avg_loss:.4f}, "
              f"Recon Loss (MSE): {avg_recon:.4f}, KL Loss: {avg_kl:.4f}")

                                                     

Epoch [500/5000], Loss: 7.9976, Recon Loss (MSE): 7.5825, KL Loss: 415.1252


                                                      

Epoch [1000/5000], Loss: 4.4038, Recon Loss (MSE): 4.0747, KL Loss: 329.0664


                                                      

Epoch [1500/5000], Loss: 2.3285, Recon Loss (MSE): 2.0382, KL Loss: 290.2384


                                                      

Epoch [2000/5000], Loss: 0.6849, Recon Loss (MSE): 0.4176, KL Loss: 267.2554


                                                      

Epoch [2500/5000], Loss: 0.4121, Recon Loss (MSE): 0.1655, KL Loss: 246.5277


                                                      

Epoch [3000/5000], Loss: 0.3271, Recon Loss (MSE): 0.0905, KL Loss: 236.6620


                                                      

Epoch [3500/5000], Loss: 0.3254, Recon Loss (MSE): 0.1004, KL Loss: 224.9945


                                                      

Epoch [4000/5000], Loss: 0.3099, Recon Loss (MSE): 0.0889, KL Loss: 221.0229


                                                      

Epoch [4500/5000], Loss: 0.2829, Recon Loss (MSE): 0.0689, KL Loss: 214.0048


                                                      

TypeError: unsupported format string passed to numpy.ndarray.__format__

In [None]:
model.eval()
error = 0.0
for x, y in val_loader:
    with torch.no_grad():
        x = x.to(device)  # [B, M, 4]
        y = y.to(device)  # [B, 8]
        y_recon, mu, logvar = model(x) # [B, 8], [B, D_z], [B, D_z]
        
        error += recon_loss_fn(y_recon, y).item()
        
        print(f"{y_recon[:3].detach().cpu().numpy()}")
        print(f"{y[:3].detach().cpu().numpy()}")
        
avg_error = error / VAL_SIZE
print(f"Validation Reconstruction Error (MSE): {avg_error:.4f}")
    
    

[[-1.0767770e+02  2.9013262e+01 -4.7637493e+01  1.0009959e+02
  -5.3098145e+01  5.3158630e+01 -2.8632693e-02  1.4951682e+02]
 [-8.2999916e+01  2.2308062e+01 -8.0873634e+01  1.6301045e+02
  -4.1522953e+01  4.1347366e+01 -1.4251691e-01  3.5909512e+01]
 [-7.9858574e+01  1.9954647e+01 -5.4268040e+01  1.0796791e+02
  -9.8724815e+01  9.8368866e+01 -1.2270026e-02  9.0486626e+01]]
[[-110.   30.  -50.  100.  -60.   60.    0.  150.]
 [ -80.   20.  -90.  180.  -30.   30.    0.   40.]
 [ -80.   20.  -50.  100.  -90.   90.    0.   90.]]
Validation Reconstruction Error (MSE): 2.1194


In [65]:
model.eval()
with torch.no_grad():
    x_test_sample, y_test_sample = next(iter(val_dataset)) # [B, M, 4], [B, 8]
    
    x_input = x_test_sample[0].unsqueeze(0).to(device) # [1, M, 4]
    y_truth = y_test_sample[0].cpu().numpy()          # [8]
    
    y_pred, _, _ = model(x_input) # [1, 8]
    y_pred = y_pred.squeeze().cpu().numpy()
    
    y_truth_deg = np.rad2deg(y_truth)
    y_pred_deg = np.rad2deg(y_pred)

    y_truth_deg[0:4] = -1 * (y_truth_deg[0:4] % 360)
    y_truth_deg[4:] = y_truth_deg[4:] % 360
    y_pred_deg[0:4] = -1 * (y_pred_deg[0:4] % 360)
    y_pred_deg[4:] = y_pred_deg[4:] % 360

    
    # print("Example from Test Set (in deg):")
    # print("---------------------------------")
    # print(f"JOINT         | TRUTH (min/max)  | PRED (min/max)")
    # print(f"Shoulder Abd  | {y_truth_deg[0]:.1f} / {y_truth_deg[4]:.1f}   | {y_pred_deg[0]:.1f} / {y_pred_deg[4]:.1f}")
    # print(f"Shoulder Flex | {y_truth_deg[1]:.1f} / {y_truth_deg[5]:.1f}  | {y_pred_deg[1]:.1f} / {y_pred_deg[5]:.1f}")
    # print(f"Shoulder Rot  | {y_truth_deg[2]:.1f} / {y_truth_deg[6]:.1f}   | {y_pred_deg[2]:.1f} / {y_pred_deg[6]:.1f}")
    # print(f"Elbow Flex    | {y_truth_deg[3]:.1f} / {y_truth_deg[7]:.1f}    | {y_pred_deg[3]:.1f} / {y_pred_deg[7]:.1f}")
    
    print("Example from Test Set (in deg):")
    print("---------------------------------")
    print(f"JOINT         | TRUTH (min/max)  | PRED (min/max)")
    print(f"Shoulder Abd  | {y_truth[0]:.1f} / {y_truth[4]:.1f}   | {y_pred[0]:.1f} / {y_pred[4]:.1f}")
    print(f"Shoulder Flex | {y_truth[1]:.1f} / {y_truth[5]:.1f}   | {y_pred[1]:.1f} / {y_pred[5]:.1f}")
    print(f"Shoulder Rot  | {y_truth[2]:.1f} / {y_truth[6]:.1f}   | {y_pred[2]:.1f} / {y_pred[6]:.1f}")
    print(f"Elbow Flex    | {y_truth[3]:.1f} / {y_truth[7]:.1f}   | {y_pred[3]:.1f} / {y_pred[7]:.1f}")

    mse_rad = np.mean((y_truth - y_pred)**2) # MSE in rad
    print(f"\nExample's MSE (deg^2): {mse_rad:.6f}")
    # mse_deg = np.mean((y_truth_deg - y_pred_deg)**2) # MSE in deg=
    # print(f"Example's MSE (deg^2): {mse_deg:.6f}")

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x1 and 1024x128)