# Basic usage

In [None]:
import jax.numpy as jnp

import chex
import jax
import jax.numpy as jnp
import numpy as np
from matplotlib import pyplot as plt

In [None]:
from src.rts.config import EnvConfig
from src.rts.env import EnvState

# from src.rts.visualizaiton import visualize_board

### Visualization

In [None]:
def visualize_board(state: EnvState) -> None:
    board = state.board
    # Visualize the board
    player_1_base_color = [0.0, 0.0, 0.8]
    player_1_color = [0.0, 0.0, 1.0]
    player_2_base_color = [0.8, 0.0, 0.0]
    player_2_color = [1.0, 0.0, 0.0]
    neutral_base_color = [0.2, 0.2, 0.2]
    neutral_color = [0.5, 0.5, 0.5]
    empty_color = [1.0, 1.0, 1.0]
    
    # Create a new image
    image = np.ones((board.shape[0], board.shape[1], 3))

    # Fill the image with the player colors


    # Fill inn troop numbers
    for i in range(board.shape[0]):
        for j in range(board.shape[1]):
            if board[i, j, 0] > 0:
                if board[i, j, 3] > 0:
                    image[i, j] = player_1_base_color
                else:
                    image[i, j] = player_1_color
                plt.text(j, i, str(int(board[i, j, 0])), ha="center", va="center", color="black")
            elif board[i, j, 1] > 0:
                if board[i, j, 4] > 0:
                    image[i, j] = player_2_base_color
                else:
                    image[i, j] = player_2_color
                plt.text(j, i, str(int(board[i, j, 1])), ha="center", va="center", color="black")
            elif board[i, j, 2] > 0:
                if board[i, j, 5] > 0:
                    image[i, j] = neutral_base_color
                else:
                    image[i, j] = neutral_color
                plt.text(j, i, str(int(board[i, j, 2])), ha="center", va="center", color="black")
            else:
                image[i, j] = empty_color
                plt.text(j, i, "0", ha="center", va="center", color="black")

    # In top left corner, show the time as pink number
    plt.text(-0.5, -0.5, str(state.time), ha="center", va="center", color="purple")

    # remove the axis
    plt.axis("off")

    # Show the image
    plt.imshow(image)
    plt.show()

### Init state

In [None]:
def init_state(rng_key: jnp.ndarray, params: EnvConfig) -> EnvState:
    """Each tile has 4 channels:
    1. Player 1 troops
    2. Player 2 troops
    3. Neutral troops
    4. Base"""
    # create a board
    width = params.board_width
    height = params.board_height

    board=jnp.zeros((width, height, 4), dtype=jnp.int32)
    # randomly select 2 start positions that should be unique
    pos1 = jax.random.randint(rng_key, (2,), 0, width)
    rng_key, _ = jax.random.split(rng_key)
    pos2 = jax.random.randint(rng_key, (2,), 0, width)
    while jnp.array_equal(pos1, pos2):
        rng_key, _ = jax.random.split(rng_key)
        pos2 = jax.random.randint(rng_key, (2,), 0, width)

    # set p1 troop and base
    board = board.at[pos1[0], pos1[1], 0].set(5)
    board = board.at[pos1[0], pos1[1], 3].set(1)
    # set p2 troop and base
    board = board.at[pos2[0], pos2[1], 1].set(5)
    board = board.at[pos2[0], pos2[1], 3].set(1)

    # set random neutral bases
    for i in range(params.num_neutral_bases):
        rng_key, _ = jax.random.split(rng_key)
        pos = jax.random.randint(rng_key, (2,), 0, width)
        while jnp.array_equal(pos, pos1) or jnp.array_equal(pos, pos2):
            rng_key, _ = jax.random.split(rng_key)
            pos = jax.random.randint(rng_key, (2,), 0, width)
        # set random number of neutral troops
        rng_key, _ = jax.random.split(rng_key)
        num_troops = jax.random.randint(rng_key, (), params.neutral_bases_min_troops, params.neutral_bases_max_troops)
        board = board.at[pos[0], pos[1], 2].set(num_troops)
        board = board.at[pos[0], pos[1], 3].set(1)

    # set random neutral troops
    for i in range(params.num_neutral_troops_start):
        rng_key, _ = jax.random.split(rng_key)
        pos = jax.random.randint(rng_key, (2,), 0, width)
        while jnp.array_equal(pos, pos1) or jnp.array_equal(pos, pos2):
            rng_key, _ = jax.random.split(rng_key)
            pos = jax.random.randint(rng_key, (2,), 0, width)
        # set random number of neutral troops
        rng_key, _ = jax.random.split(rng_key)
        num_troops = jax.random.randint(rng_key, shape=(), minval=params.neutral_bases_min_troops, maxval=params.neutral_bases_max_troops)
        board = board.at[pos[0], pos[1], 2].set(num_troops)

    return EnvState(board=board)

rng_key = jax.random.PRNGKey(0)
params = EnvConfig(board_width=10, board_height=10, num_neutral_bases=4, num_neutral_troops_start=8, neutral_bases_min_troops=1, neutral_bases_max_troops=10)
state = init_state(rng_key, params)
visualize_board(state)

### Validation

