**PREPARATION**

In [28]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [29]:
import math

import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm.notebook import tqdm
import sys

import ddpm
import datasets

import os
import trimesh

In [30]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

print(f"PyTorch version Installed: {torch.__version__}\nTorchvision version Installed: {torchvision.__version__}\n")
if not torch.__version__.startswith("1.11"):
    print("you are using an another version of PyTorch. We expect PyTorch 1.11.0. You may continue using your version but it"
          " might cause dependency and compatibility issues.")
if not torchvision.__version__.startswith("0.12"):
    print("you are using an another version of torchvision. We expect torchvision 0.12. You can continue with your version but it"
          " might cause dependency and compatibility issues.")

PyTorch version Installed: 2.0.1+cu117
Torchvision version Installed: 0.15.2+cu117

you are using an another version of PyTorch. We expect PyTorch 1.11.0. You may continue using your version but it might cause dependency and compatibility issues.
you are using an another version of torchvision. We expect torchvision 0.12. You can continue with your version but it might cause dependency and compatibility issues.


**DATASET & DATALOADER**

In [31]:
# DATA NORMALIZATION FUNCTION PRIOR TO DATASET GENERATION

# Import dataset class
from grasp_object_dataset import graspDataset
from torch.utils.data import DataLoader

# Function to obtain mean, std, max, and min of given dataset
def get_mean_std(main_dir, object_dir, dataset):

    # Allocate variables
    full_joints = []

    # Load all samples - only joint values!
    for sample in tqdm(dataset):
        joints = sample[0]
        full_joints.append(joints)

    mean = np.mean(full_joints, axis = 0)
    std = np.std(full_joints, axis = 0)
    max = np.max(full_joints, axis = 0)
    min = np.min(full_joints, axis = 0)
    

    return mean, std, max, min


In [38]:
# DATASET AND DATALOADER DEFINITION

# Import dataset class
from grasp_object_dataset import graspDataset
from torch.utils.data import DataLoader

# Directories of dataset grasps + objects: set to a small version of it to overfit!! Full version found in dataset_XX_full
main_dir = './dataset_grasps/'
object_dir = './dataset_objects/'

# Generate dataset with all dataset samples
main_dataset = graspDataset(main_dir, object_dir, mode = 'train', split = {'train': 1, 'val': 0, 'test': 0}, normalization=None, transform_joint = None, transform_object = None)

# Calculate mean, std, max, min of full dataset
mean_std_max_min = list(get_mean_std(main_dir, object_dir, main_dataset))
# Print not norm data
print('NOT NORMALIZED: ', mean_std_max_min)

# Datasets definition
train_dataset = graspDataset(main_dir, object_dir, mode = 'train', split = {'train': 1, 'val': 0, 'test': 0}) #, normalization=mean_std_max_min)#, transform_joint = None, transform_object = None)
val_dataset = graspDataset(main_dir, object_dir, mode = 'val', split = {'train': 1, 'val': 0, 'test': 0})#, normalization=mean_std_max_min)#, transform_joint = None, transform_object = None)
test_dataset = graspDataset(main_dir, object_dir, mode = 'test', split = {'train': 1, 'val': 0, 'test': 0})#, normalization=mean_std_max_min)#, transform_joint = None, transform_object = None)

# Comparison for normalized dataset
mean_std_max_min_2 = list(get_mean_std(main_dir, object_dir, train_dataset))
print('NORMALIZED: ', mean_std_max_min_2)

# Dataloader definition
train_dataloader = DataLoader(train_dataset , batch_size=64, shuffle=False, num_workers=2, drop_last=False)
val_dataloader = DataLoader(val_dataset , batch_size=64, shuffle=False, num_workers=2, drop_last=False)
test_dataloader = DataLoader(test_dataset , batch_size=64, shuffle=False, num_workers=2, drop_last=False)

  0%|          | 0/8000 [00:00<?, ?it/s]

