In [5]:
!pip install ray
!pip install ray[rllib]
!pip install ray[tune]
!pip install numba




In [6]:
!pip install gymnasium



In [7]:
import numpy as np
from numba import njit, prange
from typing import List, Dict, Any, Tuple
import gymnasium
from gymnasium.spaces import Discrete

# Constants
MAX_HINTS = 8
SCORE_3_ERRORS = 0.1
MAX_ERRORS = 3
DEFAULT_PLAYERS = 2
CARD_QUANTITIES = np.array([3, 2, 2, 2, 1])  # Number of cards per rank (1-5)

# Card representation using structured NumPy array
card_dtype = np.dtype([
    ('rank', np.int32),          # Rank of the card (1-5)
    ('color', np.int32),         # Color of the card (0-4)
    ('rank_known', np.bool_),    # Whether rank is known
    ('color_known', np.bool_),   # Whether color is known
    ('color_options', np.float32, (5,)),  # Probabilities of possible colors
    ('rank_options', np.float32, (5,))   # Probabilities of possible ranks
])

# Helper Functions

from numba import njit, prange
import numpy as np

@njit(parallel=True)
def draw_card(table, valid_ranks, valid_colors):
    """
    Draws a card from the deck table based on valid ranks and colors.
    """
    # Create an array to store valid possibilities
    max_possibilities = table.shape[0] * table.shape[1]
    possibilities = np.empty((max_possibilities, 2), dtype=np.int32)
    count = 0  # Counter for valid possibilities
    for r in range(table.shape[0]):
        for c in range(table.shape[1]):
            if table[r, c] > 0 and r in valid_ranks and c in valid_colors:
                possibilities[count%max_possibilities, 0] = r
                possibilities[count%max_possibilities, 1] = c
                count += 1

    if count == 0:
        return -1, -1

    # Trim possibilities to the actual count
    possibilities = possibilities[:count]

    # Randomly select an index
    idx = np.random.randint(0,count)
    return possibilities[idx%count, 0], possibilities[idx%count, 1]

@njit
def update_card_with_hint(card, hint_type, hint_value):
    """
    Updates a card's knowledge based on a given hint.
    """
    if hint_type == 'value':
        card['rank_known'] = True
        card['rank'] = hint_value
    elif hint_type == 'color':
        card['color_known'] = True
        card['color'] = hint_value


@njit
def check_game_ended(errors, board, last_turn_played):
    """
    Checks if the game has ended and calculates the final score.
    """
    if errors >= MAX_ERRORS:
        return True, sum(board) * SCORE_3_ERRORS
    if np.all(board == 5):
        return True, sum(board)
    if np.all(last_turn_played):
        return True, sum(board)
    return False, None


# Classes

class Deck:
    def __init__(self):
        col = CARD_QUANTITIES.reshape(-1, 1)
        self._table = np.tile(col, 5).astype(np.int8)  # 5 colors

    def draw(self, rank=None, color=None):

        valid_ranks = np.arange(5) if rank is None else np.array([rank - 1], dtype=np.int8)
        valid_colors = np.arange(5) if color is None else np.array([color], dtype=np.int8)
        rank, color = draw_card(self._table, valid_ranks, valid_colors)

        if rank == -1 or color == -1:
            return np.zeros((), dtype=card_dtype)  # Return an empty card
  # No cards available
        self._table[rank, color] -= 1
        card = np.zeros((), dtype=card_dtype)
        card['rank'] = rank + 1
        card['color'] = color
        card['color_options'] = np.full(5, 1.0 / 5, dtype=np.float32)
        card['rank_options'] = np.full(5, 1.0 / 5, dtype=np.float32)
        return card

    def add_cards(self, cards):
        for card in cards:
            self._table[card['rank'] - 1, card['color']] += 1

    def get_table(self):
        return self._table.copy()


class Trash:
    def __init__(self):
        self.list = []
        col = CARD_QUANTITIES.reshape(-1, 1)
        self._table = np.tile(col, 5).astype(np.int8)  # 5 colors

    def append(self, card):
        self.list.append(card)
        self._table[card['rank'] - 1, card['color']] -= 1

    def get_table(self):
        return self._table.copy()


