# Basic usage

In [None]:
import jax
import jax.lax
import jax.numpy as jnp
import numpy as np
import os
import matplotlib.pyplot as plt

In [None]:
from src.rts.config import EnvConfig
from src.rts.env import Board, EnvState, init_state, move, reinforce_troops, is_done
from src.rts.utils import assert_valid_state, get_legal_moves, fixed_argwhere

from src.rts.visualizaiton import visualize_board

from src.main import step, batched_step, p1_step

In [None]:
os.environ["JAX_CHECK_TRACER_LEAKS"] = "TRUE"

### Random play

In [None]:
# Play a game with random moves for 1000 steps
# Visualize the board interactivly
rng_key = jax.random.PRNGKey(3)
config = EnvConfig(
    board_width = 10,
    board_height = 10,
    num_neutral_bases = 6,
    num_neutral_troops_start = 10,
    neutral_troops_min = 4,
    neutral_troops_max = 10,
    player_start_troops=5,
    bonus_time=10,
)
state = init_state(rng_key, config)

In [None]:
for i in range(5):
    rng_key, subkey = jax.random.split(rng_key)
    state, p1_reward = step(state, subkey, config)
    assert_valid_state(state)
    if i % 1 == 0:
        visualize_board(state)
    print(p1_reward)

## Run with NN

In [None]:
from flax import nnx
import optax

class Model(nnx.Module):
    def __init__(self, in_dim, mid_dim, out_dim, rngs: nnx.Rngs):
        self.lin_in = nnx.Linear(in_dim, mid_dim, rngs=rngs)
        self.layer_norm = nnx.LayerNorm(mid_dim, rngs=rngs)
        self.lin_out = nnx.Linear(mid_dim, out_dim, rngs=rngs)

    def __call__(self, x):
        x = nnx.relu(self.layer_norm(self.lin_in(x)))
        # x = nnx.relu(self.lin_in(x))
        return self.lin_out(x)

In [None]:
rng_key = jax.random.PRNGKey(3)
width = 10
height = 10
model = Model(width*height*4, 256, width*height*4, rngs=nnx.Rngs(0))
nnx.display(model)

In [None]:
config = EnvConfig(
    board_width = width,
    board_height = height,
    num_neutral_bases = 3,
    num_neutral_troops_start = 5,
    neutral_troops_min = 4,
    neutral_troops_max = 10,
    player_start_troops=5,
    bonus_time=10,
)

