In [272]:
from gym.envs.classic_control.cartpole import CartPoleEnv
import gym

In [273]:
import torch
import math

# Creating CartPoleEnv() object

In [274]:
env_v0 = gym.make('CartPole-v0')
env_v1 = gym.make('CartPole-v1')
env_raw = CartPoleEnv()
env = env_raw

print(' v0', '=>', env_v0)
print('\t', 'max_episode_steps = {}'.format(env_v0.spec.max_episode_steps))
print('\t', 'reward_threshold  = {}'.format(env_v0.spec.reward_threshold))
print(' v1', '=>', env_v1)
print('\t', 'max_episode_steps = {}'.format(env_v1.spec.max_episode_steps))
print('\t', 'reward_threshold  = {}'.format(env_v1.spec.reward_threshold))
print('env', '=>', env)

 v0 => <TimeLimit<CartPoleEnv<CartPole-v0>>>
	 max_episode_steps = 200
	 reward_threshold  = 195.0
 v1 => <TimeLimit<CartPoleEnv<CartPole-v1>>>
	 max_episode_steps = 500
	 reward_threshold  = 475.0
env => <CartPoleEnv instance>


In [69]:
env.seed(1)

[1]

# State
1. Cart position (x)
2. Cart velocity (x_dot)
3. Pole angle (theta)
4. Rate of changes of the angle (theta_dot)

In [70]:
state = env.reset()
print(state)
print(env.state)

[ 0.03073904  0.00145001 -0.03088818 -0.03131252]
[ 0.03073904  0.00145001 -0.03088818 -0.03131252]


# Observation

In [275]:
def format_state(state):
    x, x_dot, theta, theta_dot = state
    deg = math.degrees(theta)
    return "(x = {:+.4f}, dx = {:+.4f}, θ = {:+8.4f}, dθ = {:+.4f})".format(
        x, x_dot, deg, theta_dot)

In [276]:
ACTION_LEFT = 0
ACTION_RIGHT = 1

In [277]:
state = env.reset()
done = False
t = 0
print('initial state', '->', format_state(state))
while not done:
    state, reward, done, info = env.step(ACTION_LEFT)
    print('t = {:3}'.format(t), '->', format_state(state), reward, done, info)
    t += 1

initial state -> (x = +0.0003, dx = +0.0350, θ =  -2.6161, dθ = +0.0484)
t =   0 -> (x = +0.0010, dx = -0.1595, θ =  -2.5606, dθ = +0.3264) 1.0 False {}
t =   1 -> (x = -0.0022, dx = -0.3539, θ =  -2.1867, dθ = +0.6046) 1.0 False {}
t =   2 -> (x = -0.0093, dx = -0.5485, θ =  -1.4938, dθ = +0.8850) 1.0 False {}
t =   3 -> (x = -0.0203, dx = -0.7432, θ =  -0.4797, dθ = +1.1694) 1.0 False {}
t =   4 -> (x = -0.0351, dx = -0.9383, θ =  +0.8604, dθ = +1.4595) 1.0 False {}
t =   5 -> (x = -0.0539, dx = -1.1336, θ =  +2.5328, dθ = +1.7568) 1.0 False {}
t =   6 -> (x = -0.0766, dx = -1.3292, θ =  +4.5459, dθ = +2.0629) 1.0 False {}
t =   7 -> (x = -0.1032, dx = -1.5250, θ =  +6.9098, dθ = +2.3790) 1.0 False {}
t =   8 -> (x = -0.1337, dx = -1.7210, θ =  +9.6360, dθ = +2.7062) 1.0 False {}
t =   9 -> (x = -0.1681, dx = -1.9168, θ = +12.7371, dθ = +3.0451) 1.0 True {}


# TimeLimit

https://github.com/openai/gym/blob/master/gym/wrappers/time_limit.py#L19

TimeLimit.truncated

In [278]:
import cartpole

In [279]:
checkpoint = torch.load('Conv-v3.pt')
policy = cartpole.ConvNet4()
policy.load_state_dict(checkpoint['policy'])

<All keys matched successfully>

In [284]:
checkpoint = torch.load('FC-150-v1.pt')
policy = cartpole.FCNet0()
policy.load_state_dict(checkpoint['policy'])

<All keys matched successfully>

In [280]:
state = env_v0.reset()
done = False
t = 0
print('initial state', '->', format_state(state))
while not done:
    act_prob = cartpole.ask(policy, state)
    act = cartpole.choice_act(act_prob)
    state, reward, done, info = env_v0.step(act)
    print('t = {:3}'.format(t), '->', format_state(state), reward, done, info, act_prob.data.numpy())
    t += 1

