In [1]:
import numpy as np
import os
import matplotlib.pyplot as plt
import pickle
from tqdm import tqdm

In [3]:
data_dir = '/home/will/Classwork/EE227B/new_mounted_dir/data_dir'
tokenized_data_dir = '/home/will/Classwork/EE227B/new_mounted_dir/tokenized_data_dir'

num_time_bins = 512
num_space_bins = 1024

time_max, time_min = 0.5, 0
space_max, space_min = 1, -1

for count, file_name in enumerate(tqdm(os.listdir(data_dir))):
    try:
        file_path = os.path.join(data_dir, file_name)
        tokenized_file_path = os.path.join(tokenized_data_dir, file_name)

        if os.path.exists(tokenized_file_path):
            continue

        data = np.load(file_path, allow_pickle=True)

        space_list = []
        time_list = []
        total_time_list = []

        for curve_params in data["trajs"]["curve_params"]:
            start = np.array([0, 0, 2]).reshape(3, 1)

            delta_time_list = []
            control_points_list = []

            for bezier_params in curve_params:
                start_time, end_time, control_points = bezier_params
                delta_time_list.append(end_time - start_time)
                control_points_list.append(control_points[:, 1:])

            control_points_arr = np.concatenate(control_points_list, axis=1)
            prev_control_points_arr = np.concatenate([start, control_points_arr[:, :-1]], axis=1)
            control_points_diff_arr = control_points_arr - prev_control_points_arr

            delta_time_arr = np.array(delta_time_list)

            disc_diffs = np.round((control_points_diff_arr - space_min) / (space_max - space_min) * num_space_bins, 0).astype(int)
            disc_diffs = np.clip(disc_diffs, 0, num_space_bins - 1).T.reshape(-1)

            disc_delta_time = np.round((delta_time_arr - time_min) / (time_max - time_min) * num_time_bins, 0).astype(int)
            disc_delta_time = np.clip(disc_delta_time, 0, num_time_bins - 1)

            space_list.append(disc_diffs)
            time_list.append(disc_delta_time)
            total_time_list.append(end_time)

        map_array = data["metadata"]["map"]
        map_array[0] += 12 # Shift the map cell type indicator by 12
        map_array = np.transpose(map_array, [1, 2, 0])

        H, W, C = map_array.shape
        map_array = map_array.reshape(1, H * W, C)

        tokenized_results = {
            "map": map_array,
            "times": np.array(total_time_list)[np.argsort(total_time_list)],
            "trajs": {}
        }
        
        for idx, sorted_idx in enumerate(np.argsort(total_time_list)):
            disc_diffs = space_list[sorted_idx]
            disc_delta_time = time_list[sorted_idx]

            traj_array = np.concatenate([disc_delta_time.reshape(-1, 1) + num_space_bins, 
                                            disc_diffs.reshape(-1, 18)], axis=1).reshape(1, -1)
            
            tokenized_results["trajs"][idx] = traj_array

        with open(tokenized_file_path, "wb") as f:
            pickle.dump(tokenized_results, f)
    except:
        pass

  0%|          | 0/80535 [00:00<?, ?it/s]

100%|██████████| 80535/80535 [03:35<00:00, 373.33it/s]   
