In [2]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import lax
import chex
from flax import struct
from typing import Tuple, Dict
from functools import partial
from gymnax.environments.spaces import Discrete, Box
from jaxmarl.environments.multi_agent_env import MultiAgentEnv

from utils import euclidean_distance, generate_unique_pairs, get_latest_checkpoint_dir
import base_actions
from reward import reward_function
from config import train_config, env_config
# from data_classes import UnitState, EnvState, EnvParams
from render import render_game_state

In [22]:
@struct.dataclass
class UnitState:
    unit_id: int
    health_current: float
    health_max: float
    location_x: int
    location_y: int
    melee_attack_base_damage: float
    ranged_attack_base_damage: float
    melee_attack_range: float
    ranged_attack_range: float
    movement_points_current: float
    movement_points_max: float
    action_points_current: float
    action_points_max: float
    available_actions: chex.Array

# @struct.dataclass
# class TeamState:
#     hero_1: UnitState
#     hero_2: UnitState
#     hero_3: UnitState # HeroState

@struct.dataclass
class State:
    # flatten this?
    # board: chex.Array
    player: UnitState
    enemy: UnitState
    distance_to_enemy: float
    steps: int
    turn_count: int
    previous_closest_distance: float
    initial_distance: float
    cur_player_idx: chex.Array
    terminal: bool

In [6]:
low = [
    # 'board': spaces.Box(0, 1, (20, 20), jnp.int32),
    0, # 'player_health_current'
    0, # 'player_health_max'
    0, # 'player_location_x'
    0, # 'player_location_y'
    0, # 'player_melee_attack_base_damage'
    0, # 'player_ranged_attack_base_damage'
    0, # 'player_melee_attack_range'
    0, # 'player_ranged_attack_range'
    0, # 'player_movement_points_current'
    0, # 'player_movement_points_max'
    0, # 'player_action_points_current'
    0, # 'player_action_points_max'
    0, # 'enemy_health_current'
    0, # 'enemy_health_max'
    0, # 'enemy_location_x'
    0, # 'enemy_location_y'
    0, # 'distance_to_enemy'
    0, # 'turn_count'
]# + [0]*num_actions, # 'available_actions'

high = [
    jnp.finfo(jnp.float32).max,
    jnp.finfo(jnp.float32).max,
    19,
    19,
    jnp.finfo(jnp.float32).max,
    jnp.finfo(jnp.float32).max,
    jnp.finfo(jnp.float32).max,
    jnp.finfo(jnp.float32).max,
    jnp.finfo(jnp.float32).max,
    jnp.finfo(jnp.float32).max,
    jnp.finfo(jnp.float32).max,
    jnp.finfo(jnp.float32).max,
    jnp.finfo(jnp.float32).max,
    jnp.finfo(jnp.float32).max,
    19,
    19,
    jnp.finfo(jnp.float32).max,
    env_config['MAX_STEPS'],
]# + [1]*num_actions,

observation_spaces = {i: Box(low, high, (18+11,), jnp.float32) for i in [0,1]}

In [None]:
x1, y1, x2, y2 = generate_unique_pairs(key)
initial_distance = jnp.float32(euclidean_distance(x1,x2,y1,y2))

player = UnitState(
    unit_id = jnp.int32(1),
    health_current = jnp.float32(env_config['MAX_HEALTH']),
    health_max = jnp.float32(env_config['MAX_HEALTH']),
    location_x = x1,
    location_y = y1,
    melee_attack_base_damage = jnp.float32(env_config['MELEE_DAMAGE']),
    ranged_attack_base_damage = jnp.float32(env_config['RANGED_DAMAGE']),
    melee_attack_range = jnp.float32(env_config['MELEE_RANGE']),
    ranged_attack_range = jnp.float32(env_config['RANGED_RANGE']),
    movement_points_current = jnp.float32(env_config['MOVEMENT_POINTS']),
    movement_points_max = jnp.float32(env_config['MOVEMENT_POINTS']),
    action_points_current = jnp.float32(env_config['ACTION_POINTS']),
    action_points_max = jnp.float32(env_config['ACTION_POINTS']),
    available_actions = jnp.zeros(num_actions),
)

