In [1]:
import gym
import torch as th
import torch.nn as nn

from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.env_util import make_vec_env

In [2]:
class CustomCNN(BaseFeaturesExtractor):
    """
    :param observation_space: (gym.Space)
    :param features_dim: (int) Number of features extracted.
        This corresponds to the number of unit for the last layer.
    """

    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
        super(CustomCNN, self).__init__(observation_space, features_dim)
        # We assume CxHxW images (channels first)
        # Re-ordering will be done by pre-preprocessing or wrapper
        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(None, 32, kernel_size=8, stride=4, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Compute shape by doing one forward pass
        with th.no_grad():
            n_flatten = self.cnn(
                th.as_tensor(observation_space.sample()[None]).float()
            ).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.linear(self.cnn(observations))

env = gym.make('BreakoutNoFrameskip-v4')
# env = make_vec_env('BreakoutNoFrameskip-v4', 4)
policy_kwargs = dict(
    features_extractor_class=CustomCNN,
    features_extractor_kwargs=dict(features_dim=128),
)
# model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs)
model = PPO("CnnPolicy", env)
print(model.policy)
model.learn(500)
model.save("Breakout_PPO")
del model

ActorCriticCnnPolicy(
  (features_extractor): NatureCNN(
    (cnn): Sequential(
      (0): Conv2d(3, 32, kernel_size=(8, 8), stride=(4, 4))
      (1): ReLU()
      (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
      (3): ReLU()
      (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (5): ReLU()
      (6): Flatten(start_dim=1, end_dim=-1)
    )
    (linear): Sequential(
      (0): Linear(in_features=22528, out_features=512, bias=True)
      (1): ReLU()
    )
  )
  (mlp_extractor): MlpExtractor(
    (shared_net): Sequential()
    (policy_net): Sequential()
    (value_net): Sequential()
  )
  (action_net): Linear(in_features=512, out_features=4, bias=True)
  (value_net): Linear(in_features=512, out_features=1, bias=True)
)


In [3]:
env = gym.make('BreakoutNoFrameskip-v4')
model = PPO.load("Breakout_PPO")
for episode in range(10):
    step = 0
    observation = env.reset()
    while True:
        env.render()
        action = model.predict(observation)
        observation, reward, done, info = env.step(action)
        if done:
            print(f"Episode{episode} finished after {step} timesteps")
            break
        step += 1

env.close()



Episode0 finished after 504 timesteps
Episode1 finished after 691 timesteps
Episode2 finished after 1189 timesteps
Episode3 finished after 789 timesteps
Episode4 finished after 681 timesteps
Episode5 finished after 810 timesteps
Episode6 finished after 494 timesteps
Episode7 finished after 506 timesteps
Episode8 finished after 634 timesteps
Episode9 finished after 491 timesteps


In [6]:
for i in range(10):
    print(env.observation_space.sample().shape)

(210, 160, 3)
(210, 160, 3)
(210, 160, 3)
(210, 160, 3)
(210, 160, 3)
(210, 160, 3)
(210, 160, 3)
(210, 160, 3)
(210, 160, 3)
(210, 160, 3)