class GameState:
    def __init__(self, root: int, hands=None, hints=MAX_HINTS, errors=0, deck=None, trash=None, board=None):
        """
        Initializes the game state with players' hands, deck, trash, and board.
        """
        if deck is None:
            deck = Deck()  # Initialize the deck if not provided

        if hands is None:
            hands = [np.array([deck.draw() for _ in range(5)], dtype=card_dtype) for _ in range(DEFAULT_PLAYERS)]

        if board is None:
            board = np.zeros(5, dtype=np.int32)  # Progress on each color

        if trash is None:
            trash = Trash()  # Initialize trash if not provided

        self.last_turn_played = np.full(DEFAULT_PLAYERS, False, dtype=np.bool_)
        self.root = root
        self.player = root
        self.hands = np.array(hands)
        self.hints = hints
        self.errors = errors
        self.deck = deck
        self.trash = trash
        self.board = board

    def play_card(self, player, card_idx):
        card = self.hands[player][card_idx]
        if self.board[card['color']] == card['rank'] - 1:
            self.board[card['color']] += 1
        else:
            self.trash.append(card)
            self.errors += 1

        new_card = self.deck.draw()
        if new_card ==np.zeros((), dtype=card_dtype):
            self.last_turn_played[player] = True
        else:
            self.hands[player][card_idx] = new_card

        self.player = (self.player + 1) % DEFAULT_PLAYERS

    def discard_card(self, player, card_idx):
        card = self.hands[player][card_idx]
        self.trash.append(card)

        new_card = self.deck.draw()
        if new_card == np.zeros((), dtype=card_dtype):
            self.last_turn_played[player] = True
        else:
            self.hands[player][card_idx] = new_card

        self.hints = min(self.hints + 1, MAX_HINTS)
        self.player = (self.player + 1) % DEFAULT_PLAYERS

    def give_hint(self, destination, hint_type, hint_value):
        for i in range(len(self.hands[destination])):
            update_card_with_hint(self.hands[destination][i], hint_type, hint_value)

        self.hints = max(self.hints - 1, 0)
        self.player = (self.player + 1) % DEFAULT_PLAYERS

    def game_ended(self):
        return check_game_ended(self.errors, self.board, self.last_turn_played)


In [8]:
from numba import njit, prange
from typing import Any, Dict, List, Tuple
import gymnasium
import numpy as np
from gymnasium.spaces import Discrete