enemy = UnitState(
    unit_id = jnp.int32(-1),
    health_current = jnp.float32(env_config['MAX_HEALTH']),
    health_max = jnp.float32(env_config['MAX_HEALTH']),
    location_x = x2,
    location_y = y2,
    melee_attack_base_damage = jnp.float32(env_config['MELEE_DAMAGE']),
    ranged_attack_base_damage = jnp.float32(env_config['RANGED_DAMAGE']),
    melee_attack_range = jnp.float32(env_config['MELEE_RANGE']),
    ranged_attack_range = jnp.float32(env_config['RANGED_RANGE']),
    movement_points_current = jnp.float32(env_config['MOVEMENT_POINTS']),
    movement_points_max = jnp.float32(env_config['MOVEMENT_POINTS']),
    action_points_current = jnp.float32(env_config['ACTION_POINTS']),
    action_points_max = jnp.float32(env_config['ACTION_POINTS']),
    available_actions = jnp.zeros(num_actions),
)

state = State(
    player = player,
    enemy = enemy,
    distance_to_enemy = initial_distance,
    steps = jnp.int32(0),
    turn_count = jnp.int32(0),
    previous_closest_distance = initial_distance,
    initial_distance = initial_distance,
    cur_player_idx = jnp.zeros(self.num_agents).at[0].set(1), # TODO make random
    terminal = False,
)

state = state.replace(
    player = player.replace(available_actions = self.get_available_actions(state.player, state.enemy, state)),
    enemy = enemy.replace(available_actions = self.get_available_actions(state.enemy, state.player, state)),
    )

In [None]:
"""Generate individual agent's observation"""

## TODO can we use the following pattern to do this better?
# actions = jnp.array([actions[i] for i in self.agents])
# aidx = jnp.nonzero(state.cur_player_idx, size=1)[0][0]
# action = actions.at[aidx].get()

def get_player_obs(state: State) -> chex.Array:
    return jnp.array(
        [
            state.player.health_current,
            state.player.health_max,
            state.player.location_x,
            state.player.location_y,
            state.player.melee_attack_base_damage,
            state.player.ranged_attack_base_damage,
            state.player.melee_attack_range,
            state.player.ranged_attack_range,
            state.player.movement_points_current,
            state.player.movement_points_max,
            state.player.action_points_current,
            state.player.action_points_max,
            state.enemy.health_current,
            state.enemy.health_max,
            state.enemy.location_x,
            state.enemy.location_y,
            state.distance_to_enemy,
            state.turn_count,
            ]
            )
def get_enemy_obs(state: State) -> chex.Array:
    return jnp.array(
        [
            state.enemy.health_current,
            state.enemy.health_max,
            state.enemy.location_x,
            state.enemy.location_y,
            state.enemy.melee_attack_base_damage,
            state.enemy.ranged_attack_base_damage,
            state.enemy.melee_attack_range,
            state.enemy.ranged_attack_range,
            state.enemy.movement_points_current,
            state.enemy.movement_points_max,
            state.enemy.action_points_current,
            state.enemy.action_points_max,
            state.player.health_current,
            state.player.health_max,
            state.player.location_x,
            state.player.location_y,
            state.distance_to_enemy,
            state.turn_count,
            ]
            )

# return jax.lax.cond(
#     aidx == 0,
#     lambda _: get_player_obs(state),
#     lambda _: get_enemy_obs(state),
#     operand=None
#     )

In [10]:
float_max = jnp.finfo(jnp.float32).max