NOT NORMALIZED:  [array([ 0.0392823 ,  0.31361202,  0.70816582,  0.29692368,  0.07207266,
        0.31826818,  0.49531689,  0.22077443,  0.1136933 ,  0.30901805,
        0.71981804,  0.18223994,  0.24598885, -0.00452761,  0.15958449,
        0.59792489,  0.23905943,  0.38096319,  0.97673997, -0.0487463 ,
       -0.39349605, -0.07124564, -0.17433854,  0.36593643, -2.36187123,
       -0.05151258, -0.05443905, -0.0173604 ]), array([0.02517107, 0.09351697, 0.02072358, 0.28532102, 0.00295238,
       0.04999459, 0.28136656, 0.22077773, 0.01399565, 0.04496688,
       0.36168404, 0.18209226, 0.17415622, 0.04522972, 0.06109539,
       0.2885344 , 0.23899754, 0.07810427, 0.10572949, 0.05585855,
       0.09310748, 0.07385562, 1.37517608, 0.13581058, 0.7259894 ,
       0.06679289, 0.10399817, 0.01794473]), array([ 6.44533709e-02,  4.07128990e-01,  7.28889406e-01,  5.82244694e-01,
        7.50250369e-02,  3.68262768e-01,  7.76683450e-01,  4.41552162e-01,
        1.27688959e-01,  3.53984922e-01,  1.

  0%|          | 0/8000 [00:00<?, ?it/s]

NORMALIZED:  [array([ 0.0392823 ,  0.31361202,  0.70816582,  0.29692368,  0.07207266,
        0.31826818,  0.49531689,  0.22077443,  0.1136933 ,  0.30901805,
        0.71981804,  0.18223994,  0.24598885, -0.00452761,  0.15958449,
        0.59792489,  0.23905943,  0.38096319,  0.97673997, -0.0487463 ,
       -0.39349605, -0.07124564, -0.17433854,  0.36593643, -2.36187123,
       -0.05151258, -0.05443905, -0.0173604 ]), array([0.02517107, 0.09351697, 0.02072358, 0.28532102, 0.00295238,
       0.04999459, 0.28136656, 0.22077773, 0.01399565, 0.04496688,
       0.36168404, 0.18209226, 0.17415622, 0.04522972, 0.06109539,
       0.2885344 , 0.23899754, 0.07810427, 0.10572949, 0.05585855,
       0.09310748, 0.07385562, 1.37517608, 0.13581058, 0.7259894 ,
       0.06679289, 0.10399817, 0.01794473]), array([ 6.44533709e-02,  4.07128990e-01,  7.28889406e-01,  5.82244694e-01,
        7.50250369e-02,  3.68262768e-01,  7.76683450e-01,  4.41552162e-01,
        1.27688959e-01,  3.53984922e-01,  1.0815

In [48]:
# DATASET INDICES

# train_dataset[sample][index]
# index = 0 --> joint angles
# index = 1 --> label (grasp type) - one hot vector
# index = 2 --> matrix distances
# index = 3 --> file name

print(train_dataset[1000][1])
print(len(train_dataset))

[0. 1. 0.]
8000


**DIFFUSION MODEL**

In [69]:
# HYPERPARAMETERS FOR TUNING

# model
hidden_size = 128
hidden_layers = 3
emb_size= 128
time_emb= "sinusoidal"
input_emb= "sinusoidal"

# Noise scheduler
num_timesteps = 50
beta_schedule= 'linear'

# optimizer
learning_rate = 1e-5

#training
num_epochs = 30

In [70]:
model = ddpm.MLP(
        hidden_size= hidden_size,
        hidden_layers=hidden_layers,
        emb_size= emb_size,
        time_emb= time_emb,
        input_emb= input_emb
        )

noise_scheduler = ddpm.NoiseScheduler(
        num_timesteps=num_timesteps,
        beta_schedule=beta_schedule)

optimizer = torch.optim.AdamW(
        model.parameters(),
        lr= learning_rate,
    )

global_step = 0
frames = []
losses = []
print("Training model...")
for epoch in range(num_epochs):
    model.train()
    progress_bar = tqdm(total=len(train_dataloader))
    progress_bar.set_description(f"Epoch {epoch}")
    for step, batch in enumerate(train_dataloader):
        # Variables
        joint_angles_batch = batch[0] 
        label_one_hot = batch[1]
        mat_distances = batch[2] 

        # Noise and timesteps
        noise = torch.randn(joint_angles_batch.shape)
        timesteps = torch.randint(0, noise_scheduler.num_timesteps, (joint_angles_batch.shape[0],)).long()

        noisy = noise_scheduler.add_noise(joint_angles_batch, noise, timesteps)
        noisy = noisy.double()
        noise_pred = model(noisy, timesteps, label_one_hot, mat_distances)
        loss = F.mse_loss(noise_pred, noise)
        loss.backward(loss)

        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()

        progress_bar.update(1)
        logs = {"loss": loss.detach().item(), "step": global_step}
        losses.append(loss.detach().item())
        progress_bar.set_postfix(**logs)
        global_step += 1

    progress_bar.close()


"""
    if epoch % config.save_images_step == 0 or epoch == config.num_epochs - 1:
        # generate data with the model to later visualize the learning process
        model.eval()
        sample = torch.randn(config.eval_batch_size, 2)
        timesteps = list(range(len(noise_scheduler)))[::-1]
        for i, t in enumerate(tqdm(timesteps)):
            t = torch.from_numpy(np.repeat(t, config.eval_batch_size)).long()
            with torch.no_grad():
                residual = model(sample, t)
            sample = noise_scheduler.step(residual, t[0], sample)
        frames.append(sample.numpy())

print("Saving model...")
outdir = f"exps/{config.experiment_name}"
os.makedirs(outdir, exist_ok=True)
torch.save(model.state_dict(), f"{outdir}/model.pth")

print("Saving images...")
imgdir = f"{outdir}/images"
os.makedirs(imgdir, exist_ok=True)
frames = np.stack(frames)
xmin, xmax = -6, 6
ymin, ymax = -6, 6
for i, frame in enumerate(frames):
    plt.figure(figsize=(10, 10))
    plt.scatter(frame[:, 0], frame[:, 1])
    plt.xlim(xmin, xmax)
    plt.ylim(ymin, ymax)
    plt.savefig(f"{imgdir}/{i:04}.png")
    plt.close()
    
print("Saving loss as numpy array...")
np.save(f"{outdir}/loss.npy", np.array(losses))
print("Saving frames...")
np.save(f"{outdir}/frames.npy", frames)
"""

Training model...


  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [9]:
print("Saving model...")
experiment_name = 'full_normalization'
outdir = f"exps/{experiment_name}"
os.makedirs(outdir, exist_ok=True)
torch.save(model.state_dict(), f"{outdir}/model.pth")

Saving model...


**MODEL INFERENCE --> EVALUATION OF RESULTS**

In [43]:
'''
# Load a model

model = ddpm.MLP(
        hidden_size= hidden_size,
        hidden_layers=hidden_layers,
        emb_size= emb_size,
        time_emb= time_emb,
        input_emb= input_emb
        )


model.load_state_dict(torch.load('./model_norm.pth'))
model.eval()
'''

RuntimeError: Error(s) in loading state_dict for MLP:
	Unexpected key(s) in state_dict: "conv_layers.0.weight", "conv_layers.0.bias", "conv_layers.3.weight", "conv_layers.3.bias", "conv_layers.6.weight", "conv_layers.6.bias", "conv_layers.10.weight", "conv_layers.10.bias". 
	size mismatch for joint_mlp.0.weight: copying a param with shape torch.Size([128, 4227]) from checkpoint, the shape in current model is torch.Size([128, 3715]).

In [71]:
# Directory for grasping validation
main_dir = './dataset_grasps/'
object_dir = './dataset_objects/'

# Grasp code for which we have overfitting
grasp_code = 'core-bottle-1ae823260851f7d9ea600d1a6d9f6e07'

# Load matrix distance for given object
matriz_distancias = np.load(os.path.join(object_dir+grasp_code+".npy"), allow_pickle=True)*0.11
matriz_distancias = torch.from_numpy(matriz_distancias)
matriz_distancias = matriz_distancias.reshape(1, 50, 50, 50)

# Define matrix for given object
label = torch.tensor([1, 0, 0])
label = label.reshape(1, 3)

In [72]:
# Model inference
model.eval()

eval_batch_size = 1
num_timesteps = 50
plot_step = 5

noise_scheduler = ddpm.NoiseScheduler(num_timesteps=num_timesteps)
sample = torch.randn(eval_batch_size, 28)
timesteps = list(range(num_timesteps))[::-1]

samples = []
steps = []

for i, t in enumerate(tqdm(timesteps)):
    t = torch.from_numpy(np.repeat(t, eval_batch_size)).long()
    with torch.no_grad():
        residual = model(sample, t, label, matriz_distancias)
        
    sample = noise_scheduler.step(residual, t[0], sample)
    if (i + 1) % plot_step == 0:
        samples.append(sample.numpy())
        steps.append(i + 1)

# Generated test grasp is saved in samples[-1]

  0%|          | 0/50 [00:00<?, ?it/s]

In [79]:
main_dataset[0][0]
train_dataset[5000][0]

array([ 0.06445337,  0.22009505,  0.72888941,  0.58224469,  0.07502504,
        0.26827359,  0.77668345,  0.44155216,  0.12768896,  0.35398492,
        1.08150208,  0.3643322 ,  0.07183263,  0.0407021 ,  0.0984891 ,
        0.88645929,  0.47805697,  0.30285892,  0.87101048, -0.10460485,
       -0.30038857,  0.00260998, -1.54951462,  0.23012585, -1.63588184,
       -0.11830547,  0.04955912, -0.03530513])

In [75]:
# Print of joint angles for predicted grasp
print(samples[-1][0])


# Set sample to predicted test sample
sample = samples[-1][0]


# Return normalized values to the initial values
sample = (sample + 1)/2
sample = sample * (mean_std_max_min[2]- mean_std_max_min[3]) + mean_std_max_min[3]

print(sample)

[-0.00694007  0.08205713  0.93102944  0.07395889 -0.08273968  0.29157165
  0.64628536  0.5997809   0.1282427   0.2702911   0.469738    0.01185727
  0.03065512  0.02953949 -0.0228818   0.7533447   0.3089369   0.26445988
  0.88864577  0.10281598 -0.39260814 -0.1127336  -0.15285872  0.29987124
 -2.3153622   0.1183454   0.1081029   0.09568693]
[ 3.91076078e-02  3.21285754e-01  7.27460089e-01  3.18025697e-01
  7.18283798e-02  3.32845183e-01  6.77159959e-01  3.53192702e-01
  1.15488145e-01  3.21172192e-01  8.89714780e-01  1.84399053e-01
  2.51327634e-01 -3.19155000e-03  1.58186516e-01  8.15290773e-01
  3.12894570e-01  4.01618635e-01  1.07069604e+00 -4.30031483e-02
 -4.30050807e-01 -7.95716508e-02 -3.84546214e-01  4.06662116e-01
 -4.04279965e+00 -4.36079503e-02 -4.31965433e-02 -1.56433268e-02]


**VISUALIZATION OF RESULT**

In [18]:
import random
from utils.hand_model_lite import HandModelMJCFLite
import numpy as np
import transforms3d
import torch
import trimesh

In [19]:
mesh_path = './test_files/meshdata/'
data_path = './test_files/dexgraspnet/'


use_visual_mesh = False

hand_file = "mjcf/shadow_hand_vis.xml" if use_visual_mesh else "mjcf/shadow_hand_wrist_free.xml"

joint_names = [
    'robot0:FFJ3', 'robot0:FFJ2', 'robot0:FFJ1', 'robot0:FFJ0',
    'robot0:MFJ3', 'robot0:MFJ2', 'robot0:MFJ1', 'robot0:MFJ0',
    'robot0:RFJ3', 'robot0:RFJ2', 'robot0:RFJ1', 'robot0:RFJ0',
    'robot0:LFJ4', 'robot0:LFJ3', 'robot0:LFJ2', 'robot0:LFJ1', 'robot0:LFJ0',
    'robot0:THJ4', 'robot0:THJ3', 'robot0:THJ2', 'robot0:THJ1', 'robot0:THJ0'
]
translation_names = ['WRJTx', 'WRJTy', 'WRJTz']
rot_names = ['WRJRx', 'WRJRy', 'WRJRz']

In [20]:
# Load hand file
hand_model = HandModelMJCFLite(
    hand_file,
    "mjcf/meshes")

ValueError: Invalid geometry type box.

In [21]:
# Load grasp original code and object
grasp_code = 'core-bottle-1ae823260851f7d9ea600d1a6d9f6e07'
grasp_data = np.load(
    os.path.join(data_path, grasp_code+".npy"), allow_pickle=True)
object_mesh_origin = trimesh.load(os.path.join(
    mesh_path, grasp_code, "coacd/decomposed.obj"))

{'robot0:FFJ3': -0.24548496305942535, 'robot0:FFJ2': 0.6087559461593628, 'robot0:FFJ1': 0.7188290357589722, 'robot0:FFJ0': 0.2019304633140564, 'robot0:MFJ3': -0.17577102780342102, 'robot0:MFJ2': 0.38283923268318176, 'robot0:MFJ1': 1.1127177476882935, 'robot0:MFJ0': 0.403972327709198, 'robot0:RFJ3': -0.07210764288902283, 'robot0:RFJ2': 0.1400628387928009, 'robot0:RFJ1': 0.9195994138717651, 'robot0:RFJ0': 0.41342633962631226, 'robot0:LFJ4': 0.18373392522335052, 'robot0:LFJ3': 0.025994563475251198, 'robot0:LFJ2': 0.13054056465625763, 'robot0:LFJ1': 0.9286351203918457, 'robot0:LFJ0': 0.47552821040153503, 'robot0:THJ4': -0.0618005096912384, 'robot0:THJ3': 1.1335370540618896, 'robot0:THJ2': 0.041698239743709564, 'robot0:THJ1': -0.13501165807247162, 'robot0:THJ0': -2.997061756104813e-06, 'WRJRx': 3.0913850314695135, 'WRJRy': 0.2305903394170629, 'WRJRz': -2.3657796059005416, 'WRJTx': -0.043085746467113495, 'WRJTy': 0.09048712253570557, 'WRJTz': 0.1175876185297966}


In [124]:
# Select first grasph - could be any
index = 0

# Set sample to predicted test sample
sample = samples[-1][0]

# Return normalized values to the initial values
sample = (sample + 1)/2
sample = sample * (mean_std_max_min[2]- mean_std_max_min[3]) + mean_std_max_min[3]


# Substitute values in grasp for the predicted ones
qpos = grasp_data[index]['qpos']
i=0
for key in qpos:
    qpos[key] = sample[i]
    i += 1


# Plot the results
rot = np.array(transforms3d.euler.euler2mat(
    *[qpos[name] for name in rot_names]))
rot = rot[:, :2].T.ravel().tolist()
hand_pose = torch.tensor([qpos[name] for name in translation_names] + rot + [qpos[name]
                         for name in joint_names], dtype=torch.float, device="cpu").unsqueeze(0)
hand_model.set_parameters(hand_pose)
hand_mesh = hand_model.get_trimesh_data(0)
object_mesh = object_mesh_origin.copy().apply_scale(0.11)#grasp_data[index]["scale"])

(hand_mesh+object_mesh).show()