In [1]:
import numpy as np
from glob import glob
from pathlib import Path
import torch
from matplotlib import pyplot as plt
from tqdm import tqdm
import pickle

In [2]:
data_path = Path('/common/users/dm1487/legged_manipulation_data/rollout_data/set_8')

In [3]:
all_pieces = sorted(glob(str(data_path/'*.npz')))
idxs = np.arange(550, 1050)
pieces = [all_pieces[i] for i in idxs]
len(pieces)

500

In [4]:
def split_and_pad_trajectories(tensor, dones):
    """ Splits trajectories at done indices. Then concatenates them and padds with zeros up to the length og the longest trajectory.
    Returns masks corresponding to valid parts of the trajectories
    Example: 
        Input: [ [a1, a2, a3, a4 | a5, a6],
                 [b1, b2 | b3, b4, b5 | b6]
                f]

        Output:[ [a1, a2, a3, a4], | [  [True, True, True, True],
                 [a5, a6, 0, 0],   |    [True, True, False, False],
                 [b1, b2, 0, 0],   |    [True, True, False, False],
                 [b3, b4, b5, 0],  |    [True, True, True, False],
                 [b6, 0, 0, 0]     |    [True, False, False, False],
                ]                  | ]    
            
    Assumes that the input has the following dimension order: [time, number of envs, aditional dimensions]
    """
    # dones = dones.clone()
    dones[-1] = 1
    # Permute the buffers to have order (num_envs, num_transitions_per_env, ...), for correct reshaping
    flat_dones = dones.transpose(1, 0).reshape(-1, 1)
    # flat_dones = dones.reshape(-1, 1)

    # Get length of trajectory by counting the number of successive not done elements
    done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero()[:, 0]))
    trajectory_lengths = done_indices[1:] - done_indices[:-1]
    trajectory_lengths_list = trajectory_lengths.tolist()
    # Extract the individual trajectories
    trajectories = torch.split(tensor.transpose(1, 0).flatten(0, 1),trajectory_lengths_list)
    padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories)
    
    trajectory_masks = trajectory_lengths > torch.arange(0, tensor.shape[0], device=tensor.device).unsqueeze(1)
    
    print(trajectory_masks.shape)
    return padded_trajectories, trajectory_masks

In [5]:
keys = list(np.load(pieces[0]).keys())
keys = keys[:4]
keys

['observations',
 'privileged_observations',
 'observation_histories',
 'full_seen_world']

In [6]:
def get_pieces(key):
    tensor_pieces = []
    done_pieces = []
    start = 0
    offset = 100
    with tqdm(total=len(pieces)) as pbar:
        while True:
            tensor = None
            if start >= len(pieces):
                break
            for p in pieces[start:start+offset]:
                traj_dict = np.load(p)
                if tensor is None:
                    if key == "observation_histories":
                        tensor = torch.tensor(traj_dict[key][:, :, -37:])
                    else:
                        tensor = torch.tensor(traj_dict[key])
                    dones = torch.tensor(traj_dict['dones'], dtype=torch.bool)
                    continue
                
                if key == "observation_histories":
                    tensor = torch.cat([tensor, torch.tensor(traj_dict[key])[:, :, -37:]], dim=0)
                else:
                    tensor = torch.cat([tensor, torch.tensor(traj_dict[key])], dim=0)
                dones = torch.cat([dones, torch.tensor(traj_dict['dones'], dtype=torch.bool)], dim=0)
            tensor_pieces.append(tensor)
            done_pieces.append(dones)
            start += offset
            pbar.update(offset)
            
        return tensor_pieces, done_pieces

In [7]:
import pickle
tmp_path = Path(f'/common/users/dm1487/tmp/500_pieces_next')
tmp_path.mkdir(parents=True, exist_ok=True)
save_dones = False
for key in keys:
    a, d = get_pieces(key)
    with open(tmp_path/f'{key}.pkl', 'wb') as f:
        pickle.dump(a, f)
    if not save_dones:
        with open(tmp_path/f'dones.pkl', 'wb') as f:
            pickle.dump(d, f)
        print('done')
        save_dones = True

100%|█████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [01:55<00:00,  4.35it/s]


