In [20]:
from typing import List 
from open_spiel.python import rl_environment
from acme.wrappers import open_spiel_wrapper
import numpy as np
import dm_env

class AlphaZeroWrapper(open_spiel_wrapper.OpenSpielWrapper):
    def __init__(self, environment: rl_environment.Environment, history_size: int = 8):
        super().__init__(environment)
        self._history_size = history_size
        self._board_size = 19
        self._state_history = []
        self._num_planes = self._history_size * 2 + 1  # 2 planes per history step + 1 for current player

    def _convert_obs(self, observations: List[open_spiel_wrapper.OLT]) -> List[open_spiel_wrapper.OLT]:
        # Extract the current board state (4 planes: black, white, empty, current player)
        # Here each observation in observations is identical(perfect information game), 
        # please refer to https://github.com/google-deepmind/open_spiel/blob/master/open_spiel/games/go/go.cc#L109
        current_state = observations[0].observation.reshape(self._board_size, self._board_size, 4)
        
        # Update state history (we'll only use the first 2 planes for history)
        self._state_history.append(current_state)
        if len(self._state_history) > self._history_size:
            self._state_history.pop(0)
        
        # Construct the n-plane representation
        alphazero_observation = self._construct_alphazero_planes()
        
        # Update the observation in the OLT named tuple for both players
        new_observations = []
        for obs in observations:
            new_obs = open_spiel_wrapper.OLT(
                observation=alphazero_observation,
                legal_actions=obs.legal_actions,
                terminal=obs.terminal
            )
            new_observations.append(new_obs)
        
        return new_observations

    def _construct_alphazero_planes(self):
        observation = np.zeros((self._board_size, self._board_size, self._num_planes), dtype=np.float32)
        
        for i, state in enumerate(reversed(self._state_history)):
            if i >= self._history_size:
                break
            observation[:, :, i*2] = state[:, :, 0]  # Black stones
            observation[:, :, i*2+1] = state[:, :, 1]  # White stones
        
        # Set the current player plane
        current_player = self._state_history[-1][:, :, 3]  # Use the 4th plane from the most recent state
        observation[:, :, -1] = current_player
        
        return observation

    def _initialize_state_history(self):
        empty_state = np.zeros((self._board_size, self._board_size, self._num_planes), dtype=np.float32)
        self._state_history = [empty_state] * self._history_size

    def observation_spec(self):
        spec = super().observation_spec()
        new_shape = (self._board_size, self._board_size, self._num_planes)
        return open_spiel_wrapper.OLT(
            observation=dm_env.specs.BoundedArray(shape=new_shape, dtype=np.float32, name='observation', minimum=0, maximum=1),
            legal_actions=spec.legal_actions,
            terminal=spec.terminal
        )

    def reset(self) -> dm_env.TimeStep:
        timestep = super().reset()
        self._initialize_state_history()
        return self._update_timestep(timestep)

    def step(self, action: int) -> dm_env.TimeStep:
        timestep = super().step([action])
        return self._update_timestep(timestep)

    def _update_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:
        new_observation = self._convert_obs(timestep.observation)
        return dm_env.TimeStep(
            step_type=timestep.step_type,
            reward=timestep.reward,
            discount=timestep.discount,
            observation=new_observation
        )

In [21]:
from acme import wrappers

env_configs = {
        'max_game_length': 3,
        'komi': 7.5,
        'board_size': 19,
    }
raw_environment = rl_environment.Environment('go', **env_configs)
environment = AlphaZeroWrapper(raw_environment)
environment = wrappers.SinglePrecisionWrapper(environment)

In [22]:
timestep = environment.reset()
obs = timestep.observation
environment.current_player

0

In [23]:
legal_actions0 = obs[0].legal_actions
legal_actions1 = obs[1].legal_actions

In [24]:
timestep1 = environment.step(2)
obs1 = timestep1.observation
environment.current_player

1

In [25]:
legal_actions0 = obs1[0].legal_actions
legal_actions1 = obs1[1].legal_actions

In [26]:
timestep1.reward

array([0., 0.], dtype=float32)

In [27]:
timestep2 = environment.step(4)
obs2 = timestep2.observation
environment.current_player

0

In [28]:
timestep3 = environment.step(8)
obs3 = timestep3.observation
environment.current_player

-4

In [31]:
timestep3.reward

array([-1.,  1.], dtype=float32)

In [None]:
import tensorflow as tf 

In [None]:
policy = [0.5] * 362

In [None]:
masked = tf.where(legal_actions > 0, policy, tf.float32.min)

In [22]:
legal_actions = obs[0].legal_actions
print("Shape:", np.shape(legal_actions))
print("Values:", legal_actions)
print("Total actions:", environment.action_spec())  # Usually board_size * board_size + 1 for pass

Shape: (362,)
Values: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1