In [1]:
from copy import deepcopy
import os
import ray
from ray import tune
from ray.rllib.agents.registry import get_agent_class
from ray.rllib.env import PettingZooEnv
from rlskyjo.environment import simple_skyjo_env
from ray.rllib.models import ModelCatalog
from ray.tune.registry import register_env
from gym.spaces import Box
from ray.rllib.agents.dqn.dqn_torch_model import DQNTorchModel
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import FLOAT_MAX
from supersuit.multiagent_wrappers import pad_action_space_v0

torch, nn = try_import_torch()


In [2]:
ray.init(num_cpus=2)

{'node_ip_address': '172.17.87.73',
 'raylet_ip_address': '172.17.87.73',
 'redis_address': '172.17.87.73:26653',
 'object_store_address': '/tmp/ray/session_2022-01-27_21-05-49_823098_31263/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2022-01-27_21-05-49_823098_31263/sockets/raylet',
 'webui_url': None,
 'session_dir': '/tmp/ray/session_2022-01-27_21-05-49_823098_31263',
 'metrics_export_port': 46654,
 'node_id': 'f8e10565ffb72de1a96d558f39e184626bb4d6ff26441d0c38682574'}

In [3]:
class TorchMaskedActions(DQNTorchModel):
    """PyTorch version of above ParametricActionsModel."""

    def __init__(self,
                 obs_space,
                 action_space,
                 num_outputs,
                 model_config,
                 name,
                 **kw):
        DQNTorchModel.__init__(self, obs_space, action_space, num_outputs,
                               model_config, name, **kw)

        obs_len = obs_space.shape[0]-action_space.n

        orig_obs_space = Box(shape=(obs_len,), low=obs_space.low[:obs_len], high=obs_space.high[:obs_len])
        self.action_embed_model = TorchFC(orig_obs_space, action_space, action_space.n, model_config, name + "_action_embed")

    def forward(self, input_dict, state, seq_lens):
        # Extract the available actions tensor from the observation.
        print("input_dict",input_dict)
        action_mask = input_dict["obs"]["action_mask"]

        # Compute the predicted action embedding
        action_logits, _ = self.action_embed_model({
            "obs": input_dict["obs"]['observation']
        })
        # turns probit action mask into logit action mask
        inf_mask = torch.clamp(torch.log(action_mask), -1e10, FLOAT_MAX)

        return action_logits + inf_mask, state

    def value_function(self):
        return self.action_embed_model.value_function()

In [4]:
if __name__ == "__main__":
    alg_name = "DQN"
    env_name  = "pettingzoo_skyjo"
    ModelCatalog.register_custom_model(
        "pa_model", TorchMaskedActions
    )
    # function that outputs the environment you wish to register.

    def env_creator():
        env = simple_skyjo_env.env(**{"num_players": 2})
        return env


    config = deepcopy(get_agent_class(alg_name)._default_config)

    register_env(env_name,
                 lambda config: PettingZooEnv(env_creator()))

    test_env = PettingZooEnv(env_creator())
    obs_space = test_env.observation_space
    print("obs_space", obs_space)
    act_space = test_env.action_space
    print("act_space", act_space)

    config["multiagent"] = {
        "policies": {
            "draw": (None, obs_space, act_space, {}),
            "place": (None, obs_space, act_space, {}),
        },
        "policy_mapping_fn": lambda agent_id: agent_id.split("_")[0]
    }

    config["num_gpus"] = int(os.environ.get("RLLIB_NUM_GPUS", "0"))
    config["log_level"] = "DEBUG"
    config["num_workers"] = 1
    config["rollout_fragment_length"] = 30
    config["train_batch_size"] = 200
    config["horizon"] = 200
    config["no_done_at_end"] = False
    config["framework"] = "torch"
    config["model"] = {
        "custom_model": "pa_model",
    }
    config['n_step'] = 1

    config["exploration_config"] = {
        # The Exploration class to use.
        "type": "EpsilonGreedy",
        # Config for the Exploration class' constructor:
        "initial_epsilon": 0.1,
        "final_epsilon": 0.0,
        "epsilon_timesteps": 100000,  # Timesteps over which to anneal epsilon.
    }
    config['hiddens'] = []
    config['dueling'] = False
    config['env'] = env_name

    

    tune.run(
        alg_name,
        name="DQN",
        stop={"timesteps_total": 10000000},
        checkpoint_freq=10,
        config=config
        )



TypeError: __init__() missing 1 required positional argument: 'name'