In [1]:
import gym
from stable_baselines3 import DQN, PPO
import torch as th

In [2]:
env = gym.make("CartPole-v0")

import os
if not os.path.exists('dqn_cartpole.zip'):
    model = PPO("MlpPolicy", env, verbose=1, device='cpu', tensorboard_log='../logs')
    model.learn(total_timesteps=100000)
    model.save("dqn_cartpole")
    del model

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to ../logs\PPO_11
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 42.3        |
|    ep_rew_mean          | 42.3        |
| time/                   |             |
|    fps                  | 1360        |
|    iterations           | 4           |
|    time_elapsed         | 6           |
|    total_timesteps      | 8192        |
| train/                  |             |
|    approx_kl            | 0.009837633 |
|    clip_fraction        | 0.119       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.636      |
|    explained_variance   | 0.265       |
|    learning_rate        | 0.0003      |
|    loss                 | 20.2        |
|    n_updates            | 30          |
|    policy_gradient_loss | -0.0213     |
|    value_loss           | 49          |
----------------------------------------

In [11]:
def evaluate(model, env, render=True):
    obs = env.reset()
    done = False
    while not done:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, info = env.step(action)
        if render:
            env.render()

    env.close()

In [12]:
model = PPO.load("dqn_cartpole", device='cpu')
evaluate(model, env)

In [17]:
from stable_baselines3.common.utils import obs_as_tensor
def get_action_prob(obaservation):
    obs = obaservation.reshape((-1,) + env.observation_space.shape)
    obs = obs_as_tensor(obs, model.device)
    latent_pi, _, latent_sde = model.policy._get_latent(obs)
    distribution = model.policy._get_action_dist_from_latent(latent_pi, latent_sde)
    action_prob = distribution.distribution.probs
    return action_prob.detach().numpy()

In [18]:
obs = env.reset()

for i in range(50):
    action, _states = model.predict(obs, deterministic=True)
    print(get_action_prob(obs), action)
    obs, reward, done, info = env.step(action)
    if done:
        break

[[0.81320137 0.18679859]] 0
[[0.03986129 0.9601387 ]] 1
[[0.87811095 0.121889  ]] 0
[[0.06040519 0.93959486]] 1
[[0.9190455  0.08095451]] 0
[[0.09024318 0.9097568 ]] 1
[[0.9451997  0.05480025]] 0
[[0.13307494 0.86692506]] 1
[[0.9623296 0.0376704]] 0
[[0.19357651 0.8064235 ]] 1
[[0.9738495  0.02615046]] 0
[[0.2766591  0.72334087]] 1
[[0.98177516 0.01822487]] 0
[[0.38511208 0.61488795]] 1
[[0.98731774 0.01268221]] 0
[[0.5152731  0.48472694]] 0
[[0.0143902 0.9856098]] 1
[[0.59154177 0.40845814]] 0
[[0.01713817 0.9828619 ]] 1
[[0.6413422  0.35865778]] 0
[[0.01920558 0.9807944 ]] 1
[[0.6683792 0.3316208]] 0
[[0.02024192 0.979758  ]] 1
[[0.6757409  0.32425916]] 0
[[0.02006352 0.9799365 ]] 1
[[0.6645577  0.33544236]] 0
[[0.01870253 0.9812975 ]] 1
[[0.63381684 0.3661832 ]] 0
[[0.01639649 0.9836035 ]] 1
[[0.5808036  0.41919646]] 0
[[0.01352314 0.9864769 ]] 1
[[0.502667 0.497333]] 0
[[0.01050215 0.9894979 ]] 1
[[0.4002229  0.59977716]] 1
[[0.98289466 0.01710535]] 0
[[0.33993742 0.6600626 ]] 1
[[