# Adding OpenSpiel to TorchRL

I need to add OpenSpiel to TorchRL. These are my notes about that.

Issue: <https://github.com/pytorch/rl/issues/2133>

OpenSpiel: <https://github.com/google-deepmind/open_spiel>

OpenSpiel basic API reference:
<https://github.com/google-deepmind/open_spiel/blob/master/docs/api_reference.md>

This is a very instructive tutorial for how to create a new stateless env:
<https://pytorch.org/rl/stable/tutorials/pendulum.html>

Action masks: <https://pytorch.org/rl/stable/reference/envs.html#environments-with-masked-actions>

## Basic demo of OpenSpiel

In [1]:
import pyspiel
chess = pyspiel.load_game('chess')
chess_state = chess.new_initial_state()
chess_state

rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1

In [2]:
actions = chess_state.legal_actions()
actions

[89,
 90,
 652,
 656,
 673,
 674,
 1257,
 1258,
 1841,
 1842,
 2425,
 2426,
 3009,
 3010,
 3572,
 3576,
 3593,
 3594,
 4177,
 4178]

In [3]:
[chess_state.action_to_string(action) for action in actions]

['a3',
 'a4',
 'Na3',
 'Nc3',
 'b3',
 'b4',
 'c3',
 'c4',
 'd3',
 'd4',
 'e3',
 'e4',
 'f3',
 'f4',
 'Nf3',
 'Nh3',
 'g3',
 'g4',
 'h3',
 'h4']

In [4]:
move_idx = 11
print(f'playing move {chess_state.action_to_string(actions[move_idx])}')
chess_state.apply_action(actions[move_idx])
chess_state

playing move e4


rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1

In [5]:
type(chess_state)

pyspiel.ChessState

In [6]:
import numpy as np
np.where(np.array(chess_state.legal_actions_mask()) == 1)

(array([  89,   90,  652,  656,  673,  674, 1257, 1258, 1841, 1842, 2425,
        2426, 3009, 3010, 3572, 3576, 3593, 3594, 4177, 4178]),)

In [7]:
chess_state.get_type()

dir(pyspiel.StateType.DECISION)


['CHANCE',
 'DECISION',
 'MEAN_FIELD',
 'TERMINAL',
 '__class__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__entries',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__index__',
 '__init__',
 '__init_subclass__',
 '__int__',
 '__le__',
 '__lt__',
 '__members__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 'name',
 'value']

Some notes:

* The state of a given game type is a derived class of `pyspiel.State`.
* Legal actions are encoded as integers, but we can convert them into human-readable format with `pyspiel.<State-derived>.action_to_string()`.
* The ``pyspiel.<State-derived>.__repr__()`` method displays the state in some specific format. In the case of `pyspiel.ChessState`, it's a FEN string.

## Existing TorchRL env demo

In [8]:
#import brax.envs
#from torchrl.envs import BraxWrapper
#base_env = brax.envs.get_environment("ant")
#env = BraxWrapper(base_env)
#env.set_seed(0)
#td = env.reset()
#td

In [9]:
import jumanji
import jax
env = jumanji.make('Snake-v1')
key = jax.random.PRNGKey(0)
state, _ = env.reset(key)

def state_to_dict_of_arrays(state):
    res = {}
    for key, value in state.items():
        if hasattr(value, '_fields'):
            res[key] = {}
            for field in value._fields:
                res[key][field] = jax.numpy.asarray(value)
        else:
            res[key] = jax.numpy.asarray(value)
    
    return res

state_to_dict_of_arrays(state)


An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
  from .autonotebook import tqdm as notebook_tqdm


