In [3]:
# Import required packages
import time
import numpy as np
import pandas as pd
import torch
from torch import nn, Tensor
from torch.optim import AdamW
from tqdm import tqdm, trange
import zuko
from zuko.flows import Distribution, NSF
from zuko.distributions import DiagNormal, BoxUniform, Minimum
from zuko.flows import DistributionModule, FlowModule, Unconditional
from hnne import HNNE

from utils.robot import Robot
from utils.settings import param
from utils.dataset import create_dataset

In [4]:
panda = Robot(verbose=False)
# data generation
X, y = panda.random_sample_joint_config(num_samples=100_0000, return_ee=True)
hnne = HNNE(dim=4, ann_threshold=1000)
X_transformed = hnne.fit_transform(X=X[:100_0000], dim=4, verbose=True)
y = np.column_stack((y, X_transformed))
ds = create_dataset(features=X, targets=y, enable_normalize=False)
loader = ds.create_loader(shuffle=True, batch_size=128)

100%|██████████| 1000000/1000000 [00:13<00:00, 76826.75it/s]


Building h-NNE hierarchy using FINCH...
Using PyNNDescent to compute 1st-neighbours at this step ...
Thu May 18 16:19:03 2023 Building RP forest with 32 trees
Thu May 18 16:19:12 2023 NN descent for 20 iterations
	 1  /  20
	 2  /  20
	 3  /  20
	Stopping threshold met -- exiting after 3 iterations
Step PyNNDescent done ...
Level 0: 265325 clusters
Using PyNNDescent to compute 1st-neighbours at this step ...
Thu May 18 16:19:35 2023 Building RP forest with 28 trees
Thu May 18 16:19:36 2023 NN descent for 18 iterations
	 1  /  18
	 2  /  18
	 3  /  18
	Stopping threshold met -- exiting after 3 iterations
Step PyNNDescent done ...
Level 1: 65931 clusters
Using PyNNDescent to compute 1st-neighbours at this step ...
Thu May 18 16:19:39 2023 Building RP forest with 21 trees
Thu May 18 16:19:39 2023 NN descent for 16 iterations
	 1  /  16
	 2  /  16
	 3  /  16
	Stopping threshold met -- exiting after 3 iterations
Step PyNNDescent done ...
Level 2: 15912 clusters
Using PyNNDescent to compute 

In [5]:
class Config:
    def __init__(self):
        # training
        self.num_epochs = 2
        self.num_steps_save = 2000
        self.num_test_data = 60
        self.num_test_samples = 40
        self.save_path = './weights/best_manifold_learning.pth'
    
    def __repr__(self):
        return str(self.__dict__)
        

In [6]:
config = Config()
# Neural spline flow (NSF) with 3 sample features and 5 context features
flow = NSF(features=7, context=7+1, transforms=12, randperm=True, activation=nn.LeakyReLU, hidden_features=[1024] * 4).to('cuda')
flow.load_state_dict(state_dict=torch.load(config.save_path))

<All keys matched successfully>

In [7]:
def add_small_noise_to_batch(batch, esp: float = 1e-3, eval: bool = False):
    x, y = batch
    if eval:
        std = torch.zeros((x.shape[0], 1)).to(x.device)
        y = torch.column_stack((y, std))
    else:
        std = torch.rand((x.shape[0], 1)).to(x.device)
        y = torch.column_stack((y, std))
        noise = torch.normal(mean=torch.zeros_like(input=x), std=torch.repeat_interleave(input=std, repeats=x.shape[1], dim=1))
        x = x + esp * noise
    return x, y

In [8]:
path, sample_ans_q = panda.path_generate_via_stable_joint_traj()
np.save('./data/sample_ans_q', sample_ans_q)
np.save('./data/path', path)

100%|██████████| 100/100 [00:00<00:00, 37359.08it/s]


In [None]:
sample_ans_q = np.load('./data/sample_ans_q.npy')
path = np.load('./data/path.npy')

In [34]:
def sample_jtraj(path, pidx, model=flow):
    path_len = len(path)
    pidx = np.tile(pidx, (path_len,1))
    cstd = np.zeros((path_len,))
    
    y = np.column_stack((path, pidx, cstd))
    y = torch.tensor(data=y, device='cuda', dtype=torch.float32)
    
    errs = np.zeros((len(path),))
    log_probs = np.zeros((len(path),))
    
    step = 0
    x_hat = model(y).sample((1,))
    log_prob = model(y).log_prob(x_hat)
    
    x_hat = x_hat.detach().cpu().numpy()[0]
    log_prob = -log_prob.detach().cpu().numpy()[0]
    print(path.shape)
    print(x_hat.shape)
    print(log_prob.shape)
    for q, lp, ee_pos in zip(x_hat, log_prob, path):
        errs[step] = panda.dist_fk(q=q, ee_pos=ee_pos)
        log_probs[step] = lp     
        step += 1
    print(f'step={step}')
    df = pd.DataFrame(np.column_stack((errs, log_probs)), columns=['l2_err', 'log_prob'])
    qs = x_hat
    return df, qs

In [40]:
pidx = hnne.transform(X=sample_ans_q[0:3])
pidx

array([[-0.44524488,  0.61229221,  0.21142905,  0.26101529],
       [-0.44524488,  0.61229221,  0.21142905,  0.26101529],
       [-0.44524489,  0.6122922 ,  0.21142904,  0.26101529]])

In [42]:
for i, px in enumerate(pidx):
    df, qs = sample_jtraj(path, px, flow)
    print(df.describe())
    np.save(f'./data/exp_qs_{i}', arr=qs)

(100, 3)
(100, 7)
(100,)
step=100
           l2_err    log_prob
count  100.000000  100.000000
mean     0.017932   -5.861188
std      0.010550    2.194074
min      0.003041   -9.723755
25%      0.010885   -7.426585
50%      0.017146   -6.232446
75%      0.022153   -4.845385
max      0.063643    0.473040
(100, 3)
(100, 7)
(100,)
step=100
           l2_err    log_prob
count  100.000000  100.000000
mean     0.018972   -5.751236
std      0.017647    2.021458
min      0.003151   -9.228773
25%      0.010400   -7.288369
50%      0.015109   -5.968640
75%      0.020470   -4.544196
max      0.123945   -0.395426
(100, 3)
(100, 7)
(100,)
step=100
           l2_err    log_prob
count  100.000000  100.000000
mean     0.022680   -5.307285
std      0.021689    2.384039
min      0.004120   -9.228888
25%      0.010919   -7.129104
50%      0.017539   -5.697484
75%      0.025518   -3.623688
max      0.164878    2.780489


In [11]:
panda = Robot(verbose=False)

path = np.load('./data/path.npy')
err = np.zeros((100,))

for i in range(3):
    step = 0
    qs = np.load(file=f'./data/exp_qs_{i}.npy')
    for i in range(100):
        err[i] = panda.dist_fk(q=qs[i], ee_pos=path[i])
    outliner = np.where(err > 0.05)
    print(outliner)
    print(err[outliner])
    print(np.sum(err))
    # panda.plot_qs(qs)

(array([4]),)
[0.06364301]
1.793220472269379
(array([54, 79, 90, 98]),)
[0.12394481 0.07852604 0.07414823 0.09467775]
1.8972022784952
(array([ 2,  3,  9, 72, 79, 83]),)
[0.07673373 0.05092229 0.1648779  0.08106042 0.06482196 0.10785701]
2.2680039664091587
