# Simple RTS

In [None]:
import os

import jax
import jax.numpy as jnp
import optax
from flax import nnx
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from src.rts.config import EnvConfig
from src.rts.env import init_state
from src.rts.utils import get_legal_moves, p1_step
from src.rts.visualization import visualize_state
from src.rl.pqn import Params, train_minibatched
from src.rl.model import MLP
from src.rl.eval import evaluate_batch

In [None]:
os.environ["JAX_CHECK_TRACER_LEAKS"] = "TRUE"

In [None]:
with jax.default_matmul_precision("bfloat16"):
    width = 10
    height = 10
    config = EnvConfig(
        num_players=2,
        board_width = width,
        board_height = height,
        num_neutral_bases = 3,
        num_neutral_troops_start = 5,
        neutral_troops_min = 4,
        neutral_troops_max = 10,
        player_start_troops=5,
        bonus_time=10,
    )
    params = Params(
        num_iterations=200,
        lr=8e-4,
        gamma=0.9,
        q_lambda=0.95,
        num_envs=256,
        num_steps=250,
        update_epochs=2,
        num_minibatches=10,
        epsilon=0.0008,
    )
    q_net = MLP(width*height*4, [512], width*height*4, rngs=nnx.Rngs(0))
    optimizer = nnx.Optimizer(q_net, optax.adam(params.lr))

    q_net, losses, cum_returns, timings = train_minibatched(q_net, optimizer, config, params)

In [None]:
plt.plot(losses)
plt.title("Losses")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()

plt.plot(np.mean(cum_returns, axis=1))
plt.title("Cumulative Returns")
plt.xlabel("Iteration")
plt.ylabel("Cumulative Return")
plt.show()

In [None]:
# eval
output = float(np.mean(evaluate_batch(q_net, config, jax.random.PRNGKey(0), batch_size=100, num_steps=250)))
print(f"Evaluation output: {output}")

In [None]:
rng_key = jax.random.PRNGKey(0)
state = init_state(rng_key, config)
rewards = []
for i in range(300):
    legal_mask = get_legal_moves(state, 0)
    legal_mask = jnp.array(legal_mask.flatten())
    action = jnp.argmax((q_net(state.board.flatten()) + 1000) * legal_mask)
    rng_key, subkey = jax.random.split(rng_key)
    state, p1_reward = p1_step(state, subkey, config, action)
    rewards.append(p1_reward)
    if i % 5 == 0:
        visualize_state(state)
plt.plot(np.cumsum(rewards))