In [1]:
# Import required packages
from os import path 
import wandb
import time
import numpy as np
import pandas as pd
import torch
from torch import nn, Tensor
from torch.optim import AdamW
from tqdm import tqdm, trange
import zuko
from zuko.flows import Distribution, NSF
from zuko.distributions import DiagNormal, BoxUniform, Minimum
from zuko.flows import DistributionModule, FlowModule, Unconditional
from hnne import HNNE

from utils.settings import config
from utils.utils import *
from utils.model import *
from utils.robot import Robot
from utils.dataset import create_dataset

In [2]:
panda = Robot(verbose=False)
# data generation
X, y = load_data(robot=panda, num_samples=250_0000)
# build dimension reduction model
hnne, ds, loader = get_hnne_model(X, y)
# Build Generative model, NSF
# Neural spline flow (NSF) with 3 sample features and 5 context features
flow, optimizer = get_flow_model()

Building h-NNE hierarchy using FINCH...
Using PyNNDescent to compute 1st-neighbours at this step ...
Sat May 20 02:27:29 2023 Building RP forest with 32 trees
Sat May 20 02:27:46 2023 NN descent for 21 iterations
	 1  /  21
	 2  /  21
	 3  /  21
	Stopping threshold met -- exiting after 3 iterations
Step PyNNDescent done ...
Level 0: 667390 clusters
Using PyNNDescent to compute 1st-neighbours at this step ...
Sat May 20 02:28:29 2023 Building RP forest with 32 trees
Sat May 20 02:28:32 2023 NN descent for 19 iterations
	 1  /  19
	 2  /  19
	 3  /  19
	Stopping threshold met -- exiting after 3 iterations
Step PyNNDescent done ...
Level 1: 167858 clusters
Using PyNNDescent to compute 1st-neighbours at this step ...
Sat May 20 02:28:39 2023 Building RP forest with 25 trees
Sat May 20 02:28:40 2023 NN descent for 17 iterations
	 1  /  17
	 2  /  17
	 3  /  17
	Stopping threshold met -- exiting after 3 iterations
Step PyNNDescent done ...
Level 2: 41187 clusters
Using PyNNDescent to compute

In [3]:
def train_step(model, batch):
    x, y = add_small_noise_to_batch(batch)
        
    loss = -flow(y).log_prob(x)  # -log p(x | y)
    loss = loss.mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.item()

In [4]:
# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="ikpflow",
    
    # track hyperparameters and run metadata
    config=config
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mducklyu0301[0m ([33mluca_nthu[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
step = 0
l2_label = ['mean', 'std', 'min', '25%', '50%', '75%', 'max']

for ep in range(config.num_epochs):
    t = tqdm(loader)
    for batch in t:
        loss = train_step(model=flow, batch=batch)
        
        bar = {"loss": f"{np.round(loss, 3)}"}
        t.set_postfix(bar, refresh=True)
        
        # log metrics to wandb
        wandb.log({"loss": np.round(loss, 3)})

        step += 1
        if step % config.num_steps_save == 0:
            torch.save(flow.state_dict(), config.save_path)
            df, err = test_l2_err(config, robot=panda, loader=loader, model=flow, step=step)
            l2_val = df.describe().values[1:, 0]
            l2_info = {}
            for l, v in zip(l2_label, l2_val):
                l2_info[l] = v
            wandb.log(l2_info)

 61%|██████▏   | 5999/9765 [08:25<04:57, 12.65it/s, loss=-28.7]    

In [None]:
# [optional] finish the wandb run, necessary in notebooks
wandb.finish()

In [None]:
nflow = get_nflow_model(flow=flow)
df, err = test_l2_err(config, robot=panda, loader=loader, model=nflow)
ax1 = df.plot.scatter(x='log_prob', y='l2_err')
df.describe()

In [None]:
def save_show_pose_data(config, num_data, num_samples, model=flow):
    batch = next(iter(loader))
    x, y = add_small_noise_to_batch(batch, eval=True)
    assert num_data < len(x)
    
    x_hats = np.array([])
    pidxs = np.array([])
    errs = np.array([])
    log_probs = np.array([])
    rand = np.random.randint(low=0, high=len(x), size=num_data)
    
    for nd in rand:
        x_hat = model(y[nd]).sample((num_samples,))
        log_prob = model(y[nd]).log_prob(x_hat)
        
        x_hat = x_hat.detach().cpu().numpy()
        log_prob = -log_prob.detach().cpu().numpy()
        target = y[nd].detach().cpu().numpy()
        # ee_pos = ee_pos * (ds.targets_max - ds.targets_min) + ds.targets_min
        ee_pos = target[:3]
        
        for q in x_hat:
            err = panda.dist_fk(q=q, ee_pos=ee_pos)
            errs = np.concatenate((errs, [err]))
        x_hats = np.concatenate((x_hats, x_hat.reshape(-1)))
        pidx = target[3:-1]
        pidx = np.tile(pidx, (num_samples, 1))

        pidxs = np.concatenate((pidxs, pidx.reshape(-1)))
        log_probs = np.concatenate((log_probs, log_prob))

    x_hats = x_hats.reshape((-1, panda.dof))
    pidxs = pidxs.reshape((len(x_hats), -1))
    

    save_numpy(config.show_pose_features_path, x_hats)
    save_numpy(config.show_pose_pidxs_path, pidxs)
    save_numpy(config.show_pose_errs_path, errs)
    save_numpy(config.show_pose_log_probs_path, log_probs)
    
    print('Save pose successfully')

In [None]:
save_show_pose_data(config, num_data=5, num_samples=10, model=nflow)

In [None]:
def inside_same_pidx():
    x_hats = load_numpy(file_path=config.show_pose_features_path)
    pidxs = load_numpy(file_path=config.show_pose_pidxs_path)
    
    if len(x_hats) == 0:
        raise ValueError("lack show pose data") 
    
    pre_pidx = None
    qs = np.array([])
    for q, pidx in zip(x_hats, pidxs):
        if pre_pidx is None or np.array_equal(pre_pidx, pidx):
            qs = np.concatenate((qs, q))
        else:
            break
        pre_pidx = pidx
    qs = qs.reshape((-1, panda.dof))
    for q in qs:
        panda.plot(q, q)

In [None]:
inside_same_pidx()