In [None]:
# Import required packages
import time
import numpy as np
import pandas as pd
import torch
from torch import nn, Tensor
from torch.optim import AdamW
from tqdm import tqdm, trange
import zuko
from zuko.flows import Distribution, NSF
from zuko.distributions import DiagNormal, BoxUniform, Minimum
from zuko.flows import DistributionModule, FlowModule, Unconditional
from hnne import HNNE

from utils.utils import load_numpy
from utils.robot import Robot
from utils.settings import param
from utils.dataset import create_dataset

In [None]:

class Config:
    def __init__(self):
        
        # data
        self.x_data_path = './data/feature.npy' # joint configuration
        self.y_data_path = './data/target.npy' # end-effector position
        
        # model parameter
        self.device = 'cuda'
        self.num_features = 7
        self.num_conditions = 3 + 4 + 1 # position + posture + noise = 3-dim + 4-dim + 1-dim 
        self.num_transforms = 12
        self.subnet_shape = [1024] * 4
        self.activation = nn.LeakyReLU
        
        
        # training
        self.lr = 4e-5
        self.lr_decay = 3e-2
        self.noise_esp = 1e-3
        self.num_epochs = 2
        self.num_steps_save = 2000
        self.num_test_data = 60
        self.num_test_samples = 40
        self.save_path = './weights/NSF.pth'
        
        # log
        self.err_his_path = './log/err_his.npy'
        self.train_loss_his_path = './log/train_loss_his.npy'
        
    
    def __repr__(self):
        return str(self.__dict__)
        

In [None]:
config = Config()
panda = Robot(verbose=False)

In [None]:
# data generation
X = load_numpy(file_path=config.x_data_path)
y = load_numpy(file_path=config.y_data_path)

if len(X) == 0:
    X, y = panda.random_sample_joint_config(num_samples=100_0000, return_ee=True)
    np.save(file=config.x_data_path, arr=X)
    np.save(file=config.y_data_path, arr=y)

In [None]:
# build dimension reduction model
hnne = HNNE(dim=4, ann_threshold=1000)
X_transformed = hnne.fit_transform(X=X[:100_0000], dim=4, verbose=True)
y = np.column_stack((y, X_transformed))
ds = create_dataset(features=X, targets=y, enable_normalize=False)
loader = ds.create_loader(shuffle=True, batch_size=128)

In [None]:
# Build Generative model, NSF
# Neural spline flow (NSF) with 3 sample features and 5 context features
flow = NSF(features=config.num_features, 
           context=config.num_conditions, 
           transforms=config.num_transforms, 
           randperm=True, 
           activation=config.activation, 
           hidden_features=config.subnet_shape).to(config.device)
flow.load_state_dict(state_dict=torch.load(config.save_path))

# Train to maximize the log-likelihood
optimizer = AdamW(flow.parameters(), lr=config.lr, weight_decay=config.lr_decay)

In [None]:
def add_small_noise_to_batch(batch, esp: float = config.noise_esp, eval: bool = False):
    x, y = batch
    if eval:
        std = torch.zeros((x.shape[0], 1)).to(x.device)
        y = torch.column_stack((y, std))
    else:
        std = torch.rand((x.shape[0], 1)).to(x.device)
        y = torch.column_stack((y, std))
        noise = torch.normal(mean=torch.zeros_like(input=x), std=torch.repeat_interleave(input=std, repeats=x.shape[1], dim=1))
        x = x + esp * noise
    return x, y

In [None]:
def test_l2_err(config, step=None, model=flow):
    num_data, num_samples = config.num_test_data, config.num_test_samples
    batch = next(iter(loader))
    x, y = add_small_noise_to_batch(batch, eval=True)
    assert num_data < len(x)

    errs = np.zeros((num_data*num_samples,))
    log_probs = np.zeros((num_data*num_samples,))
    rand = np.random.randint(low=0, high=len(x), size=num_data)
    
    step = 0
    for nd in rand:
        x_hat = model(y[nd]).sample((num_samples,))
        log_prob = model(y[nd]).log_prob(x_hat)
        
        x_hat = x_hat.detach().cpu().numpy()
        log_prob = -log_prob.detach().cpu().numpy()
        ee_pos = y[nd].detach().cpu().numpy()
        # ee_pos = ee_pos * (ds.targets_max - ds.targets_min) + ds.targets_min
        ee_pos = ee_pos[:3]
        
        for q, lp in zip(x_hat, log_prob):
            errs[step] = panda.dist_fk(q=q, ee_pos=ee_pos)
            log_probs[step] = lp     
            step += 1
    print(f'step={step}')
    df = pd.DataFrame(np.column_stack((errs, log_probs)), columns=['l2_err', 'log_prob'])
    return df, errs.mean()

