In [2]:
from OthelloEnv import OthelloEnv

import gym
import torch as th
import torch.nn as nn

from stable_baselines3 import DQN
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.env_checker import check_env

In [3]:
env = OthelloEnv()
check_env(env)



In [6]:
class CustomCNN(BaseFeaturesExtractor):
    """
    :param observation_space: (gym.Space)
    :param features_dim: (int) Number of features extracted.
        This corresponds to the number of unit for the last layer.
    """

    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
        super(CustomCNN, self).__init__(observation_space, features_dim)
        
        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 64, kernel_size=(4, 4), stride=1, padding='same'),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(4,4), stride=1, padding='same'),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(4,4), stride=1, padding='same'),
            nn.ReLU(),
            nn.Flatten()
        )

        with th.no_grad():
            obs = th.as_tensor(observation_space.sample()[None])
            n_flatten = self.cnn(obs.float()).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.linear(self.cnn(observations))

policy_kwargs = dict(
    features_extractor_class=CustomCNN,
    features_extractor_kwargs=dict(features_dim=128),
)

model = DQN('CnnPolicy', env, exploration_fraction=0.5, policy_kwargs=policy_kwargs, verbose=0, tensorboard_log='logs')
model.learn(100000)

KeyboardInterrupt: 