# Climbing Game Environment

In [3]:
%%capture
!pip install matrax

In [4]:
# imports
import jax
import jax.numpy as jnp

import matrax

### Make stateless environment

A **stateless** matrix game is a game that is presented to both players at each iteration without any knowledge of what has happened in previous iterations of playing the game. Concretely, agents receive as "state" a zero vector of size `(num_agents,)`, and this vector is not updated at any time.

In [5]:
# make environment (without keeping state)
env = matrax.make("Climbing-stateless-v0")

In [6]:
# print environment
env

MatrixGame(
	payoff_matrix=Array([[[ 11, -30,   0],
        [-30,   7,   0],
        [  0,   6,   5]],

       [[ 11, -30,   0],
        [-30,   7,   0],
        [  0,   6,   5]]], dtype=int32),
)

### Reset and Step

In [7]:
# Reset your (jit-able) environment
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
print("state:", state)
print("timestep:", timestep)

state: State(step_count=Array(0, dtype=int32), key=Array([0, 0], dtype=uint32))
timestep: TimeStep(step_type=Array(0, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(agent_obs=Array([[0, 0],
       [0, 0]], dtype=int32), step_count=Array(0, dtype=int32)), extras=None)


In [8]:
# Interact with the (jit-able) environment
actions = jnp.array([0, 1])
state, timestep = jax.jit(env.step)(state, actions)   # Take a step and observe the next state and time step
print("state:", state)
print("timestep:", timestep)

state: State(step_count=Array(1, dtype=int32), key=Array([0, 0], dtype=uint32))
timestep: TimeStep(step_type=Array(1, dtype=int8), reward=Array([-30, -30], dtype=int32), discount=Array(1., dtype=float32), observation=Observation(agent_obs=Array([[0, 0],
       [0, 0]], dtype=int32), step_count=Array(1, dtype=int32)), extras=None)


Note the "state" in terms of the agents' observations have not change, i.e. agents' observations remain a zero vector irrespective of the actions they take. However, each agent has received a reward based on the joint actions of the agents.

### Make stateful environment

A **stateful** matrix game is a game where agents have knowledge of what happened in the previous iteration of playing the game in term of the actions selected by each agent. Concretely, at step `t` agents receive as "state" a vector of size `(num_agents,)` containing the agent each agent took at step `t-1`.

In [9]:
# make environment (with keeping state)
env = matrax.make("Climbing-stateful-v0")

### Reset and Step

In [10]:
# Reset your (jit-able) environment
state, timestep = jax.jit(env.reset)(key)
print("state:", state)
print("timestep:", timestep)

state: State(step_count=Array(0, dtype=int32), key=Array([0, 0], dtype=uint32))
timestep: TimeStep(step_type=Array(0, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(agent_obs=Array([[-1, -1],
       [-1, -1]], dtype=int32), step_count=Array(0, dtype=int32)), extras=None)


In [17]:
# Interact with the (jit-able) environment
state, timestep = jax.jit(env.step)(state, actions)   # Take a step and observe the next state and time step
print("state:", state)
print("timestep:", timestep)

state: State(step_count=Array(1, dtype=int32), key=Array([0, 0], dtype=uint32))
timestep: TimeStep(step_type=Array(1, dtype=int8), reward=Array([-30, -30], dtype=int32), discount=Array(1., dtype=float32), observation=Observation(agent_obs=Array([[0, 1],
       [0, 1]], dtype=int32), step_count=Array(1, dtype=int32)), extras=None)


Note the "state" is updated as each agent has as observation that consists of the actions of all agents in the previous step. 