In [1]:
import numpy as np
from dm_control import suite
from dm_control.rl.control import Environment
from dm_env import StepType

In [2]:
class DataGenerator:
    def __init__(self, action_count: int = 7):
        self.action_count = action_count
        self.history = self.reset_history()
        random_state = np.random.RandomState(42)
        self.env: Environment = suite.load('cartpole', 'balance', task_kwargs={'random': random_state})
        spec = self.env.action_spec()
        self.action_values: np.array = np.linspace(spec.minimum, spec.maximum, action_count)

    def reset_history(self):
        self.history = {
            'velocity': [],
            'position': [],
            'frame': [],
            'action': [],
            'reward': [],
            'ticks': [],
        }
        return self.history

    def run_episode(self, episode_num: int):
        time_step = self.env.reset()
        while StepType.LAST != time_step.step_type:
            if StepType.FIRST == time_step.step_type:
                print(f'Episode {episode_num} start')
            self.history['velocity'].append(time_step.observation['velocity'])
            self.history['position'].append(time_step.observation['position'])
            self.history['frame'].append(self.env.physics.render(camera_id=0, height=100, width=100))
            action: int = np.random.randint(self.action_count)
            self.history['action'].append(action)
            action_val = self.action_values[action]
            time_step = self.env.step(action_val)
            print(f'Reward: {time_step.reward}')
            self.history['reward'].append(float(time_step.reward))
            self.history['ticks'].append(self.env.physics.data.time)

    def save_episode_results(self, episode_num: int):
        np.save(f'data/e{episode_num}_h_frame.npy', self.history['frame'])
        print('h_frame', np.array(self.history['frame']).shape)
        np.savetxt(f'data/e{episode_num}_h_velocity.csv', self.history['velocity'], delimiter=',')
        print('h_velocity', np.array(self.history['velocity']).shape)
        np.savetxt(f'data/e{episode_num}_h_position.csv', self.history['position'], delimiter=',')
        print('h_position', np.array(self.history['position']).shape)
        np.savetxt(f'data/e{episode_num}_h_action.csv', self.history['action'], delimiter=',')
        print('h_action', np.array(self.history['action']).shape)
        np.savetxt(f'data/e{episode_num}_h_reward.csv', self.history['reward'], delimiter=',')
        print('h_reward', np.array(self.history['reward']).shape)
        np.savetxt(f'data/e{episode_num}_h_ticks.csv', self.history['ticks'], delimiter=',')
        print('h_ticks', np.array(self.history['ticks']).shape)

    def run_episodes(self, count: int):
        for i in range(count):
            self.run_episode(i)
            self.save_episode_results(i)
            self.reset_history()

In [3]:
dg = DataGenerator()
# Frames per second
1. / dg.env.control_timestep()

100.0

In [4]:
dg.run_episodes(5)

Episode 0 start
Reward: 0.9771612925221649
Reward: 0.9767068959075431
Reward: 0.9759945461244119
Reward: 0.9750196978431651
Reward: 0.9737761568430704
Reward: 0.9956734560807978
Reward: 0.9719799425200812
Reward: 0.9701210168895569
Reward: 0.7967884709478975
Reward: 0.9089165131874274
Reward: 0.909790980434301
Reward: 0.9085961729928027
Reward: 0.7989294162400916
Reward: 0.909979423218271
Reward: 0.7985319600353187
Reward: 0.796360301937472
Reward: 0.7925023502592089
Reward: 0.9661870211250139
Reward: 0.8955896659557252
Reward: 0.7790511127667163
Reward: 0.947909694621328
Reward: 0.7659579738455363
Reward: 0.8634449887319123
Reward: 0.9284395832244406
Reward: 0.8701088471033437
Reward: 0.7523983884313005
Reward: 0.9208630379815202
Reward: 0.8626988440448213
Reward: 0.9187945873737005
Reward: 0.9111898285928886
Reward: 0.7524901112364969
Reward: 0.9199283232835304
Reward: 0.7586837671609145
Reward: 0.8671351265984868
Reward: 0.8572354388904125
Reward: 0.7395449214116707
Reward: 0.919384

In [8]:
time_step = dg.env.reset()
time_step

TimeStep(step_type=<StepType.FIRST: 0>, reward=None, discount=None, observation=OrderedDict([('position', array([-6.00652436e-02,  9.99999532e-01,  9.67941661e-04])), ('velocity', array([-0.00544383,  0.00110923]))]))