state = init_state(rng_key, config)
for i in range(50):
    flat_state = jnp.array(state.board.flatten())
    legal_mask = get_legal_moves(state, 0)
    legal_mask = jnp.array(legal_mask.flatten())
    action = jnp.argmax((model(flat_state) + 10) * legal_mask)
    # split action from int to array
    # y, x, direction
    action = jnp.array([action // (width*4), (action % (width*4))//4, action % 4])
    rng_key, subkey = jax.random.split(rng_key)
    state, p1_reward = p1_step(state, subkey, config, action)
    assert_valid_state(state)
    if i % 5 == 0:
        visualize_board(state)
    print(p1_reward)

# Batched

In [None]:
# Now we vmap
N = 50
rng_key = jax.random.PRNGKey(3)
rng_keys = jax.random.split(rng_key, N)

# Create the initial state for each game via vmap.
batched_init_state = jax.vmap(lambda key: init_state(key, config))
states = batched_init_state(rng_keys)

In [None]:
for i in range(50):
    # For each parallel game, split its RNG key into two:
    # keys_split will have shape (N, 2, key_shape).
    keys_split = jax.vmap(lambda key: jax.random.split(key, 2))(rng_keys)
    # Update rng_keys to the first half and use the second half as subkeys.
    rng_keys = keys_split[:, 0]
    subkeys = keys_split[:, 1]

    # Take one step in parallel for all games.
    states, p1_rewards = batched_step(states, subkeys, config)

    # Visualize and validate
    if i % 250 == 0:
        print(p1_rewards)
        board = Board(
            player_1_troops = states.board.player_1_troops[79],
            player_2_troops = states.board.player_2_troops[79],
            neutral_troops = states.board.neutral_troops[79],
            bases = states.board.bases[79],
        )
        single_state = EnvState(board = board)
        assert_valid_state(single_state)
        visualize_board(single_state)

### PQN

In [None]:
import functools
import time
from contextlib import contextmanager
from dataclasses import dataclass

In [None]:
@contextmanager
def time_block(label: str, timing_dict: dict[str, float]):
    start_time = time.time()
    yield
    end_time = time.time()
    timing_dict[label] = end_time - start_time

In [None]:
@dataclass(frozen=True)
class Params:
    num_iterations: int
    lr: float
    gamma: float
    q_lambda: float
    num_envs: int
    num_steps: int
    update_epochs: int
    num_minibatches: int
    epsilon: float

In [None]:
flat_state = state.board.flatten()
legal_mask = get_legal_moves(state, 0).flatten()
q_net_action = jnp.argmax((model(flat_state) + 1000) * legal_mask)
q_net_action

In [None]:
params = Params(
    num_iterations=100,
    lr=6e-4,
    gamma=0.99,
    q_lambda=0.65,
    num_envs=2048,
    num_steps=26,
    update_epochs=4,
    num_minibatches=4,
    epsilon=0.25,
)
q_net = Model(width*height*4, 256, width*height*4, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(q_net, optax.adam(params.lr))

# @nnx.jit
@functools.partial(nnx.jit, static_argnums=(1,3))
def single_rollout(rng_key, config: EnvConfig, model: Model, params: Params):
    state = init_state(rng_key, config)
    
    def policy_step(carry, _):
        state, rng_key, cum_reward = carry
        
        flat_state = state.board.flatten()
        legal_mask = get_legal_moves(state, 0).flatten()
        
        logits = model(flat_state)
        # choose the action with the highest Q-value that is also legal
        q_net_action = jnp.argmax((logits + 1000) * legal_mask)
        # epsilon-greedy exploration
        legal_actions, num_actions = fixed_argwhere(
            legal_mask, max_actions=state.board.width * state.board.height * 4
        )
        rng_key, subkey = jax.random.split(rng_key)
        action_idx = jax.random.randint(subkey, (), 0, num_actions)
        explore_action = jnp.take(legal_actions, action_idx, axis=0)[0]

        action = jax.lax.cond(
            jax.random.bernoulli(rng_key, params.epsilon),
            lambda _: explore_action,
            lambda _: q_net_action,
            operand=None,
        )
        
        # Split the scalar action into (row, col, direction) components.
        action_split = jnp.array([
            action // (config.board_width * 4),
            (action % (config.board_width * 4)) // 4,
            action % 4
        ])
        
        rng_key, subkey = jax.random.split(rng_key)
        next_state, p1_reward = p1_step(state, subkey, config, action_split)
        
        new_cum_reward = cum_reward + p1_reward

        done = is_done(next_state)
        
        # The buffer (y) collects: observation (state.board), action, reward, done, next observation.
        y = (state.board.flatten(), action, p1_reward, done, next_state.board.flatten())
        
        # The new carry is the updated state, RNG key, and cumulative reward.
        return (next_state, rng_key, new_cum_reward), y

    (final_state, final_rng, cum_return), scan_out = jax.lax.scan(
        policy_step,
        (state, rng_key, jnp.array(0.0)),
        None,
        params.num_steps
    )
    obs_buffer, actions_buffer, rewards_buffer, done_buffer, next_obs_buffer = scan_out
    return obs_buffer, actions_buffer, rewards_buffer, done_buffer, next_obs_buffer, cum_return

obs_buffer, actions_buffer, rewards_buffer, done_buffer, next_obs_buffer, cum_return = single_rollout(rng_key, config, q_net, params)
actions_buffer

In [None]:
@functools.partial(nnx.jit, static_argnums=(4,))
def q_lambda_return(q_net: Model, rewards_buffer: jnp.ndarray, done_buffer: jnp.ndarray, next_obs_buffer: jnp.ndarray, params: Params):
    # Compute Q-values for the next observations via vectorized max.
    # This returns an array of shape (num_steps,)
    q_values = jax.vmap(lambda obs: jnp.max(q_net(obs), axis=-1))(next_obs_buffer)

    # For the final step, compute the return as:
    # returns[-1] = rewards[-1] + gamma * q_value[-1] * (1 - done[-1])
    returns_last = rewards_buffer[-1] + params.gamma * q_values[-1] * (1.0 - done_buffer[-1])
    
    # For timesteps 0,...,num_steps-2 we use:
    # returns[t] = rewards[t] + gamma * (q_lambda * returns[t+1] +
    #                                    (1 - q_lambda) * q_value[t+1] * (1 - done[t+1]))
    # To compute this in reverse, we reverse the arrays (excluding the last step).
    rewards_rev = rewards_buffer[:-1][::-1]
    dones_rev = done_buffer[1:][::-1]
    next_vals_rev = q_values[1:][::-1]

    def scan_fn(next_return, inputs):
        reward, done, next_value = inputs
        nextnonterminal = 1.0 - done
        current_return = reward + params.gamma * (
            params.q_lambda * next_return +
            (1 - params.q_lambda) * next_value * nextnonterminal
        )
        return current_return, current_return

    # The scan will traverse the reversed sequences.
    # Its initial carry is the last return (for t = num_steps - 1)
    _, returns_rev_scan = jax.lax.scan(
        scan_fn,
        returns_last,
        (rewards_rev, dones_rev, next_vals_rev)
    )
    
    # Flip the scanned returns back to the original order.
    returns_first_part = returns_rev_scan[::-1]
    # Append the final return computed above.
    full_returns = jnp.concatenate([returns_first_part, jnp.array([returns_last])], axis=0)
    
    return full_returns

returns = q_lambda_return(q_net, rewards_buffer, done_buffer, next_obs_buffer, params)
returns.shape

In [None]:
batch_size = params.num_steps * params.num_envs // params.num_minibatches

# @functools.partial(jax.jit, static_argnums=(0,1))
@nnx.jit
def train_step(q_net: Model, optimizer: nnx.Optimizer, observations: jnp.ndarray, actions: jnp.ndarray, returns: jnp.ndarray):
    def loss_fn(q_net):
        q_values = q_net(observations)
        acted_q_values = jnp.take_along_axis(q_values, actions[:, None], axis=1).squeeze()
        return ((acted_q_values - returns) ** 2).mean()

    loss, grads = nnx.value_and_grad(loss_fn)(q_net)
    optimizer.update(grads)

    return loss

train_step(q_net, optimizer, obs_buffer, actions_buffer, returns)
state = init_state(rng_key, config)
q_net(state.board.flatten())

In [None]:
def train(q_net: Model, optimizer: nnx.Optimizer, config: EnvConfig, params: Params):
    rng_key = jax.random.PRNGKey(0)
    losses = []
    for iteration in range(params.num_iterations):
        rng_key, rollout_key = jax.random.split(rng_key)
        rollout = single_rollout(rollout_key, config, q_net, params)
        obs_buffer, actions_buffer, rewards_buffer, done_buffer, next_obs_buffer, cum_return = rollout

        returns = q_lambda_return(q_net, rewards_buffer, done_buffer, next_obs_buffer, params)

        for epoch in range(params.update_epochs):
            loss = train_step(q_net, optimizer, obs_buffer, actions_buffer, returns)
            losses.append(loss)
        if iteration % 10 == 0:
            print(f"Iteration {iteration} - Loss: {loss}")

    # evaluate
    state = init_state(rng_key, config)
    q_net(state.board.flatten())

    plt.plot([min(l, 1000) for l in losses])

    return q_net

q_net = train(q_net, optimizer, config, params)

In [None]:
width = 10
height = 10
config = EnvConfig(
    board_width = width,
    board_height = height,
    num_neutral_bases = 3,
    num_neutral_troops_start = 5,
    neutral_troops_min = 4,
    neutral_troops_max = 10,
    player_start_troops=5,
    bonus_time=10,
)
params = Params(
    num_iterations=5000,
    lr=6e-4,
    gamma=0.99,
    q_lambda=0.65,
    num_envs=2048,
    num_steps=26,
    update_epochs=4,
    num_minibatches=4,
    epsilon=0.4,
)
q_net = Model(width*height*4, 256, width*height*4, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(q_net, optax.adam(params.lr))

q_net = train(q_net, optimizer, config, params)

In [None]:
# Now we can use the trained model to play a game
state = init_state(rng_key, config)
for i in range(500):
    legal_mask = get_legal_moves(state, 0)
    legal_mask = jnp.array(legal_mask.flatten())
    action = jnp.argmax((q_net(state.board.flatten()) + 10) * legal_mask)
    # split action from int to array
    # y, x, direction
    action = jnp.array([action // (width*4), (action % (width*4))//4, action % 4])
    rng_key, subkey = jax.random.split(rng_key)
    state, p1_reward = p1_step(state, subkey, config, action)
    assert_valid_state(state)
    if i % 5 == 0:
        visualize_board(state)
    print(p1_reward)