# A

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

from src.rts.config import EnvConfig
from src.rts.env import init_state, p1_step
from src.rts.utils import assert_valid_state, get_legal_moves
from src.rts.visualizaiton import visualize_board

from flax import nnx
from src.rl.pqn import Model
import optax
from src.rl.pqn import Params

In [None]:
from src.rl.pqn import train_minibatched

width = 10
height = 10
config = EnvConfig(
    num_players=2,
    board_width = width,
    board_height = height,
    num_neutral_bases = 3,
    num_neutral_troops_start = 5,
    neutral_troops_min = 4,
    neutral_troops_max = 10,
    player_start_troops=5,
    bonus_time=10,
)
params = Params(
    num_iterations=50,
    lr=4e-4,
    gamma=0.99,
    q_lambda=0.92,
    num_envs=50,
    num_steps=250,
    update_epochs=1,
    num_minibatches=4,
    epsilon=0.3,
)
q_net = Model(width*height*4, 512, width*height*4, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(q_net, optax.adam(params.lr))

q_net, losses, cum_returns = train_minibatched(q_net, optimizer, config, params)

plt.plot(losses)
plt.title("Losses")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()

plt.plot(cum_returns)
plt.title("Cumulative Returns")
plt.xlabel("Iteration")
plt.ylabel("Cumulative Return")
plt.show()

In [None]:
plt.plot(np.mean(cum_returns, axis=0))

In [None]:
# Now we can use the trained model to play a game
rng_key = jax.random.PRNGKey(0)
state = init_state(rng_key, config)
rewards = []
for i in range(200):
    legal_mask = get_legal_moves(state, 0)
    legal_mask = jnp.array(legal_mask.flatten())
    print(q_net(state.board.flatten())* legal_mask)
    action = jnp.argmax((q_net(state.board.flatten()) + 1000) * legal_mask)
    # split action from int to array
    # y, x, direction
    action = jnp.array([action // (width*4), (action % (width*4))//4, action % 4])
    rng_key, subkey = jax.random.split(rng_key)
    state, p1_reward = p1_step(state, subkey, config, action)
    rewards.append(p1_reward)
    assert_valid_state(state)
    if i % 5 == 0:
        visualize_board(state)
    print(p1_reward)
print(rewards)
# for each step print the cumulative reward to the end from that step
print(np.cumsum(rewards))
plt.plot(np.cumsum(rewards))