schema = {
  "UnitState": {
    "unit_id": {
        "type": int,
        "default": 0,
        "obs": False
    },
    "action_points_base": {
      "type": float,
      "default": 5,
      "obs": True,
      "low": 0,
      "high": 20,
    },
    "action_points_current": {
      "type": float,
      "default": 5,
      "obs": True,
      "low": 0,
      "high": 20,
    },
    "action_points_max": {
      "type": float,
      "default": 5,
      "obs": True,
      "low": 0,
      "high": 20,
    },
    "movement_points_base": {
      "type": float,
      "default": 5,
      "obs": True,
      "low": 0,
      "high": 50,
    },
    "movement_points_current": {
      "type": float,
      "default": 5,
      "obs": True,
      "low": 0,
      "high": 50,
    },
    "movement_points_max": {
      "type": float,
      "default": 5,
      "obs": True,
      "low": 0,
      "high": 50,
    },
    "movement_points_percentage": {
      "type": float,
      "default": 1,
      "obs": True,
      "low": 0,
      "high": 1,
    },
    "movement_points_multiplier": {
      "type": float,
      "default": 1,
      "obs": True,
      "low": 0,
      "high": 10,
    },
    "health_current": {
      "type": float,
      "default": 100,
      "obs": True,
      "low": 0,
      "high": 1000,
    },
    "health_max": {
      "type": float,
      "default": 100,
      "obs": True,
      "low": 0,
      "high": 1000,
    },
    "health_percentage": {
      "type": float,
      "default": 1,
      "obs": True,
      "low": 0,
      "high": 1,
    },
    "health_regeneration": {
      "type": float,
      "default": 1,
      "obs": True,
      "low": -1000,
      "high": 1000,
    },
    "health_regeneration_rate": {
      "type": float,
      "default": 0,
      "obs": True,
      "low": -1,
      "high": 1,
    },
    "mana_current": {
      "type": float,
      "default": 100,
      "obs": True,
      "low": 0,
      "high": 1000,
    },
    "mana_max": {
      "type": float,
      "default": 100,
      "obs": True,
      "low": 0,
      "high": 1000,
    },
    "mana_percentage": {
      "type": float,
      "default": 1,
      "obs": True,
      "low": 0,
      "high": 1,
    },
    "mana_regeneration": {
      "type": float,
      "default": 5,
      "obs": True,
      "low": 0,
      "high": 1000,
    },
    "mana_regeneration_rate": {
      "type": float,
      "default": 0,
      "obs": True,
      "low": -1,
      "high": 1,
    },
    "barrier_current": {
      "type": float,
      "default": 10,
      "obs": True,
      "low": 0,
      "high": 1000,
    },
    "barrier_status_reduction": {
      "type": float,
      "default": 1,
      "obs": True,
      "low": 0,
      "high": 100,
    },
    "barrier_max": {
      "type": float,
      "default": 100,
      "obs": True,
      "low": 0,
      "high": 1000,
    },
    "barrier_percentage": {
      "type": float,
      "default": 0,
      "obs": True,
      "low": 0,
      "high": 1,
    },
    "barrier_regeneration": {
      "type": float,
      "default": 0,
      "obs": True,
      "low": 0,
      "high": 1000,
    },
    "barrier_regeneration_rate": {
      "type": float,
      "default": 0,
      "obs": True,
      "low": -1,
      "high": 1,
    },
    "physical_block": {
      "type": float,
      "default": 0,
      "obs": True,
      "low": -1000,
      "high": 1000,
    },
    "magical_block": {
      "type": float,
      "default": 0,
      "obs": True,
      "low": -1000,
      "high": 1000,
    },
    "physical_resist": {
      "type": float,
      "default": 0,
      "obs": True,
      "low": -1,
      "high": 1,
    },
    "magical_resist": {
      "type": float,
      "default": 0,
      "obs": True,
      "low": -1,
      "high": 1,
    },
    "physical_immunity": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "magical_immunity": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "physical_evasion": {
      "type": float,
      "default": 0,
      "obs": True,
      "low": 0,
      "high": 1,
    },
    "magical_evasion": {
      "type": float,
      "default": 0,
      "obs": True,
      "low": 0,
      "high": 1,
    },
    "physical_damage_return": {
      "type": float,
      "default": 0,
      "obs": True,
      "low": -1000,
      "high": 1000,
    },
    "physical_damage_return_rate": {
      "type": float,
      "default": 0,
      "obs": True,
      "low": -1,
      "high": 1,
    },
    "magical_damage_return": {
      "type": float,
      "default": 0,
      "obs": True,
      "low": -1000,
      "high": 1000,
    },
    "magical_damage_return_rate": {
      "type": float,
      "default": 0,
      "obs": True,
      "low": -1,
      "high": 1,
    },
    "pure_damage_return": {
      "type": float,
      "default": 0,
      "obs": True,
      "low": -1000,
      "high": 1000,
    },
    "pure_damage_return_rate": {
      "type": float,
      "default": 0,
      "obs": True,
      "low": -1,
      "high": 1,
    },
    "base_strength": {
      "type": float,
      "default": 10,
      "obs": True,
      "low": 0,
      "high": 1000,
    },
    "strength_current": {
      "type": float,
      "default": 10,
      "obs": True,
      "low": 0,
      "high": 1000,
    },
    "base_agility": {
      "type": float,
      "default": 10,
      "obs": True,
      "low": 0,
      "high": 1000,
    },
    "agility_current": {
      "type": float,
      "default": 10,
      "obs": True,
      "low": 0,
      "high": 1000,
    },
    "base_intelligence": {
      "type": float,
      "default": 10,
      "obs": True,
      "low": 0,
      "high": 1000,
    },
    "intelligence_current": {
      "type": float,
      "default": 10,
      "obs": True,
      "low": 0,
      "high": 1000,
    },
    "base_resolve": {
      "type": float,
      "default": 10,
      "obs": True,
      "low": 0,
      "high": 1000,
    },
    "resolve_current": {
      "type": float,
      "default": 10,
      "obs": True,
      "low": 0,
      "high": 1000,
    },
    "attack_damage_amplification": {
      "type": float,
      "default": 1,
      "obs": True,
      "low": -1,
      "high": 10,
    },
    "melee_base_attack_damage": {
      "type": float,
      "default": 225,
      "obs": True,
      "low": -1000,
      "high": 1000,
    },
    "melee_attack_range": {
      "type": float,
      "default": 2.6,
      "obs": True,
      "low": 0,
      "high": 10,
    },
    "melee_crit_chance": {
      "type": float,
      "default": 0,
      "obs": True,
      "low": 0,
      "high": 1,
    },
    "melee_crit_modifier": {
      "type": float,
      "default": 1.5,
      "obs": True,
      "low": 1,
      "high": 10,
    },
    "ranged_base_attack_damage": {
      "type": float,
      "default": 15,
      "obs": True,
      "low": -1000,
      "high": 1000,
    },
    "ranged_attack_range": {
      "type": float,
      "default": 5,
      "obs": True,
      "low": 0,
      "high": 20,
    },
    "ranged_crit_chance": {
      "type": float,
      "default": 0,
      "obs": True,
      "low": 0,
      "high": 1,
    },
    "ranged_crit_modifier": {
      "type": float,
      "default": 1.5,
      "obs": True,
      "low": 1,
      "high": 3,
    },
    "damage_amplification": {
      "type": float,
      "default": 1,
      "obs": True,
      "low": 0,
      "high": 10,
    },
    "silenced_flag": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "silenced_duration": {
      "type": int,
      "default": 0,
      "obs": True,
      "low": 0,
      "high": 30,
    },
    "silenced_permanent": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "silenced_dispelable": {
      "type": bool,
      "default": True,
      "obs": True,
      "low": False,
      "high": True,
    },
    "silenced_needs_greater_dispel": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "broken_flag": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "broken_duration": {
      "type": int,
      "default": 0,
      "obs": True,
      "low": 0,
      "high": 30,
    },
    "broken_permanent": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "broken_dispelable": {
      "type": bool,
      "default": True,
      "obs": True,
      "low": False,
      "high": True,
    },
    "broken_needs_greater_dispel": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "stunned_flag": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "stunned_duration": {
      "type": int,
      "default": 0,
      "obs": True,
      "low": 0,
      "high": 30,
    },
    "stunned_permanent": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "stunned_dispelable": {
      "type": bool,
      "default": True,
      "obs": True,
      "low": False,
      "high": True,
    },
    "stunned_needs_greater_dispel": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "feared_flag": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "feared_duration": {
      "type": int,
      "default": 0,
      "obs": True,
      "low": 0,
      "high": 30,
    },
    "feared_permanent": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "feared_dispelable": {
      "type": bool,
      "default": True,
      "obs": True,
      "low": False,
      "high": True,
    },
    "feared_needs_greater_dispel": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "taunted_flag": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "taunted_duration": {
      "type": int,
      "default": 0,
      "obs": True,
      "low": 0,
      "high": 30,
    },
    "taunted_permanent": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "taunted_dispelable": {
      "type": bool,
      "default": True,
      "obs": True,
      "low": False,
      "high": True,
    },
    "taunted_needs_greater_dispel": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "invisible_flag": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "invisible_duration": {
      "type": int,
      "default": 0,
      "obs": True,
      "low": 0,
      "high": 30,
    },
    "invisible_permanent": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "invisible_dispelable": {
      "type": bool,
      "default": True,
      "obs": True,
      "low": False,
      "high": True,
    },
    "invisible_needs_greater_dispel": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "sleeping_flag": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "sleeping_duration": {
      "type": int,
      "default": 0,
      "obs": True,
      "low": 0,
      "high": 30,
    },
    "sleeping_permanent": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "sleeping_dispelable": {
      "type": bool,
      "default": True,
      "obs": True,
      "low": False,
      "high": True,
    },
    "sleeping_needs_greater_dispel": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "ethereal_flag": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "ethereal_duration": {
      "type": int,
      "default": 0,
      "obs": True,
      "low": 0,
      "high": 30,
    },
    "ethereal_permanent": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "ethereal_dispelable": {
      "type": bool,
      "default": True,
      "obs": True,
      "low": False,
      "high": True,
    },
    "ethereal_needs_greater_dispel": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "untargetable_flag": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "untargetable_duration": {
      "type": int,
      "default": 0,
      "obs": True,
      "low": 0,
      "high": 30,
    },
    "untargetable_permanent": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "untargetable_dispelable": {
      "type": bool,
      "default": True,
      "obs": True,
      "low": False,
      "high": True,
    },
    "untargetable_needs_greater_dispel": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "hidden_flag": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "hidden_duration": {
      "type": int,
      "default": 0,
      "obs": True,
      "low": 0,
      "high": 30,
    },
    "hidden_permanent": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "hidden_dispelable": {
      "type": bool,
      "default": True,
      "obs": True,
      "low": False,
      "high": True,
    },
    "hidden_needs_greater_dispel": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "phased_flag": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "phased_duration": {
      "type": int,
      "default": 0,
      "obs": True,
      "low": 0,
      "high": 30,
    },
    "phased_permanent": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "phased_dispelable": {
      "type": bool,
      "default": True,
      "obs": True,
      "low": False,
      "high": True,
    },
    "phased_needs_greater_dispel": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "blind_flag": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "blind_duration": {
      "type": int,
      "default": 0,
      "obs": True,
      "low": 0,
      "high": 30,
    },
    "blind_permanent": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "blind_dispelable": {
      "type": bool,
      "default": True,
      "obs": True,
      "low": False,
      "high": True,
    },
    "blind_needs_greater_dispel": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "disarmed_flag": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "disarmed_duration": {
      "type": int,
      "default": 0,
      "obs": True,
      "low": 0,
      "high": 30,
    },
    "disarmed_permanent": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "disarmed_dispelable": {
      "type": bool,
      "default": True,
      "obs": True,
      "low": False,
      "high": True,
    },
    "disarmed_needs_greater_dispel": {
      "type": bool,
      "default": False,
      "obs": True,
      "low": False,
      "high": True,
    },
    "available_actions": {
        "type": chex.Array,
        "default": jnp.zeros(11),
        "obs": False
    },
  },
  # "TeamState": {
  # },
  "GameState": {
    "player": {
        "type": "UnitState",
        "obs": False
        },
    "enemy": {
        "type": "UnitState",
        "obs": False
        },
    "distance_to_enemy": {
        "type": float,
        "default": 0,
        "obs": True,
        "low": 0,
        "high": 30
        },
    "steps": {
        "type": int,
        "default": 0,
        "obs": True,
        "low": 0,
        "high": 1000
        },
    "turn_count": {
        "type": int,
        "default": 0,
        "obs": True,
        "low": 0,
        "high": 100
        },
    "previous_closest_distance": {
        "type": float,
        "default": 0,
        "obs": False
        },
    "initial_distance": {
        "type": float,
        "default": 0,
        "obs": False
        },
    "cur_player_idx": {
        "type": chex.Array,
        "default": jnp.zeros(2).at[0].set(1),
        "obs": False
        },
    "terminal": {
        "type": bool,
        "default": False,
        "obs": False
        },
  },
}


