In [None]:
from synthetic_task_gen.create_task import *
%matplotlib widget

# Plot individual randomized waypath
sample_params = {"waypoints_range":[2, 5], "obj_waypoint_bound": 150, "coplanar_scale": 2,
                 "keep_ori_prob": 0.6, "screw_prob": 0.4, "screw_step_mult": 2, "step_range": [15, 30],
                 "obj_num_range": [1, 2], "arc_radius_range": [0.8, 150], 'arc_tilt_radian': [-np.pi, np.pi],
                 "arc_radius_noise": 0.001, 'arc_tilt_noise': 0.001, 
                 "start_point": [-400, -400, 100, 0, 0, 0, 1]}
waypath_samples, waypt_samples = generate_waypath(sample_params)
display_trajectory(waypath_samples, [waypt_samples])

In [None]:
# Generate and save task dataset 
import os
import pickle as pkl

SAVE_FILE_PATH = f"./saved_data/"
NUM_DEMOS = 10
UNIQUE_OBJS = ["obj0", "obj1", "obj2"]
N_OBJS = len(UNIQUE_OBJS)

split_types = ["train", "valid", "test"]
task_sizes = [10, 25, 50]
dataset_ids = [0, 1, 2]
for tsize in task_sizes:
    for i in dataset_ids:
        generator = TaskGenerator(sample_params, UNIQUE_OBJS, n_tasks=tsize)
        generated_dataset = {t:{} for t in generator.task_names}
        for split in split_types:
            generated_data_split = generator.sample_demos(NUM_DEMOS)
            # Create train/valid/test split
            for task_id, dataset in generated_data_split.items():
                generated_dataset[task_id][f"{split}_obj"] = dataset[0]
                generated_dataset[task_id][f"{split}_seq"] = dataset[1]

        # save_path = os.path.join(SAVE_FILE_PATH, f"data_{dsize}tasks_{NUM_DEMOS}demos_set{i}.pkl")
        # with open(save_path, "wb") as fout:
        #     pkl.dump(generated_dataset, fout)

In [None]:
obj_sequence = [] 
traj_sequence = []
for i in range(1):
    obj_pose = generated_dataset['task0']['train_obj'][i][:,:7]
    traj_pose = generated_dataset['task0']['train_seq'][i][:,:7]
    obj_sequence.append(obj_pose)
    traj_sequence.append(traj_pose)
display_trajectory(traj_sequence, obj_sequence)

In [None]:
# Plot trajectories across execution
comb_waypath = np.concatenate(traj_sequence, axis=0)
count = 0
check_pts = []
check_pts.append(count)
    
dim_names = ['x','y','z','qx','qy','qz','qw']
dims = len(dim_names)
fig, ax = plt.subplots(dims,1, figsize=(6,11))
for i in range(dims):
    upper, lower = 0, 0
    for traj in traj_sequence:
        upper = max(upper, traj[:,i].max())
        lower = min(lower, traj[:,i].min())
        ax[i].plot(range(traj.shape[0]), traj[:,i], '--', color='blue', alpha=.5)
    ax[i].vlines(check_pts, [lower-0.5 for k in check_pts], [upper+0.5 for k in check_pts], color='red')
    ax[i].set_ylabel(f'{dim_names[i]}')
plt.show()