In [None]:
def assert_valid_state(state: EnvState) -> None:
    # Check that the board is of the right shape
    chex.assert_shape(state.board, (10, 10, 4))
    # Check that the number of troops and bases are integers
    chex.assert_type(state.board, jnp.integer)
    # Check that all values are non-negative.
    assert jnp.all(state.board >= 0), "Board has negative values."
    
    # For tiles that are bases, ensure at least one troop.
    base_valid = jnp.where(state.board[..., 3] == 1, jnp.sum(state.board[..., :3]) > 0, True)
    assert jnp.all(base_valid), "Some bases do not have any troops."
    
    # Check that no tile has multiple bases (channel 3 at most 1).
    no_multiple_bases = state.board[..., 3] <= 1
    assert jnp.all(no_multiple_bases), "Some tiles have multiple bases."
    
    # Check that no tile has troops from multiple players (only one channel from 0 to 2 can be over 0).
    no_multiple_troops = jnp.sum(state.board[..., :3] > 0, axis=-1) <= 1
    assert jnp.all(no_multiple_troops), "Some tiles have troops from multiple players."

assert_valid_state(state)

### Move/Step

In [None]:
def move(state: EnvState, player: int, x: int, y: int, action: int) -> EnvState:
    board = state.board
    if board.shape[0] <= x or board.shape[1] <= y:
        print("Out of bounds")
        return state
    if board[y, x, player] < 2:
        print("Not enough troops")
        return state
    target_x, target_y = x, y
    if action == 0:
        target_y = y - 1
    elif action == 1:
        target_x = x + 1
    elif action == 2:
        target_y = y + 1
    elif action == 3:
        target_x = x - 1

    # Check if the target is within bounds
    within_x = target_x >= 0 and target_x < board.shape[1]
    within_y = target_y >= 0 and target_y < board.shape[0]
    if not within_x or not within_y:
        return state
    
    # Check if the target has opponent troops
    if board[target_y, target_x, (player + 1) % 2] > 0:
        target_troops = board[target_y, target_x, (player + 1) % 2]
        opponent = (player + 1) % 2
    # Check if the target has neutral troops
    elif board[target_y, target_x, 2] > 0:
        target_troops = board[target_y, target_x, 2 % 2]
        opponent = 2
    else:
        target_troops = 0
        opponent = None
        
    sorce_troops = board[y, x, player]
    if opponent is None:
        board = board.at[target_y, target_x, player].set(board[y, x, player] - 1 + board[target_y, target_x, player])
        board = board.at[y, x, player].set(1)
    elif target_troops > sorce_troops:
        board = board.at[target_y, target_x, opponent].set(target_troops - sorce_troops + 1)
        board = board.at[y, x, player].set(1)
    else:
        board = board.at[target_y, target_x, opponent].set(0)
        board = board.at[y, x, player].set(sorce_troops - target_troops)
        if board[y, x, player] > 1:
            board = board.at[target_y, target_x, player].set(board[y, x, player] - 1)
            board = board.at[y, x, player].set(1)
        
    return EnvState(board=board, time=state.time)

state = move(state, player=1, x=1, y=1, action=1)
state = move(state, player=0, x=0, y=9, action=0)
visualize_board(state)

In [None]:
def increase_troops(state: EnvState) -> EnvState:
    # We only increase troops for player 1 and player 2
    board = state.board
    bonus_troops = state.time == 0
    for i in range(board.shape[0]):
        for j in range(board.shape[1]):
            for k in range(2):
                # Increase troops for all places with troops if bonus troops
                if board[i, j, k] > 0:
                    board = board.at[i, j, k].set(board[i, j, k] + bonus_troops)
                    # Increse troops for all bases
                    if board[i, j, 3] > 0:
                        board = board.at[i, j, k].set(board[i, j, k] + 1)
    # Decrese time and increase to 10 if bonus troops
    time = state.time - 1 + bonus_troops * 10
    return EnvState(board=board, time=time)

In [None]:
print(state.time)
state = increase_troops(state)
visualize_board(state)
print(state.time)

### Random policy

In [None]:
def get_legal_moves(state: EnvState, player: int) -> jnp.ndarray:
    board = state.board
    legal_moves = jnp.zeros((board.shape[0], board.shape[1], 4), dtype=jnp.bool_)
    for i in range(board.shape[0]):
        for j in range(board.shape[1]):
            if board[i, j, player] > 1:
                legal_moves = legal_moves.at[i, j, 0].set(i > 0)
                legal_moves = legal_moves.at[i, j, 1].set(j < board.shape[1] - 1)
                legal_moves = legal_moves.at[i, j, 2].set(i < board.shape[0] - 1)
                legal_moves = legal_moves.at[i, j, 3].set(j > 0)
    return legal_moves

legal_moves = get_legal_moves(state, 1)
print(legal_moves)

### Random play

In [None]:
# Play a game with random moves for 1000 steps
# Visualize the board interactivly
state = init_state(rng_key, params)

for i in range(5):
    player = i % 2
    legal_moves = get_legal_moves(state, player)
    if jnp.sum(legal_moves) == 0:
        print(f"Player {player} has no legal moves.")
    else:
        # get one random legal move
        legal_indices = np.argwhere(legal_moves)
        move_idx = np.random.randint(0, len(legal_indices))
        move_ = legal_indices[move_idx]
        # print(move_)
        state = move(state, player, move_[1], move_[0], move_[2])
    state = increase_troops(state)
    # visualize_board(state)


In [None]:
visualize_board(state)

### Benchmarking

In [None]:
jit_move = jax.jit(move)

rng_key = jax.random.PRNGKey(0)
params = EnvConfig(board_width=10, board_height=10, num_neutral_bases=4, num_neutral_troops_start=8, neutral_bases_min_troops=1, neutral_bases_max_troops=10)
init_state = init_state(rng_key, params)
visualize_board(init_state)

state = jit_move(init_state, player=1, x=1, y=1, action=1)