done


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [01:30<00:00,  5.51it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [09:22<00:00,  1.12s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [02:11<00:00,  3.81it/s]


In [8]:
def convert_to_traj_and_save(name, tensor_pieces, done_pieces, ctr = 0):
    DATA_PATH_TRAJ = Path(f'/common/users/dm1487/legged_manipulation_data/rollout_data/latest_trajectories_2/{name}')
    DATA_PATH_TRAJ.mkdir(parents=True, exist_ok=True)
    all_dones = torch.cat(done_pieces, dim=0)
    all_tensor = torch.cat(tensor_pieces, dim=0)
    all_tensor_traj = split_and_pad_trajectories(all_tensor, all_dones)
    if name == 'dones':
        traj = all_tensor_traj[1].permute(1, 0)
    else:
        traj = all_tensor_traj[0].permute(1, 0, 2)
    start = 0
    offset = 10000
    with tqdm(total=traj.shape[0]) as pbar:
        while True:
            if start >= traj.shape[0]:
                break

            np.savez_compressed(DATA_PATH_TRAJ/f'{name}_{ctr}.npz', data=traj[start:start+offset])
            start += offset
            pbar.update(offset)
            ctr += 1
    return all_tensor_traj[1], ctr

In [9]:
with open(tmp_path/f'dones.pkl', 'rb') as f:
    d = pickle.load(f)
    
for key in [*keys, 'dones']:
    a = None
    b = None
    
    print('loading...')
    with open(tmp_path/f'{key}.pkl', 'rb') as f:
        a = pickle.load(f)
    print('done')
    
    splits = 2
    start = 0
    offset = len(a)//splits
    ctr = 0
    for _ in range(splits):
        _, ctr = convert_to_traj_and_save(key, a[start:(start+offset)], d[start:(start+offset)], ctr)
        start += offset

loading...
done
torch.Size([5000, 153316])


160000it [02:13, 1195.69it/s]                                                                                                       


torch.Size([5000, 155860])


160000it [03:27, 770.51it/s]                                                                                                        


loading...
done
torch.Size([5000, 153316])


160000it [01:38, 1627.73it/s]                                                                                                       


torch.Size([5000, 155860])


160000it [01:37, 1641.44it/s]                                                                                                       


loading...
done
torch.Size([5000, 153316])


160000it [08:19, 320.40it/s]                                                                                                        


torch.Size([5000, 155860])


160000it [07:50, 340.28it/s]                                                                                                        


loading...
done
torch.Size([5000, 153316])


160000it [01:42, 1562.99it/s]                                                                                                       


torch.Size([5000, 155860])


160000it [01:41, 1581.79it/s]                                                                                                       


loading...
done
torch.Size([5000, 153316])


160000it [00:15, 10166.54it/s]                                                                                                      


torch.Size([5000, 155860])


160000it [00:19, 8099.97it/s]                                                                                                       


In [10]:
DATA_PATH_TRAJ = Path(f'/common/users/dm1487/legged_manipulation_data/rollout_data/latest_individual_traj_2')
DATA_PATH_TRAJ.mkdir(parents=True, exist_ok=True)

file_list = {}
for key in [*keys, 'dones']:
    file_list[key] = sorted(list(glob(f'/common/users/dm1487/legged_manipulation_data/rollout_data/latest_trajectories_2/{key}/*.npz')), 
                            key= lambda x: int(str(x).split('.npz')[0].split('/')[-1].split('_')[-1]))

traj_ctr = 0
with tqdm(total=len(list(file_list.values())[0])) as pbar:
    for observations, privileged_observations, observation_histories, full_seen_world, dones in zip(*file_list.values()):
        obs = np.load(observations)['data']
        priv_obs = np.load(privileged_observations)['data']
        obs_hist = np.load(observation_histories)['data']
        fsw = np.load(full_seen_world)['data']
        d = np.load(dones)['data']
        for idx in range(obs.shape[0]):
            # print(obs_hist[idx].shape, fsw[idx].shape, priv_obs[idx].shape, d[idx][:750].reshape(-1, 1).shape)
            np.savez_compressed(DATA_PATH_TRAJ/f'traj_{traj_ctr}', obs=obs[idx], priv_obs=priv_obs[idx], obs_hist=obs_hist[idx][:, -37:], fsw=fsw[idx], done=d[idx][:750].reshape(-1, 1))
            traj_ctr += 1
        pbar.update(1)

100%|████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [1:11:52<00:00, 134.78s/it]


In [15]:
123455 + 142496

265951

In [None]:
from pathlib import Path
import os
import numpy as np
from tqdm import tqdm
import pickle
DATA_PATH_TRAJ = Path(f'/common/users/dm1487/legged_manipulation_data/rollout_data/latest_individual_traj_mini')
all_files = os.listdir(DATA_PATH_TRAJ)
ignore_files = []
for i in tqdm(all_files[:]):
    traj = np.load(DATA_PATH_TRAJ/i)
    last_idx = traj['done'].nonzero()[0][-1]
    if np.sum(traj['fsw'][last_idx][:7]) != 0:
        if np.sum(traj['fsw'][last_idx][1:7] == traj['priv_obs'][last_idx][1:7]) != 6:
            
            ignore_files.append(DATA_PATH_TRAJ/i)
            
## save ignore_files as pickle

In [None]:
# all_obs_traj = split_and_pad_trajectories(all_obs, all_dones)
# all_priv_obs_traj = split_and_pad_trajectories(all_priv_obs, all_dones)
# all_obs_hist_traj = split_and_pad_trajectories(all_obs_hist, all_dones)
# # all_obs_traj = split_and_pad_trajectories(all_obs, all_dones)
# all_obs_traj.shape, all_priv_obs_traj.shape, all_obs_hist_traj.shape

In [None]:
import numpy as np
from glob import glob
from pathlib import Path
import torch
from matplotlib import pyplot as plt

In [None]:
data_path = Path('/common/users/dm1487/legged_manipulation_data/rollout_data/set_3_trajectories')

In [None]:
files_merge = sorted(list((data_path/'obs_hist').glob('*')), key= lambda x: int(str(x).split('.npz')[0].split('_')[-1]))
DATA_PATH_TRAJ = Path(f'/common/users/dm1487/legged_manipulation_data/rollout_data/set_3_trajectories/obs_hist_combined')
DATA_PATH_TRAJ.mkdir(parents=True, exist_ok=True)
# len(files_merge)
for i in range(10, len(files_merge), 10):
    files_length = 10
    if i == 50:
        files_length = 9
    files = [np.load(files_merge[j+i]) for j in range(files_length)] # np.load(files_merge[i]), np.load(files_merge[i+1])
    print(i, len(files))
    final_data = np.concatenate([f['data'] for f in files], axis=0)
    print(final_data.shape)
    np.savez_compressed(DATA_PATH_TRAJ/f'obs_hist_combined_{i}.npz', data=final_data)
    # break

In [None]:
trajectories = 200000//100000
traj_data = {
    'obs': None, 
    'obs_hist_combined': None, 
    'priv_obs': None, 
    'mask': None
}
# while length <= trajectory_length:
for i in sorted(data_path.glob('*')):
    name = str(i).split('/')[-1] 
    if  name == 'obs_hist' or name == 'mask1':
        continue
    print(i)
    for j in sorted(glob(f'{i}/*.npz'))[:trajectories]:
        if traj_data[name] is None:
            traj_data[name] = np.load(j)['data']
        else:
            traj_data[name] = np.concatenate([traj_data[name], np.load(j)['data']], axis=0)
        print(traj_data[name].shape)

In [None]:
traj_data['obs_hist_combined'].shape

In [None]:
traj_path = Path('/common/users/dm1487/legged_manipulation_data/rollout_data/set3_traj_200k')
# traj_path.mkdir(exist_ok=True, parents=True)

In [None]:
# np.savez_compressed(traj_path, **traj_data)

In [None]:
# traj_num = np.random.randint(0, 200000)
# print(traj_num)
# plt.scatter(traj_data['obs'][traj_num][:, 0], traj_data['obs'][traj_num][:, 1])

In [None]:
trajectories = 200000
obs_hist = torch.zeros(trajectories, traj_data['obs'].shape[1],  traj_data['obs'].shape[2] + traj_data['obs_hist_combined'].shape[2])
obs_hist[:, :, :traj_data['obs'].shape[2]] = torch.tensor(traj_data['obs'])
obs_hist[:, :, traj_data['obs'].shape[2]:] = torch.tensor(traj_data['obs_hist_combined'])

FILE_NAME = 'rnn_200k_data'
DATA_PATH_TRAJ = Path(f'/common/users/dm1487/legged_manipulation_data/rollout_data/set_3_trajectories/{FILE_NAME}')
np.savez_compressed(DATA_PATH_TRAJ, inp=obs_hist.numpy(), target=traj_data['priv_obs'], mask=traj_data['mask'])

In [None]:
print(DATA_PATH_TRAJ)
traj = np.load(DATA_PATH_TRAJ/'traj_3.npz')['obs']
plt.scatter(traj[:, 0], traj[:, 1])

In [None]:
traj.shape