initial state -> (x = -0.0409, dx = +0.0271, θ =  -0.4716, dθ = -0.0291)
t =   0 -> (x = -0.0404, dx = -0.1679, θ =  -0.5050, dθ = +0.2609) 1.0 False {} [9.99865564e-01 1.34435786e-04]
t =   1 -> (x = -0.0438, dx = +0.0273, θ =  -0.2060, dθ = -0.0345) 1.0 False {} [0.00273336 0.99726664]
t =   2 -> (x = -0.0432, dx = -0.1677, θ =  -0.2456, dθ = +0.2570) 1.0 False {} [9.99867864e-01 1.32135932e-04]
t =   3 -> (x = -0.0466, dx = +0.0274, θ =  +0.0489, dθ = -0.0370) 1.0 False {} [0.00249785 0.99750215]
t =   4 -> (x = -0.0460, dx = -0.1677, θ =  +0.0065, dθ = +0.2559) 1.0 False {} [9.99863565e-01 1.36434550e-04]
t =   5 -> (x = -0.0494, dx = +0.0274, θ =  +0.2998, dθ = -0.0367) 1.0 False {} [0.00199504 0.99800496]
t =   6 -> (x = -0.0488, dx = -0.1678, θ =  +0.2577, dθ = +0.2576) 1.0 False {} [9.99851832e-01 1.48167766e-04]
t =   7 -> (x = -0.0522, dx = +0.0273, θ =  +0.5530, dθ = -0.0336) 1.0 False {} [0.00146397 0.99853603]
t =   8 -> (x = -0.0516, dx = -0.1680, θ =  +0.5144, dθ = +0.26

t =  78 -> (x = +0.5189, dx = +0.9379, θ =  +0.7422, dθ = -0.0632) 1.0 False {} [9.99020369e-01 9.79630524e-04]
t =  79 -> (x = +0.5377, dx = +1.1328, θ =  +0.6698, dθ = -0.3517) 1.0 False {} [0.10114256 0.89885744]
t =  80 -> (x = +0.5603, dx = +0.9375, θ =  +0.2667, dθ = -0.0554) 1.0 False {} [9.99068024e-01 9.31976237e-04]
t =  81 -> (x = +0.5791, dx = +1.1326, θ =  +0.2033, dθ = -0.3466) 1.0 False {} [0.07589367 0.92410633]
t =  82 -> (x = +0.6017, dx = +0.9374, θ =  -0.1939, dθ = -0.0528) 1.0 False {} [9.99191088e-01 8.08912385e-04]
t =  83 -> (x = +0.6205, dx = +1.1326, θ =  -0.2544, dθ = -0.3465) 1.0 False {} [0.07447348 0.92552652]
t =  84 -> (x = +0.6431, dx = +0.9375, θ =  -0.6515, dθ = -0.0553) 1.0 False {} [9.99357531e-01 6.42469130e-04]
t =  85 -> (x = +0.6619, dx = +1.1328, θ =  -0.7148, dθ = -0.3515) 1.0 False {} [0.09897351 0.90102649]
t =  86 -> (x = +0.6845, dx = +0.9379, θ =  -1.1176, dθ = -0.0628) 1.0 False {} [9.99528959e-01 4.71040849e-04]
t =  87 -> (x = +0.7033,

t = 169 -> (x = +1.6304, dx = +0.0195, θ =  -2.5459, dθ = +0.1390) 1.0 False {} [0.98468499 0.01531501]
t = 170 -> (x = +1.6308, dx = +0.2152, θ =  -2.3866, dθ = -0.1673) 1.0 False {} [0.02170338 0.97829662]
t = 171 -> (x = +1.6351, dx = +0.0207, θ =  -2.5783, dθ = +0.1119) 1.0 False {} [0.98746646 0.01253354]
t = 172 -> (x = +1.6355, dx = +0.2164, θ =  -2.4501, dθ = -0.1946) 1.0 False {} [0.0799689 0.9200311]
t = 173 -> (x = +1.6399, dx = +0.0220, θ =  -2.6730, dθ = +0.0843) 1.0 False {} [0.98822792 0.01177208]
t = 174 -> (x = +1.6403, dx = -0.1725, θ =  -2.5764, dθ = +0.3619) 1.0 False {} [0.28959616 0.71040384]
t = 175 -> (x = +1.6369, dx = +0.0233, θ =  -2.1617, dθ = +0.0554) 1.0 False {} [6.82861389e-05 9.99931714e-01]
t = 176 -> (x = +1.6373, dx = -0.1713, θ =  -2.0983, dθ = +0.3359) 1.0 False {} [0.50285254 0.49714746]
t = 177 -> (x = +1.6339, dx = +0.0243, θ =  -1.7133, dθ = +0.0319) 1.0 False {} [9.33351191e-05 9.99906665e-01]
t = 178 -> (x = +1.6344, dx = -0.1704, θ =  -1.676

In [285]:
cartpole.sim(policy, env, limit=200, animation=True)