In [45]:
import numpy as np

import tensorflow as tf

from tf_agents.environments import suite_gym, suite_pybullet, wrappers, py_environment, tf_environment, tf_py_environment, utils
from tf_agents.specs import array_spec
from tf_agents.trajectories import time_step as ts

In [12]:
# Three types of time_step
time_step = ts.restart(np.array([0.], dtype=np.float32))
print(time_step._fields)
print("start:\n", time_step, "\n")

time_step = ts.transition(np.array([0.], dtype=np.float32), 0)
print("middle:\n", time_step, "\n")

time_step = ts.termination(np.array([0.], dtype=np.float32), 1)
print("termination:\n", time_step, "\n")

('step_type', 'reward', 'discount', 'observation')
start:
 TimeStep(step_type=array(0, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([0.], dtype=float32)) 

middle:
 TimeStep(step_type=array(1, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([0.], dtype=float32)) 

termination:
 TimeStep(step_type=array(2, dtype=int32), reward=array(1., dtype=float32), discount=array(0., dtype=float32), observation=array([0.], dtype=float32)) 



## Create a Custom PyEnvironment

In [51]:
class BlackJack(py_environment.PyEnvironment):
    def __init__(self):
        self._action_spec = array_spec.BoundedArraySpec(shape=(), dtype=np.int32, minimum=0, maximum=1, name='action')
        self._observation_spec = array_spec.BoundedArraySpec(shape=(1,), dtype=np.int32, minimum=0, name='observation')
        self._state = 0
        self._episode_ended = False
        
    def action_spec(self):
        return self._action_spec
    
    def observation_spec(self):
        return self._observation_spec
    
    def time_step_spec(self):
        return ts.time_step_spec(self._observation_spec)
    
    def _reset(self):
        self._state = 0
        self._episode_ended = False
        return ts.restart(np.array([self._state], dtype=np.int32))
    
    def _step(self, action):
        if self._episode_ended:
            return self.reset()
        
        if action == 1:  # stop draw
            self._episode_ended = True
        elif action == 0:  # draw
            new_card = np.random.randint(1, 11)
            self._state += new_card
        else:
            raise ValueError("`actionn` should be 0 or 1 (received {0})".format(action))
            
        if self._episode_ended or self._state >= 21:
            reward = 21 - self._state if self._state <= 21 else - 21
            return ts.termination(np.array([self._state], dtype=np.int32), reward)
        else:
            return ts.transition(np.array([self._state], dtype=np.int32), reward=0., discount=1.)



### Validate environment

In [52]:
env = BlackJack()
utils.validate_py_environment(env)

In [54]:
# Run
get_new_card_action = np.array(0, dtype=np.int32)
end_round_action = np.array(1, dtype=np.int32)

env = BlackJack()
time_step = env.reset()
print("Reset:\n", time_step)
cummulative_reward = time_step.reward

for _ in range(3):
    time_step = env.step(get_new_card_action)
    print("Draw a card:\n", time_step)
    cummulative_reward += time_step.reward
    
time_step = env.step(end_round_action)
print("End round:\n", time_step)
cummulative_reward += time_step.reward
print("Final Reward:", cummulative_reward)

Reset:
 TimeStep(step_type=array(0, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([0], dtype=int32))
Draw a card:
 TimeStep(step_type=array(1, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([5], dtype=int32))
Draw a card:
 TimeStep(step_type=array(1, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([10], dtype=int32))
Draw a card:
 TimeStep(step_type=array(1, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([15], dtype=int32))
End round:
 TimeStep(step_type=array(2, dtype=int32), reward=array(6., dtype=float32), discount=array(0., dtype=float32), observation=array([15], dtype=int32))
Final Reward: 6.0


### Conver to TFEnvironment

In [57]:
tf_env = tf_py_environment.TFPyEnvironment(env)

print("Instance:", type(tf_env))
print("\nTimeStep Spec:\n", tf_env.time_step_spec())
print("\nAction Spec:\n", tf_env.action_spec())

Instance: <class 'tf_agents.environments.tf_py_environment.TFPyEnvironment'>

TimeStep Spec:
 TimeStep(step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), reward=TensorSpec(shape=(), dtype=tf.float32, name='reward'), discount=BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)), observation=BoundedTensorSpec(shape=(1,), dtype=tf.int32, name='observation', minimum=array(0, dtype=int32), maximum=array(2147483647, dtype=int32)))

Action Spec:
 BoundedTensorSpec(shape=(), dtype=tf.int32, name='action', minimum=array(0, dtype=int32), maximum=array(1, dtype=int32))


## Environment Wrapper

In [18]:
# Discretization
env = suite_gym.load("Pendulum-v0")
print("Action Spec:", env.action_spec())

discretized_action_env = wrappers.ActionDiscretizeWrapper(env, num_actions=5)
print("Discretized Action Spec:", discretized_action_env.action_spec())

discretized_action_env_2 = wrappers.ActionDiscretizeWrapper(env, num_actions=10)
print("Discretized Action Spec (2):", discretized_action_env_2.action_spec())

Action Spec: BoundedArraySpec(shape=(1,), dtype=dtype('float32'), name='action', minimum=-2.0, maximum=2.0)
Discretized Action Spec: BoundedArraySpec(shape=(), dtype=dtype('int32'), name='action', minimum=0, maximum=4)
Discretized Action Spec (2): BoundedArraySpec(shape=(), dtype=dtype('int32'), name='action', minimum=0, maximum=9)


In [22]:
# TimeLimit
env = suite_pybullet.load("HalfCheetahBulletEnv-v0")
env_tm = wrappers.TimeLimit(env, duration=500)

# Without limit
time_step = env.reset()
for i in range(5000):
    time_step = env.step(np.array([0.] * 6, dtype=np.float32))
    if time_step.is_last():
        print("terminated at i = {}".format(i))
        break
        
# With limit
time_step = env_tm.reset()
for i in range(5000):
    time_step = env_tm.step(np.array([0.] * 6, dtype=np.float32))
    if time_step.is_last():
        print("terminated at i = {}".format(i))
        break



terminated at i = 999
terminated at i = 499
