-
Notifications
You must be signed in to change notification settings - Fork 424
/
create_dataset.py
102 lines (96 loc) · 3.82 KB
/
create_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import csv
import logging
# make deterministic
from mingpt.utils import set_seed
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
from torch.utils.data import Dataset
from mingpt.model_atari import GPT, GPTConfig
from mingpt.trainer_atari import Trainer, TrainerConfig
from mingpt.utils import sample
from collections import deque
import random
import torch
import pickle
import blosc
import argparse
from fixed_replay_buffer import FixedReplayBuffer
def create_dataset(num_buffers, num_steps, game, data_dir_prefix, trajectories_per_buffer):
# -- load data from memory (make more efficient)
obss = []
actions = []
returns = [0]
done_idxs = []
stepwise_returns = []
transitions_per_buffer = np.zeros(num_buffers, dtype=int)
num_trajectories = 0
while len(obss) < num_steps:
buffer_num = np.random.choice(np.arange(50 - num_buffers, 50), 1)[0]
i = transitions_per_buffer[buffer_num]
print('loading from buffer %d which has %d already loaded' % (buffer_num, i))
frb = FixedReplayBuffer(
data_dir=data_dir_prefix + game + '/1/replay_logs',
replay_suffix=buffer_num,
observation_shape=(84, 84),
stack_size=4,
update_horizon=1,
gamma=0.99,
observation_dtype=np.uint8,
batch_size=32,
replay_capacity=1000000)
if frb._loaded_buffers:
done = False
curr_num_transitions = len(obss)
trajectories_to_load = trajectories_per_buffer
while not done:
states, ac, ret, next_states, next_action, next_reward, terminal, indices = frb.sample_transition_batch(batch_size=1, indices=[i])
states = states.transpose((0, 3, 1, 2))[0] # (1, 84, 84, 4) --> (4, 84, 84)
obss += [states]
actions += [ac[0]]
stepwise_returns += [ret[0]]
if terminal[0]:
done_idxs += [len(obss)]
returns += [0]
if trajectories_to_load == 0:
done = True
else:
trajectories_to_load -= 1
returns[-1] += ret[0]
i += 1
if i >= 1000000:
obss = obss[:curr_num_transitions]
actions = actions[:curr_num_transitions]
stepwise_returns = actions[:curr_num_transitions]
returns[-1] = 0
i = transitions_per_buffer[buffer_num]
done = True
num_trajectories += (trajectories_per_buffer - trajectories_to_load)
transitions_per_buffer[buffer_num] = i
print('this buffer has %d loaded transitions and there are now %d transitions total divided into %d trajectories' % (i, len(obss), num_trajectories))
actions = np.array(actions)
returns = np.array(returns)
stepwise_returns = np.array(stepwise_returns)
done_idxs = np.array(done_idxs)
# -- create reward-to-go dataset
start_index = 0
rtg = np.zeros_like(stepwise_returns)
for i in done_idxs:
i = int(i)
curr_traj_returns = stepwise_returns[start_index:i+1] # includes i
for j in range(i-1, start_index-1, -1): # start from i-1
rtg_j = curr_traj_returns[j-start_index:i+1-start_index]
rtg[j] = sum(rtg_j) # includes i
start_index = i+1
print('max rtg is %d' % max(rtg))
# -- create timestep dataset
start_index = 0
timesteps = np.zeros(len(actions)+1, dtype=int)
for i in done_idxs:
i = int(i)
timesteps[start_index:i+1] = np.arange(i+1 - start_index)
start_index = i+1
print('max timestep is %d' % max(timesteps))
return obss, actions, returns, done_idxs, rtg, timesteps