{'body': Array([[False, False, False, False, False, False, False, False, False,
         False, False, False],
        [False, False, False, False, False, False, False, False, False,
         False, False, False],
        [False, False, False, False, False, False, False, False, False,
         False, False, False],
        [False, False, False, False, False, False, False, False, False,
         False, False, False],
        [False, False, False, False, False, False, False, False, False,
         False, False, False],
        [False, False, False, False, False, False, False, False,  True,
         False, False, False],
        [False, False, False, False, False, False, False, False, False,
         False, False, False],
        [False, False, False, False, False, False, False, False, False,
         False, False, False],
        [False, False, False, False, False, False, False, False, False,
         False, False, False],
        [False, False, False, False, False, False, False, False, 

## History



In [10]:
chess_state = chess.new_initial_state()

for _ in range(20):
    action = np.random.choice(chess_state.legal_actions())
    chess_state.apply_action(action)

chess_state.serialize()
chess_state.clone()

1r1q1bnr/p1nkppp1/b7/2p4p/1pP4P/3P2P1/PP1KPP2/RNB2BNR w - - 0 11

## Plan

I need to add an `OpenSpielWrapper` call to TorchRL, derived from `_EnvWrapper`. I can base it off of the other env wrappers in `torchrl/envs/libs`.

The wrapper should be stateless, meaning that an instance of `OpenSpielWrapper` does not hold onto the state (`pyspiel.State`) of the game. Instead, the state of the game should be part of the `TensorDict` that we pass in to `OpenSpielWrapper.step`, and the new state should be part of the output `TensorDict`.

So that means that we will need a way to convert a `pyspiel.State` into something that we can put into a `TensorDict`. We will also need a way to reconstruct the `pyspiel.State` from the `TensorDict` so that we can continue to make moves upon each call to `OpenSpielWrapper.step`. What are some ways to do that?

I think we may be able to simply place the `pyspiel.State` directly into the `TensorDict`.

## Convert `pyspiel.State` to `TensorDict`

In [11]:
from tensordict import TensorDict

def state_to_td(state):
  td = TensorDict(
    source={
      'state': chess_state,
      'observation': chess_state.observation_tensor(),
    },
    batch_size=[],
  )
  return td

td = state_to_td(chess_state)
td

TensorDict(
    fields={
        observation: Tensor(shape=torch.Size([1280]), device=cpu, dtype=torch.float64, is_shared=False),
        state: NonTensorData(data=1r1q1bnr/p1nkppp1/b7/2p4p/1pP4P/3P2P1/PP1KPP2/RNB2BNR w - - 0 11, batch_size=torch.Size([]), device=None)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

OpenSpiel doesn't support Jax, so we can't do this, like `BraxWrapper` and `JumanjiWrapper` do:

In [12]:
#from torchrl.envs.libs.jax_utils import _object_to_tensordict
#
#state_dict = _object_to_tensordict(chess_state, device='cpu', batch_size=())

## New Plan

There is no way to obtain a state dict in OpenSpiel. There is a way to get a string representation of the state of the env and reset the state to the given string representation, so we will use that. However, it would be too inefficient to do this on every single step of `OpenSpielWrapper.step`, so we will have to make it a stateful env. But we can support resetting the `OpenSpielWrapper` to a given string representation of a state in `OpenSpielWrapper.reset`, and that would be good enough to support MCTS.

### Example

In [13]:
import numpy as np
chess_game = pyspiel.load_game('chess')
chess_env = chess_game.new_initial_state()

for _ in range(4):
    chess_env.apply_action(np.random.choice(chess_env.legal_actions()))

state = chess_env.serialize()
print('State 1:')
print(state)

for _ in range(4):
    chess_env.apply_action(np.random.choice(chess_env.legal_actions()))

print('State 2:')
print(chess_env.serialize())

chess_env = chess_env.get_game().deserialize_state(state)
print('Reloaded state 1:')
print(chess_env.serialize())


State 1:
FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1
89
1257
162
1330

State 2:
FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1
89
1257
162
1330
1841
89
673
16

Reloaded state 1:
FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1
89
1257
162
1330



In [14]:
td = TensorDict(source={'state': state})

if 'state' in td:
    print(td['state'])



FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1
89
1257
162
1330



In [15]:
chess_env.rewards()

[0.0, 0.0]

In [16]:
chess_env.is_terminal()

False

## Multiplayer games

I'be been wondering what the "observation" should return for OpenSpiel after a `step` call. Should it only contain the observation for one of the players, or should it contain all of the observations?

`PettingZooEnv` has tic-tac-toe, a 2-player game, so I could do whatever it does.

In [17]:
from torchrl.envs import PettingZooEnv

env = PettingZooEnv(
    task="tictactoe_v3",
    parallel=False,
    # A group map allows you to combine the agents into a batched td
    #group_map={"player": ["player_1", "player_2"]},
    categorical_actions=False,
    seed=0,
    use_mask=True,
)

td = env.reset()
print('td after reset:')
print(td)



td after reset:
TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        player_1: TensorDict(
            fields={
                action_mask: Tensor(shape=torch.Size([1, 9]), device=cpu, dtype=torch.bool, is_shared=False),
                done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                mask: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: TensorDict(
                    fields={
                        observation: Tensor(shape=torch.Size([1, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False)},
                    batch_size=torch.Size([1]),
                    device=None,
                    is_shared=False),
                terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=to



In [18]:
action = env.action_spec.rand()

#for _ in range(100):
#    action = env.action_spec.rand()
#    if (action['player_1', 'action'] == action['player_2', 'action']).all():
#        print('yes')
print('rand action:')
print(action)
print(action['player_1', 'action'])
print(action['player_2', 'action'])

rand action:
TensorDict(
    fields={
        player_1: TensorDict(
            fields={
                action: Tensor(shape=torch.Size([1, 9]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([1]),
            device=None,
            is_shared=False),
        player_2: TensorDict(
            fields={
                action: Tensor(shape=torch.Size([1, 9]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([1]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0]])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1]])


In [19]:

td = env.step(env.action_spec.rand())
print('td after step:')
print(td)

td after step:
TensorDict(
    fields={
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
                player_1: TensorDict(
                    fields={
                        action_mask: Tensor(shape=torch.Size([1, 9]), device=cpu, dtype=torch.bool, is_shared=False),
                        done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                        mask: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
                        observation: TensorDict(
                            fields={
                                observation: Tensor(shape=torch.Size([1, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False)},
                            batch_size=torch.Size([1]),
                            device=None,
                            is_shared=False),
                        reward: Tensor(shape=to

In [20]:
# What happens if both players do the same action?
td = env.reset()

action = env.action_spec.rand()

#action['player_2', 'action'] = action['player_1', 'action'].clone().detach()

td = env.step(action)
print(env.action_spec['player_1', 'action'].mask)
print(env.action_spec['player_2', 'action'].mask)
#print(td['next', 'player_1', 'observation', 'observation'])
print(td['next', 'player_1', 'observation', 'observation'])

action = env.action_spec.rand()
td = env.step(action)
print(env.action_spec['player_1', 'action'].mask)
print(env.action_spec['player_2', 'action'].mask)
#print(td['next', 'player_1', 'observation', 'observation'])
print(td['next', 'player_1', 'observation', 'observation'])

tensor([[True, True, True, True, True, True, True, True, True]])
tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True]])
tensor([[[[1, 0],
          [0, 0],
          [0, 0]],

         [[0, 0],
          [0, 0],
          [0, 0]],

         [[0, 0],
          [0, 0],
          [0, 0]]]], dtype=torch.int8)
tensor([[False,  True,  True, False,  True,  True,  True,  True,  True]])
tensor([[True, True, True, True, True, True, True, True, True]])
tensor([[[[1, 0],
          [0, 0],
          [0, 0]],

         [[0, 1],
          [0, 0],
          [0, 0]],

         [[0, 0],
          [0, 0],
          [0, 0]]]], dtype=torch.int8)


So the action spec makes it apparent that there are two players. When we do `env.action_spec.rand()`, a random action is generated for both players, but only the player whose turn it is has their actions masked according to the actual legal actions. Then, when we do `env.step()`, we're only applying the action for the player whose turn it currently is, and the other one gets ignored.

`MeltingpotEnv` also seems to support multiplayer games, apparently in the same way, or at least very similar.


## `pyspeil.State.apply_moves`

In [21]:
chess_game = pyspiel.load_game('chess')
env = chess_game.new_initial_state()
action = np.random.choice(env.legal_actions())

try:
    env.apply_actions_with([action])
except AttributeError:
    print("failed")

failed


Ah, ok, we cannot apply multiple actions at once in OpenSpiel anyway if not all the `State`s implement it.

In [22]:
chess_game = pyspiel.load_game('chess')
env = chess_game.new_initial_state()
env.current_player()

1

In [23]:
action = np.random.choice(env.legal_actions())
env.apply_action(action)
env.current_player()

0

## TensorDict returned by `reset()` versus `step()`

In [24]:
import brax.envs
from torchrl.envs import BraxWrapper
base_env = brax.envs.get_environment("ant")
env = BraxWrapper(base_env)
env.set_seed(0)
td = env.reset()
td_reset = td.clone()

td["action"] = env.action_spec.rand()
td = env.step(td)
td_step = td.clone()

  from collections import Mapping
  from collections import Mapping, Set, Iterable
  from collections import Mapping, Set, Iterable


In [25]:
print(td_reset.keys())
#print(td_reset['state'].keys())
print('---------------')
print(td_step.keys())
#print(td_step['state'].keys())
print(td_step['next'].keys())
#print(td_step['next', 'state'].keys())

<class 'tensordict.utils._StringKeys'>(dict_keys(['observation', 'done', 'terminated', 'state']))
---------------
<class 'tensordict.utils._StringKeys'>(dict_keys(['observation', 'done', 'terminated', 'state', 'action', 'next']))
<class 'tensordict.utils._StringKeys'>(dict_keys(['observation', 'reward', 'done', 'terminated', 'state']))


In [33]:
from torchrl.envs import PettingZooEnv

env = PettingZooEnv(
    task="tictactoe_v3",
    parallel=False,
    categorical_actions=False,
    seed=0,
    use_mask=True,
)
td = env.reset()
td_reset = td.clone()

td = env.step(env.action_spec.rand())
td_step = td.clone()



In [34]:
#print(td_reset.keys())
#print('---------------')
#print(td_step.keys())
#print(td_step['next'].keys())

action_keys = env.action_keys
done_keys = env.done_keys
reward_keys = env.reward_keys
observation_keys = env.full_observation_spec.keys(True, True)
state_keys = env.full_state_spec.keys(True, True)
print(f'action_keys: {action_keys}')
print(f'done_keys: {done_keys}')
print(f'reward_keys: {reward_keys}')
print(f'observation_keys: {observation_keys}')
print(f'state_keys: {state_keys}')

action_keys: [('player_1', 'action'), ('player_2', 'action')]
done_keys: ['done', 'terminated', 'truncated', ('player_1', 'done'), ('player_1', 'terminated'), ('player_1', 'truncated'), ('player_2', 'done'), ('player_2', 'terminated'), ('player_2', 'truncated')]
reward_keys: [('player_1', 'reward'), ('player_2', 'reward')]
observation_keys: _CompositeSpecKeysView(keys=[('player_1', 'observation', 'observation'), ('player_1', 'action_mask'), ('player_1', 'mask'), ('player_2', 'observation', 'observation'), ('player_2', 'action_mask'), ('player_2', 'mask')])
state_keys: _CompositeSpecKeysView(keys=[])


In [36]:
from torchrl.envs import OpenSpielEnv

env = OpenSpielEnv("chess")
td = env.reset()
td_reset = td.clone()

td = env.step(TensorDict({'action': env.action_spec.rand()}))
td_step = td.clone()

In [37]:
#print(td_reset.keys())
#print('---------------')
#print(td_step.keys())
#print(td_step['next'].keys())

action_keys = env.action_keys
done_keys = env.done_keys
reward_keys = env.reward_keys
observation_keys = env.full_observation_spec.keys(True, True)
state_keys = list(env.full_state_spec.keys(True, True))
print(f'action_keys: {action_keys}')
print(f'done_keys: {done_keys}')
print(f'reward_keys: {reward_keys}')
print(f'observation_keys: {observation_keys}')
print(f'state_keys: {state_keys}')

action_keys: ['action']
done_keys: ['done', 'terminated']
reward_keys: ['reward']
observation_keys: _CompositeSpecKeysView(keys=['observation'])
state_keys: ['observation']


Maybe the TensorDict returned by `step` needs to have 'observation' in it?

In [32]:
env.rollout(3)

RuntimeError: The sets of keys in the tensordicts to stack are exclusive. Consider using `LazyStackedTensorDict.maybe_dense_stack` instead.

{'observation', 'action', 'reward', 'current_player', 'next', 'done', 'terminated', 'state'}
{'observation', 'action',           'current_player', 'next', 'done', 'terminated', 'state'}


## OpenSpiel games info

In [8]:
import pyspiel

def game_info_str(game_type):
    dynamics = game_type.dynamics
    name = game_type.short_name

    return (
        f"{name}: {dynamics}"
    )


for game_type in pyspiel.registered_games():
    if game_type.dynamics != pyspiel.GameType.Dynamics.SEQUENTIAL:
        print(game_info_str(game_type))


blotto: Dynamics.SIMULTANEOUS
coop_box_pushing: Dynamics.SIMULTANEOUS
goofspiel: Dynamics.SIMULTANEOUS
laser_tag: Dynamics.SIMULTANEOUS
markov_soccer: Dynamics.SIMULTANEOUS
matching_pennies_3p: Dynamics.SIMULTANEOUS
matrix_bos: Dynamics.SIMULTANEOUS
matrix_brps: Dynamics.SIMULTANEOUS
matrix_cd: Dynamics.SIMULTANEOUS
matrix_coordination: Dynamics.SIMULTANEOUS
matrix_mp: Dynamics.SIMULTANEOUS
matrix_pd: Dynamics.SIMULTANEOUS
matrix_rps: Dynamics.SIMULTANEOUS
matrix_rpsw: Dynamics.SIMULTANEOUS
matrix_sh: Dynamics.SIMULTANEOUS
matrix_shapleys_game: Dynamics.SIMULTANEOUS
mfg_crowd_modelling: Dynamics.MEAN_FIELD
mfg_crowd_modelling_2d: Dynamics.MEAN_FIELD
mfg_dynamic_routing: Dynamics.MEAN_FIELD
mfg_garnet: Dynamics.MEAN_FIELD
nfg_game: Dynamics.SIMULTANEOUS
normal_form_extensive_game: Dynamics.SIMULTANEOUS
oshi_zumo: Dynamics.SIMULTANEOUS
pathfinding: Dynamics.SIMULTANEOUS
repeated_game: Dynamics.SIMULTANEOUS


## TODO list

* Support chance nodes: <https://openspiel.readthedocs.io/en/latest/concepts.html#playing-a-trajectory>

* Change specs to be like PettingZooEnv, as shown in the exploration above. Specifically, support all agents acting at each step, but for games where players act sequentially, use a mask to mask out all but the player whose turn it currently is. PettingZooEnv docs explain it: <https://pytorch.org/rl/stable/reference/generated/torchrl.envs.PettingZooEnv.html>

* (Maybe) make some more documentation since OpenSpiel has a fair amount of important concepts that people would need to be aware of to use it properly. Could base docs off of existing OpenSpiel docs: <https://openspiel.readthedocs.io/en/latest/concepts.html>. Alternatively, just link to the relevant OpenSpiel docs.

* "cursor_go" and "cliff_walking" fail tests with a similar error. Hunch is that the game reaches the end state and I'm not detecting it quite correctly.

* Explanation of when `current_player() < 0` and why: <https://openspiel.readthedocs.io/en/latest/api_reference/state_current_player.html?highlight=players#openspiel-state-methods-current-player>