In [None]:
import numpy as np
import pandas as pd
import torch
import sys, os, copy
from tqdm.auto import tqdm, trange
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
import warnings, time
warnings.simplefilter("ignore")
torch.set_printoptions(precision=3, sci_mode=False)

from datetime import datetime

import pybullet as p

from IPython.display import display
import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
import plotly.subplots as sp
plotly_layout = dict(margin=dict(l=20, r=20, t=20, b=20))

from torch.autograd.functional import jacobian

from training.model import get_model
from training.loader import get_dataloader
from envs import get_env

from training.model.PairwiseNet import Pairwise2Global

from utils.cubic_spline import cubic_spline_curve_manual
from planning.rrt import RRTConnectplanar
from planning.optimize import TrajectoryOptimizer

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [None]:
def get_path(start, end, N):
    assert len(start) == len(end), 'dim mismatch'
    dim = len(start)

    order = 4
    A = torch.tensor([[1, 0, 0, 0], 
                        [1, N, N**2, N**3], 
                        [0, 1, 0, 0], 
                        [0, 1, 2*N, 3*N**2]], dtype=torch.float)

    b = torch.zeros(4, dim)
    b[0, :] = start
    b[1, :] = end

    coef = torch.linalg.inv(A) @ b

    X = torch.tensor([[1, t, t**2, t**3] for t in range(0, N)], dtype=torch.float)
    path = X @ coef
    
    return path

In [None]:
MODEL_PATH = 'results/20240704-192301_fourarm_Pairwise'

cfg, best_model = None, None
for file in os.listdir(MODEL_PATH):
    if file.endswith('yml'):
        cfg = OmegaConf.load(os.path.join(MODEL_PATH, file))
    if file.endswith('best_safe_FPR.pkl'):
        best_model = torch.load(os.path.join(MODEL_PATH, file))

assert cfg is not None, 'cfg file does not exist.'
assert best_model is not None, 'best_model does not exist.'

model = get_model(cfg.model).to(device)
model.load_state_dict(best_model['model_state'])

In [None]:
env_cfg = OmegaConf.load('configs/envs/env_config_4arm.yml')
env = get_env(env_cfg, device=device, GUI=True)

In [None]:
checker = Pairwise2Global(model, cfg, env)
col_thr = 0.0

In [None]:
# # Traj 1
# q_start = torch.tensor([0.0, -0.148, 0.0, -1.271, 0.0, 2.304, 0.0, 
#                         0.0, -0.186, 0.0, -1.792, 0.0, 1.867, 0.0, 
#                         0.0, -0.278, 0.0, -2.440, 0.0, 2.562, 0.0, 
#                         0.0,  0.186, 0.0, -2.550, 0.0, 2.958, 0.0,])
# q_end = torch.tensor([0.0,  0.186, 0.0, -2.550, 0.0, 2.958, 0.0, 
#                         0.0, -0.278, 0.0, -2.440, 0.0, 2.562, 0.0, 
#                         0.0, -0.186, 0.0, -1.792, 0.0, 1.867, 0.0, 
#                         0.0, -0.148, 0.0, -1.271, 0.0, 2.304, 0.0,])

# # Traj 2
# q_start = torch.tensor([0.244, 0.445, 0.0, -1.839, 0.0, 2.343, 0.0, 
#                         0.244, 0.445, 0.0, -1.839, 0.0, 2.343, 0.0, 
#                         0.335, 0.148, 0.0, -1.239, 0.0, 2.720, 0.0, 
#                         0.335, 0.148, 0.0, -1.239, 0.0, 2.720, 0.0,])
# q_end = torch.tensor([0.335, 0.148, 0.0, -1.239, 0.0, 2.720, 0.0, 
#                         0.335, 0.148, 0.0, -1.239, 0.0, 2.720, 0.0, 
#                         0.244, 0.445, 0.0, -1.839, 0.0, 2.343, 0.0, 
#                         0.244, 0.445, 0.0, -1.839, 0.0, 2.343, 0.0,])

# Traj 3
q_start = torch.tensor([-1.250, 0.724, 0.945, -2.187, 0.793, 1.669, 0.0, 
                        -1.250, 0.724, 0.945, -2.187, 0.793, 1.669, 0.0, 
                        0.457, 0.130, 0.0, -1.176, -1.464, 1.094, 0.0, 
                        0.457, 0.130, 0.0, -1.176, -1.464, 1.094, 0.0,])
q_end = torch.tensor([0.457, 0.130, 0.0, -1.176, -1.464, 1.094, 0.0, 
                      0.457, 0.130, 0.0, -1.176, -1.464, 1.094, 0.0, 
                      -1.250, 0.724, 0.945, -2.187, 0.793, 1.669, 0.0, 
                      -1.250, 0.724, 0.945, -2.187, 0.793, 1.669, 0.0,])

print(f'q_start : GT = {env.calculate_min_distance(q_start.unsqueeze(0)).squeeze().item():.4f} | est = {checker(q_start.unsqueeze(0).to(device)).squeeze().item():.4f}')
print(f'q_end   : GT = {env.calculate_min_distance(q_end.unsqueeze(0)).squeeze().item():.4f} | est = {checker(q_end.unsqueeze(0).to(device)).squeeze().item():.4f}')

In [None]:
traj = get_path(q_start, q_end, 100).to(device)

In [None]:
curve_order = 10
traj = traj[np.linspace(0, len(traj)-1, curve_order).astype(int)]
curve = cubic_spline_curve_manual(traj, device)

In [None]:
to_cfg = {
    'length': 'joint',
    'num_sample': 10000,
    'mu_g': 10,
    'mu_v': 1,
    'col_thr': (col_thr, 'lower'),
}

optimizer = TrajectoryOptimizer(to_cfg)
results, curve, min_loss, best_curve = optimizer.optimize(curve, checker, device=device, env=env, iteration=3000, pbar=True)