In [36]:
def create_struct_dataclass(schema):
    classes = {}
    for key, value in schema.items():
        annotations = {}
        for field, attrs in value.items():
            dtype = attrs["type"]#eval(attrs["type"]) if attrs["type"] in dir(jnp) else attrs["type"]
            annotations[field] = dtype
        classes[key] = struct.dataclass(type(key, (object,), {'__annotations__': annotations}))
    return classes

schema_classes = create_struct_dataclass(schema)  # Assuming 'json_schema' is loaded from the JSON file
UnitState = schema_classes['UnitState']
GameState = schema_classes['GameState']

{'unit_id': <class 'int'>, 'action_points_base': <class 'float'>, 'action_points_current': <class 'float'>, 'action_points_max': <class 'float'>, 'movement_points_base': <class 'float'>, 'movement_points_current': <class 'float'>, 'movement_points_max': <class 'float'>, 'movement_points_percentage': <class 'float'>, 'movement_points_multiplier': <class 'float'>, 'health_current': <class 'float'>, 'health_max': <class 'float'>, 'health_percentage': <class 'float'>, 'health_regeneration': <class 'float'>, 'health_regeneration_rate': <class 'float'>, 'mana_current': <class 'float'>, 'mana_max': <class 'float'>, 'mana_percentage': <class 'float'>, 'mana_regeneration': <class 'float'>, 'mana_regeneration_rate': <class 'float'>, 'barrier_current': <class 'float'>, 'barrier_status_reduction': <class 'float'>, 'barrier_max': <class 'float'>, 'barrier_percentage': <class 'float'>, 'barrier_regeneration': <class 'float'>, 'barrier_regeneration_rate': <class 'float'>, 'physical_block': <class 'fl

In [38]:
#NOW MAKE IT CONCAT FROM COMPENENTS EG 

def get_low_high(schema, class_name):
    low = jnp.array([info["low"] for field, info in schema[class_name].items() if info.get("obs", False)])
    high = jnp.array([info["high"] for field, info in schema[class_name].items() if info.get("obs", False)])
    return low, high

# Example usage:
low, high = get_low_high(schema, 'UnitState')


In [41]:
high

Array([  20,   20,   20,   50,   50,   50,    1,   10, 1000, 1000,    1,
       1000,    1, 1000, 1000,    1, 1000,    1, 1000,  100, 1000,    1,
       1000,    1, 1000, 1000,    1,    1,    1,    1,    1,    1, 1000,
          1, 1000,    1, 1000,    1, 1000, 1000, 1000, 1000, 1000, 1000,
       1000, 1000,   10, 1000,   10,    1,   10, 1000,   20,    1,    3,
         10,    1,   30,    1,    1,    1,    1,   30,    1,    1,    1,
          1,   30,    1,    1,    1,    1,   30,    1,    1,    1,    1,
         30,    1,    1,    1,    1,   30,    1,    1,    1,    1,   30,
          1,    1,    1,    1,   30,    1,    1,    1,    1,   30,    1,
          1,    1,    1,   30,    1,    1,    1,    1,   30,    1,    1,
          1,    1,   30,    1,    1,    1,    1,   30,    1,    1,    1],      dtype=int32)

In [43]:
def initialize_game_state(schema, UnitState, State):
    default_unit_state = UnitState(**{k: eval(v['default']) if 'jnp' in v['default'] else v['default']
                                      for k, v in schema['UnitState'].items()})
    player = default_unit_state
    enemy = default_unit_state.replace(unit_id=-1)
    state = State(player=player, enemy=enemy, **{k: eval(v['default']) if 'jnp' in v['default'] else v['default']
                                                 for k, v in schema['State'].items() if k not in ['player', 'enemy']})
    return state

# Initialize the game state
game_state = initialize_game_state(schema, UnitState, State)


TypeError: argument of type 'int' is not iterable

In [49]:
# from jax import jit
# import jax.numpy as jnp
# from jax import struct
global schema

# Define the function that initializes game state
def initialize_game_state(UnitState, GameState):
    # Construct the default UnitState by evaluating the provided defaults
    default_unit_state = UnitState(**{
        k: eval(v['default']) if isinstance(v['default'], str) and 'jnp' in v['default'] else v['default']
        for k, v in schema['UnitState'].items()
    })
    
    # Create player and enemy unit states
    player = default_unit_state
    enemy = default_unit_state.replace(unit_id=-1)  # Assuming 'unit_id' differentiates players and enemies
    
    # Construct the GameState by providing defaults for non-unit state properties
    game_state = GameState(
        player=player,
        enemy=enemy,
        **{k: eval(v['default']) if isinstance(v['default'], str) and 'jnp' in v['default'] else v['default']
           for k, v in schema['GameState'].items() if k not in ['player', 'enemy']}
    )
    return game_state

# JIT compile the initialization function to ensure compatibility with JAX transformations
jit_initialize_game_state = jax.jit(initialize_game_state, static_argnums=(0, 1, 2))

# Example usage
game_state = jit_initialize_game_state(UnitState, GameState)


In [50]:
game_state

GameState(player=UnitState(unit_id=Array(0, dtype=int32, weak_type=True), action_points_base=Array(5, dtype=int32, weak_type=True), action_points_current=Array(5, dtype=int32, weak_type=True), action_points_max=Array(5, dtype=int32, weak_type=True), movement_points_base=Array(5, dtype=int32, weak_type=True), movement_points_current=Array(5, dtype=int32, weak_type=True), movement_points_max=Array(5, dtype=int32, weak_type=True), movement_points_percentage=Array(1, dtype=int32, weak_type=True), movement_points_multiplier=Array(1, dtype=int32, weak_type=True), health_current=Array(100, dtype=int32, weak_type=True), health_max=Array(100, dtype=int32, weak_type=True), health_percentage=Array(1, dtype=int32, weak_type=True), health_regeneration=Array(1, dtype=int32, weak_type=True), health_regeneration_rate=Array(0, dtype=int32, weak_type=True), mana_current=Array(100, dtype=int32, weak_type=True), mana_max=Array(100, dtype=int32, weak_type=True), mana_percentage=Array(1, dtype=int32, weak_t