In [72]:
from munch import Munch
from torch.utils.data import DataLoader
from plm_special.data.dataset import ExperienceDataset

In [73]:

class ExperiencePool:
    """
    Experience pool for collecting trajectories.
    """
    def __init__(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.dones = []

    def add(self, state, action, reward, done):
        
        self.states.append(state)  # sometime state is also called obs (observation)
        self.actions.append(action)
        self.rewards.append(reward)
        self.dones.append(done)

    def __len__(self):
        return len(self.states)


In [74]:
import pickle
pickle_save_path='artifacts/exp_pools/exp_pool.pkl'
# pickle_save_path='artifacts/exp_pools/exp_pool.pkl'
exp_pool=pickle.load(open( pickle_save_path, 'rb'))

In [75]:
exp_pool.states

[array([[0.        , 0.        , 0.        , 0.        , 0.        ,
         0.06976745],
        [0.        , 0.        , 0.        , 0.        , 0.        ,
         0.4       ],
        [0.        , 0.        , 0.        , 0.        , 0.        ,
         0.12347122],
        [0.        , 0.        , 0.        , 0.        , 0.        ,
         0.14724159],
        [0.15558   , 0.398865  , 0.611087  , 0.957685  , 1.431809  ,
         2.123065  ],
        [0.        , 0.        , 0.        , 0.        , 0.        ,
         0.9791667 ]], dtype=float32),
 array([[0.        , 0.        , 0.        , 0.        , 0.06976745,
         0.06976745],
        [0.        , 0.        , 0.        , 0.        , 0.4       ,
         0.6728411 ],
        [0.        , 0.        , 0.        , 0.        , 0.12347122,
         0.12235085],
        [0.        , 0.        , 0.        , 0.        , 0.14724159,
         0.12715891],
        [0.139857  , 0.350812  , 0.571051  , 0.877771  , 1.300868  ,
    

In [76]:
exp_pool.states[0]

array([[0.        , 0.        , 0.        , 0.        , 0.        ,
        0.06976745],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.4       ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.12347122],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.14724159],
       [0.15558   , 0.398865  , 0.611087  , 0.957685  , 1.431809  ,
        2.123065  ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.9791667 ]], dtype=float32)

In [77]:
exp_pool.rewards[10]

0.75

In [78]:
len(exp_pool.states)

19928

In [79]:
type(exp_pool.states[0])

numpy.ndarray

In [80]:
exp_pool.states[0].shape

(6, 6)

In [81]:
type(exp_pool.rewards[0])

numpy.float64

In [82]:
exp_pool.states[0]

array([[0.        , 0.        , 0.        , 0.        , 0.        ,
        0.06976745],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.4       ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.12347122],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.14724159],
       [0.15558   , 0.398865  , 0.611087  , 0.957685  , 1.431809  ,
        2.123065  ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.9791667 ]], dtype=float32)

In [83]:
import pickle
pickle_save_path='artifacts/exp_pools/exp_pool_l4s.pkl'
# pickle_save_path='artifacts/exp_pools/exp_pool.pkl'
exp_pool=pickle.load(open( pickle_save_path, 'rb'))

In [None]:
exp_pool.states[0]

In [84]:
exp_pool.states

[array([0, 0, 0, 0, 0, 0, 0, 0], dtype=int64),
 array([     0, 150000,      0,      0,      0,      0,     60,      0],
       dtype=int64),
 array([     0, 150000,      0,      0,      0,      0,    120,      0],
       dtype=int64),
 array([     0, 150000,      0,      0,      0,      0,    180,      0],
       dtype=int64),
 array([     0, 150000,      0,  10000,      0,      0,    240,      0],
       dtype=int64),
 array([     0, 150000,      0,      0,      0,      0,    300,      0],
       dtype=int64),
 array([     0, 150000,      0,  10000,      0,      0,    360,      0],
       dtype=int64),
 array([     0, 150000,      0,      0,      0,      0,    420,      0],
       dtype=int64),
 array([     0, 150000,      0,  10000,      0,      0,    480,      0],
       dtype=int64),
 array([1, 0, 0, 0, 0, 0, 0, 0], dtype=int64),
 array([     0, 150000,      0,      0,      0,      0,    532,      0],
       dtype=int64),
 array([     0, 150000,      0,      0,      0,      0,    5

In [85]:
len(exp_pool.states)

721858

In [86]:
exp_pool.rewards

[0,
 0,
 0,
 0,
 10000,
 0,
 10000,
 0,
 10000,
 0,
 0,
 0,
 10000,
 0,
 0,
 10000,
 30000,
 0,
 10000,
 0,
 10000,
 0,
 30000,
 30000,
 10000,
 0,
 10000,
 10000,
 10000,
 0,
 0,
 0,
 10000,
 0,
 0,
 0,
 10000,
 0,
 10000,
 0,
 0,
 10000,
 0,
 0,
 10000,
 0,
 0,
 10000,
 0,
 10000,
 0,
 0,
 0,
 0,
 0,
 0,
 10000,
 0,
 0,
 0,
 10000,
 0,
 0,
 0,
 0,
 0,
 10000,
 0,
 0,
 0,
 10000,
 10000,
 10000,
 10000,
 10000,
 10000,
 10000,
 10000,
 10000,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 30000,
 0,
 0,
 20000,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 10000,
 0,
 10000,
 0,
 0,
 0,
 0,
 0,
 10000,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 10000,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 10000,
 10000,
 10000,
 10000,
 10000,
 10000,
 10000,
 10000,
 10000,
 0,
 10000,
 10000,
 10000,
 0,
 10000,
 0,
 0,
 10000,
 10000,
 10000,
 10000,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 10000,
 10000,
 10000,
 10000,
 30000,
 30000,
 30000,
 30000,
 30000,
 30000,
 30000,


In [87]:
exp_pool.rewards[0]

0

In [88]:
type(exp_pool.rewards[0])

numpy.int64

In [89]:
exp_pool.states[0].shape

(8,)

In [90]:
exp_pool.states[100]

array([     1, 150000,      0,  10000,      0,      0,  23326,      0],
      dtype=int64)

In [91]:
print("Type of first element in states:", type(exp_pool.states[2]))

Type of first element in states: <class 'numpy.ndarray'>


In [92]:
exp_dataset = ExperienceDataset(exp_pool, gamma=1.0, scale=1000, max_length=20, sample_step=10)
batch_size =1
exp_dataset_info = Munch(exp_dataset.exp_dataset_info)
dataloader = DataLoader(exp_dataset, batch_size, shuffle=True, pin_memory=True)

for step, batch in enumerate(dataloader):
    states, actions, returns, timesteps = batch
    print("Type of first element in states:", type(states[0]))
    print("process_batch states type:",type(states))
    break

Type of first element in states: <class 'torch.Tensor'>
process_batch states type: <class 'list'>
