In [None]:
from dataclasses import dataclass
import functools
import os

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from src.rts.config import EnvConfig
from src.rts.env import Board, EnvState, init_state, random_step, p1_step
from src.rts.utils import assert_valid_state, get_legal_moves
from src.rts.visualizaiton import visualize_board

In [None]:
os.environ["JAX_CHECK_TRACER_LEAKS"] = "TRUE"

## Random play

In [None]:
# Play a game with random moves for 1000 steps
# Visualize the board interactivly
rng_key = jax.random.PRNGKey(3)
config = EnvConfig(
    board_width = 10,
    board_height = 10,
    num_neutral_bases = 6,
    num_neutral_troops_start = 10,
    neutral_troops_min = 4,
    neutral_troops_max = 10,
    player_start_troops=5,
    bonus_time=10,
)
state = init_state(rng_key, config)

In [None]:
for i in range(5):
    rng_key, subkey = jax.random.split(rng_key)
    state = random_step(state, subkey, config)
    assert_valid_state(state)
    if i % 1 == 0:
        visualize_board(state)

## Batched Random

In [None]:
@functools.partial(jax.jit, static_argnames=("config",))
def batched_step(states, rng_keys, config):
    def single_step(state, key):
        return random_step(state, key, config)

    return jax.vmap(single_step)(states, rng_keys)

In [None]:
# Now we vmap
N = 50
rng_key = jax.random.PRNGKey(3)
rng_keys = jax.random.split(rng_key, N)

# Create the initial state for each game via vmap.
batched_init_state = jax.vmap(lambda key: init_state(key, config))
states = batched_init_state(rng_keys)

In [None]:
for i in range(50):
    # For each parallel game, split its RNG key into two:
    # keys_split will have shape (N, 2, key_shape).
    keys_split = jax.vmap(lambda key: jax.random.split(key, 2))(rng_keys)
    # Update rng_keys to the first half and use the second half as subkeys.
    rng_keys = keys_split[:, 0]
    subkeys = keys_split[:, 1]

    # Take one step in parallel for all games.
    states = batched_step(states, subkeys, config)

    # Visualize and validate
    if i % 250 == 0:
        board = Board(
            player_1_troops = states.board.player_1_troops[79],
            player_2_troops = states.board.player_2_troops[79],
            neutral_troops = states.board.neutral_troops[79],
            bases = states.board.bases[79],
        )
        single_state = EnvState(board = board)
        assert_valid_state(single_state)
        visualize_board(single_state)

## Run with Neural Network

In [None]:
from flax import nnx
from src.rl.pqn import Model
import optax

In [None]:
rng_key = jax.random.PRNGKey(3)
width = 10
height = 10
model = Model(width*height*4, 256, width*height*4, rngs=nnx.Rngs(0))
nnx.display(model)

In [None]:
config = EnvConfig(
    board_width = width,
    board_height = height,
    num_neutral_bases = 3,
    num_neutral_troops_start = 5,
    neutral_troops_min = 4,
    neutral_troops_max = 10,
    player_start_troops=5,
    bonus_time=10,
)

