# Basic usage

In [None]:
import jax
import jax.lax
import jax.numpy as jnp
import numpy as np

In [None]:
from src.rts.config import EnvConfig
from src.rts.env import Board, EnvState, init_state, move, reinforce_troops
from src.rts.utils import assert_valid_state, get_legal_moves, fixed_argwhere

from src.rts.visualizaiton import visualize_board

from src.main import step, batched_step, p1_step

### Random play

In [None]:
# Play a game with random moves for 1000 steps
# Visualize the board interactivly
rng_key = jax.random.PRNGKey(3)
config = EnvConfig(
    board_width = 10,
    board_height = 10,
    num_neutral_bases = 6,
    num_neutral_troops_start = 10,
    neutral_troops_min = 4,
    neutral_troops_max = 10,
    player_start_troops=5,
    bonus_time=10,
)
state = init_state(rng_key, config)

In [None]:
for i in range(5):
    rng_key, subkey = jax.random.split(rng_key)
    state, p1_reward = step(state, subkey, config)
    assert_valid_state(state)
    if i % 1 == 0:
        visualize_board(state)
    print(p1_reward)

## Run with NN

In [None]:
from flax import nnx

class Model(nnx.Module):
  def __init__(self, in_dim, mid_dim, out_dim, rngs: nnx.Rngs):
    self.lin_in = nnx.Linear(in_dim, mid_dim, rngs=rngs)
    self.layer_norm = nnx.LayerNorm(mid_dim, rngs=rngs)
    self.lin_out = nnx.Linear(mid_dim, out_dim, rngs=rngs)

  def __call__(self, x):
    x = nnx.relu(self.layer_norm(self.lin_in(x)))
    return self.lin_out(x)

In [None]:
rng_key = jax.random.PRNGKey(3)
width = 6
height = 6
config = EnvConfig(
    board_width = width,
    board_height = height,
    num_neutral_bases = 3,
    num_neutral_troops_start = 5,
    neutral_troops_min = 4,
    neutral_troops_max = 10,
    player_start_troops=5,
    bonus_time=10,
)
model = Model(width*height*4, 256, width*height*4, rngs=nnx.Rngs(0))
state = init_state(rng_key, config)

for i in range(500):
    flat_state = jnp.array(state.board.flatten())
    legal_mask = get_legal_moves(state, 0)
    legal_mask = jnp.array(legal_mask.flatten())
    action = jnp.argmax((model(flat_state) + 10) * legal_mask)
    # split action from int to array
    # y, x, direction
    action = jnp.array([action // (width*4), (action % (width*4))//4, action % 4])
    rng_key, subkey = jax.random.split(rng_key)
    state, p1_reward = p1_step(state, subkey, config, action)
    assert_valid_state(state)
    if i % 5 == 0:
        visualize_board(state)
    print(p1_reward)

# Batched

In [None]:
# Now we vmap
N = 50
rng_key = jax.random.PRNGKey(3)
rng_keys = jax.random.split(rng_key, N)

# Create the initial state for each game via vmap.
batched_init_state = jax.vmap(lambda key: init_state(key, config))
states = batched_init_state(rng_keys)

In [None]:
for i in range(5000):
    # For each parallel game, split its RNG key into two:
    # keys_split will have shape (N, 2, key_shape).
    keys_split = jax.vmap(lambda key: jax.random.split(key, 2))(rng_keys)
    # Update rng_keys to the first half and use the second half as subkeys.
    rng_keys = keys_split[:, 0]
    subkeys = keys_split[:, 1]

    # Take one step in parallel for all games.
    states, p1_rewards = batched_step(states, subkeys, config)

    # Visualize and validate
    if i % 250 == 0:
        print(p1_rewards)
        board = Board(
            player_1_troops = states.board.player_1_troops[79],
            player_2_troops = states.board.player_2_troops[79],
            neutral_troops = states.board.neutral_troops[79],
            bases = states.board.bases[79],
        )
        single_state = EnvState(board = board)
        assert_valid_state(single_state)
        visualize_board(single_state)