In [14]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d, CubicSpline
import math
from mpl_toolkits.mplot3d import Axes3D


# 创建模拟数据和参数
class Args:
    def __init__(self):
        self.arm_steps_length = [0.1] * 7  # 7个关节
        self.chunk_size = 50

args = Args()

# 模拟前一个动作和当前动作
pre_action = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
actions = np.array([[1, 1, 1, 1, 1, 1, 7, 1, 1, 1, 1, 1, 1, 7]])

# 模拟统计数据
stats = {
    'qpos_mean': np.zeros(14),
    'qpos_std': np.ones(14)
}

# 假设这是您的原始函数
steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0)
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
post_process = lambda a: a * stats['qpos_std'] + stats['qpos_mean']

result = [pre_action]
post_action = post_process(actions[0]).reshape(1, -1)
print("post_action:", post_action)

# 分离夹爪和其他关节
gripper_indices = [6, 13]
arm_indices = [i for i in range(len(pre_action)) if i not in gripper_indices]

# 计算arm和gripper的最大差异
arm_diffs = np.abs(pre_action[arm_indices] - post_action[:, arm_indices])
gripper_diffs = np.abs(pre_action[gripper_indices] - post_action[:, gripper_indices])
max_arm_diff_index = np.argmax(np.sum(arm_diffs, axis=1))
max_gripper_diff_index = np.argmax(np.sum(gripper_diffs, axis=1))

# 使用更大的索引确保同步
max_diff_index = max(max_arm_diff_index, max_gripper_diff_index)
max_diff_index


post_action: [[1. 1. 1. 1. 1. 1. 7. 1. 1. 1. 1. 1. 1. 7.]]


0

In [16]:

for i in range(max_diff_index, post_action.shape[0]):
    # 计算arm的插值步数
    arm_steps = max([math.ceil(math.fabs(result[-1][j] - post_action[i][j])/steps[j]) for j in arm_indices])
    arm_steps = max(arm_steps, 1)  # 确保至少有一步

    # 对arm进行插值
    arm_inter = []
    for j in arm_indices:
        arm_inter.append(np.linspace(result[-1][j], post_action[i][j], arm_steps + 1)[1:])
    arm_inter = np.array(arm_inter).T

    # 对夹爪进行二次插值
    gripper_inter = []
    for j in gripper_indices:
        x = np.array([0, arm_steps])
        y = np.array([pre_action[j], post_action[max_diff_index][j]])
        f = interp1d(x, y, kind='quadratic', fill_value="extrapolate")
        x_new = np.linspace(0, arm_steps, arm_steps + 1)[1:]
        gripper_inter.append(f(x_new))
    gripper_inter = np.array(gripper_inter).T

    # 合并arm和gripper的插值结果
    inter = np.zeros((arm_steps, len(pre_action)))
    inter[:, arm_indices] = arm_inter
    inter[:, gripper_indices] = gripper_inter

    result.extend(inter)

# 确保结果长度正确
result = np.array(result[1:args.chunk_size+1])
if len(result) < args.chunk_size:
    result = np.pad(result, ((0, args.chunk_size - len(result)), (0, 0)), mode='edge')

result = pre_process(result)

# print(result[np.newaxis, :])

ValueError: The number of derivatives at boundaries does not match: expected 1, got 0+0

In [None]:

# 准备绘图数据
original_trajectory = np.vstack((pre_action, actions[0]))
interpolated_trajectory = result[0]

# 创建图形
fig = plt.figure(figsize=(15, 10))

# 2D plot for each joint
for i in range(14):
    ax = fig.add_subplot(3, 5, i+1)
    ax.plot(range(2), original_trajectory[:, i], 'ro-', label='Original')
    ax.plot(np.linspace(0, 1, len(interpolated_trajectory)), interpolated_trajectory[:, i], 'b-', label='Interpolated')
    ax.set_title(f'Joint {i+1}')
    ax.legend()

# 3D plot for the first 3 joints
ax = fig.add_subplot(3, 5, 15, projection='3d')
ax.plot(original_trajectory[:, 0], original_trajectory[:, 1], original_trajectory[:, 2], 'ro-', label='Original')
ax.plot(interpolated_trajectory[:, 0], interpolated_trajectory[:, 1], interpolated_trajectory[:, 2], 'b-', label='Interpolated')
ax.set_xlabel('Joint 1')
ax.set_ylabel('Joint 2')
ax.set_zlabel('Joint 3')
ax.legend()

plt.tight_layout()
plt.show()

# 打印一些统计信息
print(f"Original trajectory shape: {original_trajectory.shape}")
print(f"Interpolated trajectory shape: {interpolated_trajectory.shape}")
print(f"Number of interpolated steps: {len(interpolated_trajectory)}")