In [None]:
def train_step(model, batch):
    x, y = add_small_noise_to_batch(batch)
        
    loss = -flow(y).log_prob(x)  # -log p(x | y)
    loss = loss.mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.item()

In [None]:
step = 0
err_his = load_numpy(file_path=config.err_his_path)
loss_his = load_numpy(file_path=config.train_loss_his_path)
for ep in range(config.num_epochs):
    t = tqdm(loader)
    for batch in t:
        loss = train_step(model=flow, batch=batch)
        
        loss_his = np.concatenate((loss_his, [loss]))
        bar = {
            "loss": f"{np.round(loss, 3)}/{np.round(loss_his.mean(), 3)}",
            "ep": ep,
        }
        t.set_postfix(bar, refresh=True)

        step += 1
        if step % config.num_steps_save == 0:
            torch.save(flow.state_dict(), config.save_path)
            np.save(config.train_loss_his_path, loss_his)
            df, err = test_l2_err(config, step=step)
            print(df.describe())
            err_his = np.concatenate((err_his, [err]))
            np.save(config.err_his_path, err_his)

In [None]:
df, err = test_l2_err(config)
ax1 = df.plot.scatter(x='log_prob', y='l2_err')
df.describe()

In [None]:
nflow = FlowModule(
    transforms=flow.transforms, 
    base= Unconditional(
            BoxUniform,
            -torch.ones((7,))*.5,
            torch.ones((7,))*.5,
            buffer=True,
        ))
    
nflow.to('cuda')
df, err = test_l2_err(num_data=num_test_data, num_samples=num_test_samples, model=nflow)
ax1 = df.plot.scatter(x='log_prob', y='l2_err')
df.describe()

In [None]:
def show_pose(num_data, num_samples, model=flow):
    batch = next(iter(loader))
    x, y = add_small_noise_to_batch(batch, eval=True)
    assert num_data < len(x)
    
    x_hats = np.array([])
    pidxs = np.array([])
    errs = np.array([])
    log_probs = np.array([])
    rand = np.random.randint(low=0, high=len(x), size=num_data)
    
    for nd in rand:
        x_hat = model(y[nd]).sample((num_samples,))
        log_prob = model(y[nd]).log_prob(x_hat)
        
        x_hat = x_hat.detach().cpu().numpy()
        log_prob = -log_prob.detach().cpu().numpy()
        target = y[nd].detach().cpu().numpy()
        # ee_pos = ee_pos * (ds.targets_max - ds.targets_min) + ds.targets_min
        ee_pos = target[:3]
        
        for q in x_hat:
            err = panda.dist_fk(q=q, ee_pos=ee_pos)
            errs = np.concatenate((errs, [err]))
        x_hats = np.concatenate((x_hats, x_hat.reshape(-1)))
        pidx = target[3:-1]
        pidx = np.tile(pidx, (num_samples, 1))

        pidxs = np.concatenate((pidxs, pidx.reshape(-1)))
        log_probs = np.concatenate((log_probs, log_prob))

    x_hats = x_hats.reshape((-1, panda.dof))
    pidxs = pidxs.reshape((len(x_hats), -1))
    return x_hats, pidxs, errs, log_probs

In [None]:
x_hats, pidxs, errs, log_porbs = show_pose(num_data=5, num_samples=10)

In [None]:
np.save('./data/x_hats', arr=x_hats)
np.save('./data/pidxs', pidxs)
np.save('./data/errs', errs)
np.save('./data/log_porbs', log_porbs)

In [None]:
x_hats = np.load('./data/x_hats.npy')
pidxs = np.load('./data/pidxs.npy')
errs = np.load('./data/errs.npy')
log_porbs = np.load('./data/log_porbs.npy')

In [None]:
def inside_same_pidx(x_hats, pidxs):
    pre_pidx = None
    qs = np.array([])
    for q, pidx in zip(x_hats, pidxs):
        if pre_pidx is None or np.array_equal(pre_pidx, pidx):
            qs = np.concatenate((qs, q))
        else:
            break
        pre_pidx = pidx
    qs = qs.reshape((-1, panda.dof))
    for q in qs:
        panda.plot(q, q)

In [None]:
inside_same_pidx()

In [None]:
def plot_all(x_hats):
    for i, q in enumerate(x_hats):
        panda.plot(q, q)
        print(f'step={i}, pidx={pidxs[i]}')

In [None]:
plot_all()

In [None]:
path, sample_ans_q = panda.path_generate_via_stable_joint_traj()