-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_data.py
62 lines (52 loc) · 1.94 KB
/
generate_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""
Data Generation Script for the VAE training
"""
import os
import argparse
import gymnasium as gym
import numpy as np
from multiprocessing import Pool
def rollout(data):
data_dir, seq_len, rollouts = data
os.makedirs(data_dir)
env = gym.make("CarRacing-v2")
for i in range(rollouts):
env.reset()
# get random actions
actions_rollout = [env.action_space.sample() for _ in range(seq_len)]
observations_rollout = []
rewards_rollout = []
dones_rollout = []
t = 0
while True:
action = actions_rollout[t]
t += 1
obs, reward, done, truncated, _ = env.step(action)
observations_rollout += [obs]
rewards_rollout += [reward]
dones_rollout += [done]
if done or truncated:
print(f"End of rollout {i} | {t} frames")
np.savez(
os.path.join(data_dir, f"rollout_{i}"),
observations=np.array(observations_rollout),
rewards=np.array(rewards_rollout),
actions=np.array(actions_rollout),
terminals=np.array(dones_rollout),
)
break
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--rollouts', help="number of rollouts", type=int, default=10_000)
parser.add_argument('--threads', help="number of threads", type=int, default=20)
parser.add_argument('--seq_len', help="sequence length", type=int, default=1000)
parser.add_argument('--dir', help="output directory", type=str, default="data/vae")
args = parser.parse_args()
os.makedirs(args.dir)
reps = args.rollouts // args.threads + 1
p = Pool(args.threads)
work = [
(os.path.join(args.dir, f"thread_{i}"), args.seq_len, reps) for i in range(args.threads)
]
print(work)
p.map(rollout, tuple(work))