class HanabiEnv(gymnasium.Env):
    def __init__(self, num_players: int = 2):
        super().__init__()
        self.num_players = num_players
        self.game_state = GameState(
            root=0,
            hands=None,
            hints=MAX_HINTS,
            errors=0,
            deck=Deck(),
            trash=Trash()
        )
        self.action_space = Discrete(20)
        self.observation_size = 447
        self.observation_space = gymnasium.spaces.Box(
            low=-10.0,
            high=10.0,
            shape=(self.observation_size,),
            dtype=np.float32
        )

    def reset(self, seed=None, options=None) -> Tuple[Dict[str, Any], Dict]:
        """
        Resets the environment to the initial state.
        """
        self.game_state = GameState(
            root=0,
            hands=None,
            hints=MAX_HINTS,
            errors=0,
            deck=Deck(),
            trash=Trash()
        )
        return self._get_observation(), {}

    @staticmethod
    @njit(parallel=True)
    def _compute_action_mask(hands, player: int, hints: int) -> np.ndarray:
        """
        Generate a mask for all possible actions.

        Args:
            hands: The hands of all players.
            player: The index of the current player.
            hints: The number of remaining hints.

        Returns:
            np.ndarray: A binary mask array for all possible actions.
        """
        action_mask = np.full(20, -np.inf, dtype=np.float32)
        for card_idx in prange(len(hands[player])):
            if hands[player][card_idx] is not None:
                action_mask[card_idx] = 0  # Legal discard
                action_mask[card_idx + 5] = 0  # Legal play
        if hints < MAX_HINTS:
            action_mask[10:] = 0
        return action_mask

    def get_legal_actions(self) -> np.ndarray:
        """
        Generate a mask for all possible actions.

        Returns:
            np.ndarray: A binary mask array for all possible actions.
        """
        return self._compute_action_mask(self.game_state.hands, self.game_state.player, self.game_state.hints)

    def step(self, action: int) -> Tuple[Dict[str, Any], float, bool, Dict[str, Any]]:
        """
        Executes an action in the environment.
        """
        action_type, player, card_idx, destination, hint_type, hint_value = self.decode_action(action)
        previous_board = self.game_state.board.copy()
        reward = 0

        if action_type == 0:  # Play card
            played_card = self.game_state.hands[player][card_idx]
            current_rank_on_board = self.game_state.board[played_card['color']]
            reward += 2 if played_card['rank'] == current_rank_on_board + 1 else -2
            self.game_state.play_card(player, card_idx)

        elif action_type == 1:  # Discard card
            discarded_card = self.game_state.hands[player][card_idx]
            current_rank_on_board = self.game_state.board[discarded_card['color']]
            reward += 1 if discarded_card['rank'] <= current_rank_on_board else -1
            if self.game_state.trash._table[discarded_card['color'], discarded_card['rank'] - 1] == 1:
                reward -= 3
            self.game_state.discard_card(player, card_idx)

        elif action_type == 2:  # Give hint
            self.game_state.give_hint(destination, hint_type, hint_value)

        # Compute incremental rewards for progress
        progress = self.game_state.board - previous_board
        reward += np.sum(progress)
        reward += 5 * np.sum(self.game_state.board == 5)  # Bonus for completed colors

        # Check game state
        done, score = self.game_state.game_ended()
        if done:
            reward += score if self.game_state.errors < MAX_ERRORS else -5

        return self._get_observation(), reward, done,False, {"score": np.sum(self.game_state.board)}

    @staticmethod
    @njit(parallel=True)
    def _encode_observation(
        hands, board: np.ndarray, hints: int, errors: int,
        deck_table: np.ndarray, trash_table: np.ndarray, player: int
    ) -> np.ndarray:
        """
        Encodes the current state as an observation.

        Returns:
            np.ndarray: A flattened representation of the game state.
        """
        num_possible_cards = 25
        num_cards_per_hand = 5
        belief = np.zeros((num_cards_per_hand * 2, num_possible_cards + 1), dtype=np.float32)

        for i in prange(num_cards_per_hand):
            card = hands[player][i]
            if card is not None:
                one_hot = card['color_options'].reshape(-1, 1).dot(card['rank_options'].reshape(1, -1))
                belief[i][:25] = (one_hot * trash_table).flatten()

        opponent_hands = np.zeros((5, num_possible_cards + 1), dtype=np.float32)
        for i in prange(num_cards_per_hand):
            card = hands[1 - player][i]
            if card is not None:
                opponent_hands[i][5 * card['color'] + card['rank'] - 1] = 1

        board = board.astype(np.float32)
        hints_errors = np.array([hints, errors], dtype=np.float32)
        deck_table = deck_table.flatten().astype(np.float32)
        trash_table = trash_table.flatten().astype(np.float32)

        # Concatenate arrays
        observation = np.concatenate((
            belief.flatten(),
            opponent_hands.flatten(),
            board,
            hints_errors,
            deck_table,
            trash_table
        ))
        return observation

    def _get_observation(self) -> np.ndarray:
        """
        Encodes the current state as an observation.
        """
        return self._encode_observation(
            self.game_state.hands,
            self.game_state.board,
            self.game_state.hints,
            self.game_state.errors,
            self.game_state.deck.get_table(),
            self.game_state.trash.get_table(),
            self.game_state.player
        )

    def render(self, mode: str = "human") -> None:
        """
        Renders the current state of the environment.
        """
        print("Board:", self.game_state.board)
        print("Hands:", [[str(card) for card in hand] for hand in self.game_state.hands])
        print("Hints:", self.game_state.hints)
        print("Errors:", self.game_state.errors)
        print("Trash:", [str(card) for card in self.game_state.trash])

    def decode_action(self, action_index: int) -> Tuple:
        """
        Decodes a flattened action index into an Action object.
        """
        if action_index < 5:
            return 0, self.game_state.player, action_index, None, None, None
        elif 5 <= action_index < 10:
            return 1, self.game_state.player, action_index - 5, None, None, None
        else:
            hint_idx = action_index - 10
            destination = 1 - self.game_state.player
            hint_type = "color" if hint_idx % 2 == 0 else "value"
            hint_value = hint_idx % 5
            return 2, self.game_state.player, 0, destination, hint_type, hint_value


In [9]:
!pip install hanabi_learning_environment



In [10]:
import ray
from ray import tune
from ray.rllib.algorithms.dqn import DQNConfig
from ray.tune.schedulers import PopulationBasedTraining
from ray.rllib.models import ModelCatalog
import torch.nn as nn
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.core.rl_module.rl_module import RLModuleConfig, RLModuleSpec
# """Example of implementing and configuring a custom (torch) LSTM containing RLModule.