state = init_state(rng_key, config)
for i in range(50):
    flat_state = jnp.array(state.board.flatten())
    legal_mask = get_legal_moves(state, 0)
    legal_mask = jnp.array(legal_mask.flatten())
    action = jnp.argmax((model(flat_state) + 10) * legal_mask)
    # split action from int to array
    # y, x, direction
    action = jnp.array([action // (width*4), (action % (width*4))//4, action % 4])
    rng_key, subkey = jax.random.split(rng_key)
    state, p1_reward = p1_step(state, subkey, config, action)
    assert_valid_state(state)
    if i % 5 == 0:
        visualize_board(state)
    print(p1_reward)

## PQN

In [None]:
from src.rl.pqn import Params

In [None]:
flat_state = state.board.flatten()
legal_mask = get_legal_moves(state, 0).flatten()
q_net_action = jnp.argmax((model(flat_state) + 1000) * legal_mask)
q_net_action

In [None]:
from src.rl.pqn import single_rollout
params = Params(
    num_iterations=100,
    lr=6e-4,
    gamma=0.99,
    q_lambda=0.65,
    num_envs=2048,
    num_steps=20,
    update_epochs=4,
    num_minibatches=4,
    epsilon=0.25,
)
q_net = Model(width*height*4, 256, width*height*4, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(q_net, optax.adam(params.lr))

obs_buffer, actions_buffer, rewards_buffer, done_buffer, next_obs_buffer, cum_return = single_rollout(rng_key, config, q_net, params)
actions_buffer, done_buffer, rewards_buffer

In [None]:
from src.rl.pqn import q_lambda_return

returns = q_lambda_return(q_net, rewards_buffer, done_buffer, next_obs_buffer, params)
returns

In [None]:
from src.rl.pqn import train_step

train_step(q_net, optimizer, obs_buffer, actions_buffer, returns)

print(jnp.take_along_axis(q_net(obs_buffer), actions_buffer[:, None], axis=1).squeeze())
print(returns)

In [None]:
def train(q_net: Model, optimizer: nnx.Optimizer, config: EnvConfig, params: Params):
    rng_key = jax.random.PRNGKey(0)
    losses = []
    for iteration in range(params.num_iterations):
        rng_key, rollout_key = jax.random.split(rng_key)
        rollout = single_rollout(rollout_key, config, q_net, params)
        obs_buffer, actions_buffer, rewards_buffer, done_buffer, next_obs_buffer, cum_return = rollout

        returns = q_lambda_return(q_net, rewards_buffer, done_buffer, next_obs_buffer, params)

        # print(returns)

        for epoch in range(params.update_epochs):
            loss = train_step(q_net, optimizer, obs_buffer, actions_buffer, returns)
            losses.append(loss)
        if iteration % 10 == 0:
            print(f"Iteration {iteration} - Loss: {loss}")

    # evaluate
    state = init_state(rng_key, config)
    q_net(state.board.flatten())

    plt.plot([min(l, 1000) for l in losses])

    return q_net

q_net = train(q_net, optimizer, config, params)

In [None]:
width = 10
height = 10
config = EnvConfig(
    board_width = width,
    board_height = height,
    num_neutral_bases = 3,
    num_neutral_troops_start = 5,
    neutral_troops_min = 4,
    neutral_troops_max = 10,
    player_start_troops=5,
    bonus_time=10,
)
params = Params(
    num_iterations=500,
    lr=4e-4,
    gamma=0.99,
    q_lambda=0.92,
    num_envs=50,
    num_steps=250,
    update_epochs=1,
    num_minibatches=4,
    epsilon=0.3,
)
q_net = Model(width*height*4, 256, width*height*4, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(q_net, optax.adam(params.lr))

q_net = train(q_net, optimizer, config, params)

In [None]:
# Now we can use the trained model to play a game
state = init_state(rng_key, config)
rewards = []
for i in range(200):
    legal_mask = get_legal_moves(state, 0)
    legal_mask = jnp.array(legal_mask.flatten())
    print(q_net(state.board.flatten())* legal_mask)
    action = jnp.argmax((q_net(state.board.flatten()) + 1000) * legal_mask)
    # split action from int to array
    # y, x, direction
    action = jnp.array([action // (width*4), (action % (width*4))//4, action % 4])
    rng_key, subkey = jax.random.split(rng_key)
    state, p1_reward = p1_step(state, subkey, config, action)
    rewards.append(p1_reward)
    assert_valid_state(state)
    if i % 5 == 0:
        visualize_board(state)
    print(p1_reward)
print(rewards)
# for each step print the cumulative reward to the end from that step
print(np.cumsum(rewards))
plt.plot(np.cumsum(rewards))

## PQN with vmap

In [None]:
rng_keys = jax.random.split(rng_key, params.num_envs)

vmapped_rollout = jax.vmap(single_rollout, in_axes=(0, None, None, None))

obs_buffer, actions_buffer, rewards_buffer, done_buffer, next_obs_buffer, cum_returns = vmapped_rollout(
    rng_keys, config, q_net, params
)
obs_buffer.shape, actions_buffer.shape

In [None]:
vmapped_q_lambda_return = jax.vmap(q_lambda_return, in_axes=(None, 0, 0, 0, None))

returns = vmapped_q_lambda_return(q_net, rewards_buffer, done_buffer, next_obs_buffer, params)
returns.shape

In [None]:
# Flatten the first two dimensions (envs and time steps)
flat_observations = obs_buffer.reshape(-1, obs_buffer.shape[-1])
flat_actions = actions_buffer.reshape(-1)
flat_returns = returns.reshape(-1)

# Now train your single model on all the data in one go.
loss = train_step(q_net, optimizer, flat_observations, flat_actions, flat_returns)

In [None]:
def train_vmapped(q_net: Model, optimizer: nnx.Optimizer, config: EnvConfig, params: Params):
    rng_key = jax.random.PRNGKey(0)
    vmapped_rollout = jax.vmap(single_rollout, in_axes=(0, None, None, None))
    vmapped_q_lambda_return = jax.vmap(q_lambda_return, in_axes=(None, 0, 0, 0, None))
    losses = []
    for iteration in range(params.num_iterations):
        rng_keys = jax.random.split(rng_key, params.num_envs + 1)
        rng_key, rollout_keys = rng_keys[0], rng_keys[1:]
        rollout = vmapped_rollout(rollout_keys, config, q_net, params)
        obs_buffer, actions_buffer, rewards_buffer, done_buffer, next_obs_buffer, cum_return = rollout

        returns = vmapped_q_lambda_return(q_net, rewards_buffer, done_buffer, next_obs_buffer, params)

        flat_observations = obs_buffer.reshape(-1, obs_buffer.shape[-1])
        flat_actions = actions_buffer.reshape(-1)
        flat_returns = returns.reshape(-1)

        loss = train_step(q_net, optimizer, flat_observations, flat_actions, flat_returns)
        losses.append(loss)
        if iteration % 10 == 0:
            print(f"Iteration {iteration} - Loss: {loss}")

    plt.plot([min(l, 1000) for l in losses])

    return q_net

In [None]:
width = 10
height = 10
config = EnvConfig(
    board_width = width,
    board_height = height,
    num_neutral_bases = 3,
    num_neutral_troops_start = 5,
    neutral_troops_min = 4,
    neutral_troops_max = 10,
    player_start_troops=5,
    bonus_time=10,
)
params = Params(
    num_iterations=500,
    lr=4e-4,
    gamma=0.99,
    q_lambda=0.92,
    num_envs=50,
    num_steps=250,
    update_epochs=1,
    num_minibatches=4,
    epsilon=0.3,
)
q_net = Model(width*height*4, 256, width*height*4, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(q_net, optax.adam(params.lr))

q_net = train_vmapped(q_net, optimizer, config, params)

### Minibatched

In [None]:
def train_minibatched(q_net: Model, optimizer: nnx.Optimizer, config: EnvConfig, params: Params):
    rng_key = jax.random.PRNGKey(0)
    vmapped_rollout = jax.vmap(single_rollout, in_axes=(0, None, None, None))
    vmapped_q_lambda_return = jax.vmap(q_lambda_return, in_axes=(None, 0, 0, 0, None))
    losses = []
    
    for iteration in range(params.num_iterations):
        # Split rng_key for each environment.
        rng_keys = jax.random.split(rng_key, params.num_envs + 1)
        rng_key, rollout_keys = rng_keys[0], rng_keys[1:]
        
        # Run vmapped rollout across all environments.
        rollout = vmapped_rollout(rollout_keys, config, q_net, params)
        obs_buffer, actions_buffer, rewards_buffer, done_buffer, next_obs_buffer, cum_return = rollout

        # Compute returns using vmapped q_lambda_return.
        returns = vmapped_q_lambda_return(q_net, rewards_buffer, done_buffer, next_obs_buffer, params)

        # Flatten rollout buffers to combine envs and timesteps into one batch dimension.
        flat_observations = obs_buffer.reshape(-1, obs_buffer.shape[-1])
        flat_actions = actions_buffer.reshape(-1)
        flat_returns = returns.reshape(-1)

        num_samples = flat_observations.shape[0]
        minibatch_size = num_samples // params.num_minibatches

        # Perform two passes (epochs) over the flattened rollout data.
        for epoch in range(params.update_epochs):
            # Shuffle indices for minibatch splitting.
            rng_key, perm_key = jax.random.split(rng_key)
            permuted_indices = jax.random.permutation(perm_key, num_samples)
            
            for i in range(params.num_minibatches):
                start_idx = i * minibatch_size
                # Ensure that the last minibatch gets any remaining samples.
                end_idx = (i + 1) * minibatch_size if i < params.num_minibatches - 1 else num_samples
                minibatch_idx = permuted_indices[start_idx:end_idx]

                minibatch_obs = flat_observations[minibatch_idx]
                minibatch_actions = flat_actions[minibatch_idx]
                minibatch_returns = flat_returns[minibatch_idx]

                loss = train_step(q_net, optimizer, minibatch_obs, minibatch_actions, minibatch_returns)
                losses.append(loss)

        if iteration % 10 == 0:
            print(f"Iteration {iteration} - Loss: {loss}")

    plt.plot([min(l, 1000) for l in losses])
    return q_net