In [1]:
import os
import numpy as np
from environments import Environments
from trajectory_generator import DataGenerator

from tqdm import tqdm

In [2]:
# set up paths and save location

trajectories = True

if trajectories:
    dataset_type = "trajectories"
else:
    dataset_type = "points"
    
loc = f"../datasets/{dataset_type}/" # default path for storing datasets

if not os.path.isdir(loc):
    os.makedirs(loc)

In [3]:
environments = Environments() # geometries
data_generator = DataGenerator() # paths

In [4]:
data_samples = 15000 # total number of samples

if trajectories:
    timesteps = 501
else:
    space_samples = 500 # more samples if few timesteps
    
# set dataset splits

# 80:20 train test split
total_samples = {}
total_samples["train"] = int(data_samples*0.8) 
total_samples["val"] = int(data_samples*0.2)

# number of samples in a given split
split_samples = {split:total_samples[split]//len(environments) for split in total_samples}

In [5]:
for split in tqdm(total_samples):
    
    if trajectories:
        dataset = {"r" : [], "v": [], "c" : []}
    else:
        dataset = {"r" : [], "c" : []}
        
    for env in environments.envs:

        if trajectories:
            r, v = data_generator.generate_paths(split_samples[split], timesteps, environments.envs[env])
            dataset["v"].append(v)
            r = r[:,1:] # skip initial position
        else:
            r = data_generator.generate_points(split_samples[split], space_samples, environments.envs[env])    
        dataset["r"].append(r)
        # one-hot context signal
        c = environments.encoding(env)*np.ones((split_samples[split], r.shape[1], len(environments)))
        dataset["c"].append(c)
        

    # convert to array
    dataset = {key: np.concatenate(np.array(dataset[key]), dtype = "float32", axis = 0) for key in dataset}
    # shuffle 
    shuffle_inds = np.random.choice(len(dataset["r"]), size = len(dataset["r"]), replace = False)
    dataset = {key: dataset[key][shuffle_inds] for key in dataset }
    np.savez(f"{loc}{split}_dataset", **dataset)

100%|██████████████████████████████████████████████████████████████████████████████| 2/2 [07:45<00:00, 232.56s/it]