# This example:
#     - demonstrates how you can subclass the TorchRLModule base class and set up your
#     own LSTM-containing NN architecture by overriding the `setup()` method.
#     - shows how to override the 3 forward methods: `_forward_inference()`,
#     `_forward_exploration()`, and `forward_train()` to implement your own custom forward
#     logic(s), including how to handle STATE in- and outputs to and from these calls.
#     - explains when each of these 3 methods is called by RLlib or the users of your
#     RLModule.
#     - shows how you then configure an RLlib Algorithm such that it uses your custom
#     RLModule (instead of a default RLModule).

# We implement a simple LSTM layer here, followed by a series of Linear layers.
# After the last Linear layer, we add fork of 2 Linear (non-activated) layers, one for the
# action logits and one for the value function output.

# We test the LSTM containing RLModule on the StatelessCartPole environment, a variant
# of CartPole that is non-Markovian (partially observable). Only an RNN-network can learn
# a decent policy in this environment due to the lack of any velocity information. By
# looking at one observation, one cannot know whether the cart is currently moving left or
# right and whether the pole is currently moving up or down).


# How to run this script
# ----------------------
# `python [script file name].py --enable-new-api-stack`

# For debugging, use the following additional command line options
# `--no-tune --num-env-runners=0`
# which should allow you to set breakpoints anywhere in the RLlib code and
# have the execution stop there for inspection and debugging.

# For logging to your WandB account, use:
# `--wandb-key=[your WandB API key] --wandb-project=[some project name]
# --wandb-run-name=[optional: WandB run name (within the defined project)]`


# Results to expect
# -----------------
# You should see the following output (during the experiment) in your console:

# """
# from ray.rllib.core.rl_module.rl_module import RLModuleSpec
# from ray.rllib.examples.envs.classes.stateless_cartpole import StatelessCartPole
# from ray.rllib.examples.envs.classes.multi_agent import MultiAgentStatelessCartPole
# from ray.rllib.examples.rl_modules.classes.lstm_containing_rlm import (
#     LSTMContainingRLModule,
# )
# from ray.tune import TuneConfig, Tuner

# from ray import train
# from ray.tune.registry import get_trainable_cls, register_env

# def hanabi_env_creator(env_config):
#     # Replace with your Hanabi environment
#     return HanabiEnv(env_config)
from ray import tune
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
# class CustomMetricsCallback(DefaultCallbacks):
#     def on_episode_end(
#         self,
#         *,
#         episode =None,
#         env_runner =None,
#         metrics_logger = MetricsLogger,
#         env = None,
#         env_index,
#         rl_module = None,
#         # TODO (sven): Deprecate these args.
#         worker = None,
#         base_env = None,
#         policies = None,
#         **kwargs,):       # Retrieve the 'score' metric from the environment
#         # Check if the last environment info contains "score"
#         last_info = episode.get_infos(-1)
#         if last_info and "score" in last_info.keys():
#             score = last_info["score"]
#             # Log the custom score metric
#             metrics_logger.log_dict({"score": score})
#         train.report({"score": score})
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.ppo import PPO
# import os
# import torch
# storage_path = os.path.abspath("./ray_results")
# ray.tune.registry.register_env("env", hanabi_env_creator)
# # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# base_config = (
#        PPOConfig()
#         .environment(
#             env="env",
#             env_config={"num_agents": 2},
#         )
#         .training(
#             train_batch_size_per_learner=1024,
#             num_epochs=100,
#             lr=0.0009,
#             vf_loss_coeff=0.001,
#             entropy_coeff=0.0,
#         )
#         .rl_module(
#             # Plug-in our custom RLModule class.
#             rl_module_spec=RLModuleSpec(
#                 module_class=LSTMContainingRLModule,
#                 # Feel free to specify your own `model_config` settings below.
#                 # The `model_config` defined here will be available inside your
#                 # custom RLModule class through the `self.model_config`
#                 # property.
#                 model_config={
#                     "lstm_cell_size": 256,
#                     "dense_layers": [256, 256],
#                     "max_seq_len": 20,
#                 },
#             ),
#         )
#     ).callbacks(CustomMetricsCallback)




#     # Define the PBT scheduler

# tuner = tune.Tuner(
#             PPO,
#         param_space=base_config.to_dict(),
#         run_config=train.RunConfig(
#             name="PPO_Hanabi",
#             stop={"training_iteration": 50000},  # Stop after 100 iterations
#             checkpoint_config=train.CheckpointConfig(
#                 checkpoint_frequency=10,  # Save a checkpoint every 10 iterations
#                 num_to_keep=2,  # Keep only the latest 2 checkpoints
#             ),
#             storage_path=storage_path,  # Path to save results
#         ),
#     )

#     # Run the training
# results = tuner.fit()

  from .autonotebook import tqdm as notebook_tqdm
2024-11-27 15:04:25,811	INFO util.py:90 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2024-11-27 15:04:27,489	INFO util.py:90 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


TypeError: Descriptors cannot not be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:
 1. Downgrade the protobuf package to 3.20.x or lower.
 2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).

More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates

In [11]:
import ray
from ray import tune
from ray.rllib.algorithms.dqn import DQNConfig
from ray.tune.registry import register_env
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger

# Register the custom environment
def hanabi_env_creator(env_config):
    # Replace this with your custom Hanabi environment
    return HanabiEnv(env_config)

class CustomMetricsCallback(DefaultCallbacks):
    def on_episode_end(
        self,
        *,
        episode =None,
        env_runner =None,
        metrics_logger = MetricsLogger,
        env = None,
        env_index,
        rl_module = None,
        # TODO (sven): Deprecate these args.
        worker = None,
        base_env = None,
        policies = None,
        **kwargs,):       # Retrieve the 'score' metric from the environment
        # Check if the last environment info contains "score"
        # Access the last info from the environment
        infos = episode.user_data.get("infos", [])
        
        if infos:  # Ensure infos exist
            # Assuming "score" is part of the last info dict
            last_info = infos[-1]  # Get the last info
            if "score" in last_info:
                score = last_info["score"]
                # Log the custom score metric
                episode.custom_metrics["score"] = score

# Define a custom DQN class for reporting metrics
from ray.rllib.algorithms.dqn import DQN

class CustomDQN(DQN):
    def __init__(self, config, *args, **kwargs):
        super().__init__(config, *args, **kwargs)

    def train(self):
        # Call the parent train method
        result = super().train()
        # Custom reporting at the end of training
        result['score'] = result['env_runners']['score']
        return result


# Define the PBT scheduler
from ray.tune.schedulers import PopulationBasedTraining

pbt_scheduler = PopulationBasedTraining(
    time_attr="training_iteration",
    metric="score",
    mode="max",
    perturbation_interval=5,  # Frequency of hyperparameter perturbations
    hyperparam_mutations={
        "lr": tune.uniform(1e-5, 1e-3),  # Learning rate mutation range
        "dueling_q_model": tune.choice([True, False]),  # Dueling network structure
        "exploration_config.epsilon_timesteps": tune.choice([10000, 50000, 100000]),  # Epsilon decay
        "train_batch_size": tune.choice([32, 64, 128, 256]),  # Training batch size
    },
)

# DQN configuration
base_config = (
    DQNConfig()
    .environment(
        
        env="hanabi_env",
        env_config={"num_agents": 2},
    )
    .training(
        dueling=True,  # Enable dueling network architecture
        double_q=True,  # Enable Double Q-learning
        lr=1e-4,
        train_batch_size=128,
        gamma=0.99,
    )
    .callbacks(CustomMetricsCallback)
    .rollouts(num_env_runners=2)  # Adjust this based on your hardware
)

import os
storage_path = os.path.abspath("./ray_results_dqn")

# Define the stopping criteria
stopping_criteria = {
    "training_iteration": 10000,  # Stop after 1000 iterations
    "episode_reward_mean": 300,  # Stop if reward mean exceeds 300
}

# Set up the tuner with a storage path
tuner = tune.Tuner(
    CustomDQN,
    tune_config=tune.TuneConfig(
        scheduler=pbt_scheduler,
        num_samples=4,  # Number of trials to run in parallel
    ),
    run_config=ray.air.RunConfig(
        name="dqn_pbt_hanabi",
        stop=stopping_criteria,
        checkpoint_config=ray.air.CheckpointConfig(
            checkpoint_score_attribute="score",
            num_to_keep=2,  # Keep the top 2 checkpoints
        ),
        storage_path=storage_path,  # Specify the storage path
    ),
    param_space=base_config.to_dict(),
)

# Train the model using PBT
results = tuner.fit()


TypeError: Descriptors cannot not be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:
 1. Downgrade the protobuf package to 3.20.x or lower.
 2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).

More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates