In [1]:
#development.ipynb
import environment_MARL
import data_classes

import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any, Tuple, Union, Dict
import distrax


import functools

class ScannedRNN(nn.Module):
    @functools.partial(
        nn.scan,
        variable_broadcast="params",
        in_axes=0,
        out_axes=0,
        split_rngs={"params": False},
    )
    @nn.compact
    def __call__(self, carry, x):
        """Applies the module."""
        rnn_state = carry
        ins, resets = x
        rnn_state = jnp.where(
            resets[:, np.newaxis],
            self.initialize_carry(ins.shape[0], ins.shape[1]),
            rnn_state,
        )
        new_rnn_state, y = nn.GRUCell(features=ins.shape[1])(rnn_state, ins)
        return new_rnn_state, y

    @staticmethod
    def initialize_carry(batch_size, hidden_size):
        # Use a dummy key since the default state init fn is just zeros.
        cell = nn.GRUCell(features=hidden_size)
        return cell.initialize_carry(jax.random.PRNGKey(0), (batch_size, hidden_size))


class ActorRNN(nn.Module):
    action_dim: Sequence[int]
    config: Dict

    @nn.compact
    def __call__(self, hidden, x):
        obs, dones, avail_actions = x
        embedding = nn.Dense(
            128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(obs)
        embedding = nn.relu(embedding)

        rnn_in = (embedding, dones)
        hidden, embedding = ScannedRNN()(hidden, rnn_in)

        actor_mean = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(
            embedding
        )
        actor_mean = nn.relu(actor_mean)
        action_logits = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
        unavail_actions = 1 - avail_actions
        action_logits = action_logits - (unavail_actions * 1e10)

        pi = distrax.Categorical(logits=action_logits)

        return hidden, pi


class CriticRNN(nn.Module):
    
    @nn.compact
    def __call__(self, hidden, x):
        world_state, dones = x
        embedding = nn.Dense(
            128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(world_state)
        embedding = nn.relu(embedding)
        
        rnn_in = (embedding, dones)
        hidden, embedding = ScannedRNN()(hidden, rnn_in)
        
        critic = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(
            embedding
        )
        critic = nn.relu(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
            critic
        )
        
        return hidden, jnp.squeeze(critic, axis=-1)
    
from jaxmarl.wrappers.baselines import JaxMARLWrapper
from functools import partial

class HanabiWorldStateWrapper(JaxMARLWrapper):
    
    @partial(jax.jit, static_argnums=0)
    def reset(self,
              key):
        obs, env_state = self._env.reset(key)
        obs["world_state"] = self.world_state(obs, env_state)
        return obs, env_state
    
    @partial(jax.jit, static_argnums=0)
    def step(self,
             key,
             state,
             action):
        obs, env_state, reward, done, info = self._env.step(
            key, state, action
        )
        obs["world_state"] = self.world_state(obs, state)
        return obs, env_state, reward, done, info

    @partial(jax.jit, static_argnums=0)
    def world_state(self, obs, state):
        """ 
        For each agent: [agent obs, own hand]
        """
            
        all_obs = jnp.array([obs[agent] for agent in self._env.agents])
        # hands = state.player_hands.reshape((self._env.num_agents, -1))
        return all_obs
        
    
    def world_state_size(self):
   
        return data_classes.get_observation_size(data_classes.schema) # NOTE hardcoded hand size
    
import jax
import jax.numpy as jnp
import numpy as np
import optax
from typing import Sequence, NamedTuple, Any, Tuple, Union, Dict

from flax.training.train_state import TrainState
import hydra
from omegaconf import DictConfig, OmegaConf
import jaxmarl
from jaxmarl.wrappers.baselines import LogWrapper

import wandb
from config import train_config

class Transition(NamedTuple):
    global_done: jnp.ndarray
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    world_state: jnp.ndarray
    info: jnp.ndarray
    avail_actions: jnp.ndarray


def batchify(x: dict, agent_list, num_actors):
    x = jnp.stack([x[a] for a in agent_list])
    return x.reshape((num_actors, -1))


def unbatchify(x: jnp.ndarray, agent_list, num_envs, num_actors):
    x = x.reshape((num_actors, num_envs, -1))
    return {a: x[i] for i, a in enumerate(agent_list)}


def make_train(config):
    env = environment_MARL.RL_Roguelike_JAX_MARL()
    num_actions = env.num_moves  # Get actual number of actions from environment
    config["NUM_ACTORS"] = env.num_agents * config["NUM_ENVS"]
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ACTORS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )
    config["CLIP_EPS"] = config["CLIP_EPS"] / env.num_agents if config["SCALE_CLIP_EPS"] else config["CLIP_EPS"]

    # env = FlattenObservationWrapper(env) # NOTE need a batchify wrapper
    env = HanabiWorldStateWrapper(env)
    env = LogWrapper(env)

    def linear_schedule(count):
        frac = (
            1.0
            - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
            / config["NUM_UPDATES"]
        )
        return config["LR"] * frac

    def train(rng):
        # INIT NETWORK
        actor_network = ActorRNN(num_actions, config=config)
        critic_network = CriticRNN()
        rng, _rng_actor, _rng_critic = jax.random.split(rng, 3)
        ac_init_x = (
            jnp.zeros((1, config["NUM_ENVS"], data_classes.get_observation_size(data_classes.schema))),
            jnp.zeros((1, config["NUM_ENVS"])),
            jnp.zeros((1, config["NUM_ENVS"], num_actions)),
        )
        ac_init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128)
        actor_network_params = actor_network.init(_rng_actor, ac_init_hstate, ac_init_x)
        
        cr_init_x = (
            jnp.zeros((1, config["NUM_ENVS"], env.world_state_size(),)), 
            jnp.zeros((1, config["NUM_ENVS"])),
        )
        cr_init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128)
        critic_network_params = critic_network.init(_rng_critic, cr_init_hstate, cr_init_x)
        
        if config["ANNEAL_LR"]:
            actor_tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
            critic_tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
        else:
            actor_tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(config["LR"], eps=1e-5),
            )
            critic_tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(config["LR"], eps=1e-5),
            )
        actor_train_state = TrainState.create(
            apply_fn=actor_network.apply,
            params=actor_network_params,
            tx=actor_tx,
        )
        critic_train_state = TrainState.create(
            apply_fn=actor_network.apply,
            params=critic_network_params,
            tx=critic_tx,
        )

        # INIT ENV
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = jax.vmap(env.reset, in_axes=(0,))(reset_rng)
        ac_init_hstate = ScannedRNN.initialize_carry(config["NUM_ACTORS"], 128)
        cr_init_hstate = ScannedRNN.initialize_carry(config["NUM_ACTORS"], 128)

        # TRAIN LOOP
        def _update_step(update_runner_state, unused):
            # COLLECT TRAJECTORIES
            runner_state, update_steps = update_runner_state
            
            def _env_step(runner_state, unused):
                train_states, env_state, last_obs, last_done, hstates, rng = runner_state

                # SELECT ACTION
                rng, _rng = jax.random.split(rng)
                avail_actions = jax.vmap(env.get_legal_moves)(env_state.env_state)
                avail_actions = jax.lax.stop_gradient(
                    batchify(avail_actions, env.agents, config["NUM_ACTORS"])
                )
                obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"])
                ac_in = (
                    obs_batch[np.newaxis, :],
                    last_done[np.newaxis, :],
                    avail_actions,
                )
                ac_hstate, pi = actor_network.apply(train_states[0].params, hstates[0], ac_in)
                action = pi.sample(seed=_rng)
                log_prob = pi.log_prob(action)
                env_act = unbatchify(
                    action, env.agents, config["NUM_ENVS"], env.num_agents
                )

                # VALUE
                world_state = last_obs["world_state"].reshape((config["NUM_ACTORS"],-1))
                cr_in = (
                    world_state[None, :],
                    last_done[np.newaxis, :],
                )
                cr_hstate, value = critic_network.apply(train_states[1].params, hstates[1], cr_in)

                # STEP ENV
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = jax.vmap(
                    env.step, in_axes=(0, 0, 0)
                )(rng_step, env_state, env_act)
                info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
                done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze()
                transition = Transition(
                    jnp.tile(done["__all__"], env.num_agents),
                    done_batch,
                    action.squeeze(),
                    value.squeeze(),
                    batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(),
                    log_prob.squeeze(),
                    obs_batch,
                    world_state,
                    info,
                    avail_actions,
                )
                runner_state = (train_states, env_state, obsv, done_batch, (ac_hstate, cr_hstate), rng)
                return runner_state, transition

            initial_hstates = runner_state[-2]
            runner_state, traj_batch = jax.lax.scan(
                _env_step, runner_state, None, config["NUM_STEPS"]
            )
            
            # CALCULATE ADVANTAGE
            train_states, env_state, last_obs, last_done, hstates, rng = runner_state
      
            last_world_state = last_obs["world_state"].reshape((config["NUM_ACTORS"],-1))
            cr_in = (
                last_world_state[None, :],
                last_done[np.newaxis, :],
            )
            _, last_val = critic_network.apply(train_states[1].params, hstates[1], cr_in)
            last_val = last_val.squeeze()

            def _calculate_gae(traj_batch, last_val):
                def _get_advantages(gae_and_next_value, transition):
                    gae, next_value = gae_and_next_value
                    done, value, reward = (
                        transition.global_done,
                        transition.value,
                        transition.reward,
                    )
                    delta = reward + config["GAMMA"] * next_value * (1 - done) - value
                    gae = (
                        delta
                        + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
                    )
                    return (gae, value), gae

                _, advantages = jax.lax.scan(
                    _get_advantages,
                    (jnp.zeros_like(last_val), last_val),
                    traj_batch,
                    reverse=True,
                    unroll=16,
                )
                return advantages, advantages + traj_batch.value

            advantages, targets = _calculate_gae(traj_batch, last_val)

            # UPDATE NETWORK
            def _update_epoch(update_state, unused):
                def _update_minbatch(train_states, batch_info):
                    actor_train_state, critic_train_state = train_states
                    ac_init_hstate, cr_init_hstate, traj_batch, advantages, targets = batch_info

                    def _actor_loss_fn(actor_params, init_hstate, traj_batch, gae):
                        # RERUN NETWORK
                        _, pi = actor_network.apply(
                            actor_params,
                            init_hstate.transpose(),
                            (traj_batch.obs, traj_batch.done, traj_batch.avail_actions),
                        )
                        log_prob = pi.log_prob(traj_batch.action)

                        # CALCULATE ACTOR LOSS
                        ratio = jnp.exp(log_prob - traj_batch.log_prob)
                        gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                        loss_actor1 = ratio * gae
                        loss_actor2 = (
                            jnp.clip(
                                ratio,
                                1.0 - config["CLIP_EPS"],
                                1.0 + config["CLIP_EPS"],
                            )
                            * gae
                        )
                        loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                        loss_actor = loss_actor.mean(where=(1 - traj_batch.done))
                        entropy = pi.entropy().mean(where=(1 - traj_batch.done))
                        actor_loss = (
                            loss_actor
                            - config["ENT_COEF"] * entropy
                        )
                        return actor_loss, (loss_actor, entropy)
                    
                    def _critic_loss_fn(critic_params, init_hstate, traj_batch, targets):
                        # RERUN NETWORK
                        _, value = critic_network.apply(critic_params, init_hstate.transpose(), (traj_batch.world_state,  traj_batch.done)) 
                        
                        # CALCULATE VALUE LOSS
                        value_pred_clipped = traj_batch.value + (
                            value - traj_batch.value
                        ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
                        value_losses = jnp.square(value - targets)
                        value_losses_clipped = jnp.square(value_pred_clipped - targets)
                        value_loss = (
                            0.5 * jnp.maximum(value_losses, value_losses_clipped).mean(where=(1 - traj_batch.done))
                        )
                        critic_loss = config["VF_COEF"] * value_loss
                        return critic_loss, (value_loss)

                    actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True)
                    actor_loss, actor_grads = actor_grad_fn(
                        actor_train_state.params, ac_init_hstate, traj_batch, advantages
                    )
                    critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True)
                    critic_loss, critic_grads = critic_grad_fn(
                        critic_train_state.params, cr_init_hstate, traj_batch, targets
                    )
                    
                    actor_train_state = actor_train_state.apply_gradients(grads=actor_grads)
                    critic_train_state = critic_train_state.apply_gradients(grads=critic_grads)
                    
                    total_loss = actor_loss[0] + critic_loss[0]
                    loss_info = {
                        "total_loss": total_loss,
                        "actor_loss": actor_loss[0],
                        "critic_loss": critic_loss[0],
                        "entropy": actor_loss[1][1],
                    }
                    
                    return (actor_train_state, critic_train_state), loss_info

                (
                    train_states,
                    init_hstates,
                    traj_batch,
                    advantages,
                    targets,
                    rng,
                ) = update_state
                rng, _rng = jax.random.split(rng)

                init_hstates = jax.tree_map(lambda x: jnp.reshape(
                    x, (config["NUM_STEPS"], config["NUM_ACTORS"])
                ), init_hstates)
                
                batch = (
                    init_hstates[0],
                    init_hstates[1],
                    traj_batch,
                    advantages.squeeze(),
                    targets.squeeze(),
                )
                permutation = jax.random.permutation(_rng, config["NUM_ACTORS"])

                shuffled_batch = jax.tree_util.tree_map(
                    lambda x: jnp.take(x, permutation, axis=1), batch
                )

                minibatches = jax.tree_util.tree_map(
                    lambda x: jnp.swapaxes(
                        jnp.reshape(
                            x,
                            [x.shape[0], config["NUM_MINIBATCHES"], -1]
                            + list(x.shape[2:]),
                        ),
                        1,
                        0,
                    ),
                    shuffled_batch,
                )

                train_states, loss_info = jax.lax.scan(
                    _update_minbatch, train_states, minibatches
                )
                update_state = (
                    train_states,
                    init_hstates,
                    traj_batch,
                    advantages,
                    targets,
                    rng,
                )
                return update_state, loss_info

            ac_init_hstate = initial_hstates[0][None, :].squeeze().transpose()
            cr_init_hstate = initial_hstates[1][None, :].squeeze().transpose()

            update_state = (
                train_states,
                (ac_init_hstate, cr_init_hstate),
                traj_batch,
                advantages,
                targets,
                rng,
            )
            update_state, loss_info = jax.lax.scan(
                _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
            )
            loss_info = jax.tree_map(lambda x: x.mean(), loss_info)
            
            train_states = update_state[0]
            metric = traj_batch.info
            rng = update_state[-1]

            # def callback(metric):
                
            #     wandb.log(
            #         {
            #             "returns": metric["returned_episode_returns"][-1, :].mean(),
            #             "env_step": metric["update_steps"]
            #             * config["NUM_ENVS"]
            #             * config["NUM_STEPS"],
            #         }
            #     )
                
            
            # metric["update_steps"] = update_steps
            # jax.experimental.io_callback(callback, None, metric)
            update_steps = update_steps + 1
            runner_state = (train_states, env_state, last_obs, last_done, hstates, rng)
            return (runner_state, update_steps), metric

        rng, _rng = jax.random.split(rng)
        runner_state = (
            (actor_train_state, critic_train_state),
            env_state,
            obsv,
            jnp.zeros((config["NUM_ACTORS"]), dtype=bool),
            (ac_init_hstate, cr_init_hstate),
            _rng,
        )
        runner_state, metric = jax.lax.scan(
            _update_step, (runner_state, 0), None, config["NUM_UPDATES"]
        )
        return {"runner_state": runner_state}

    return train

def main():
    """Main training function using direct config."""
    # Use the train_config directly from config.py
    config = train_config
    
    # Initialize wandb
    wandb.init(
        entity=config["ENTITY"],
        project=config["PROJECT"],
        tags=["MAPPO", "RNN", config["ENV_NAME"]],
        config=config,
        mode=config["WANDB_MODE"],
    )
    
    # Set random seed
    rng = jax.random.PRNGKey(config["SEED"])
    
    # Run training
    with jax.disable_jit(False):
        train_jit = jax.jit(make_train(config))
        out = train_jit(rng)

    return out

# Remove Hydra imports and decorators
if __name__ == "__main__":
    test = main()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


In [2]:
test['runner_state'][0][1].env_state.player.health_current

Array([100., 100., 100., 100., 100., 100., 100., 100., 100., 100., 100.,
       100., 100., 100., 100.,  40.,  10., 100.,  85., 100.,  25., 100.,
       100., 100., 100., 100., 100.,  40., 100., 100., 100., 100., 100.,
       100., 100.,  70.,  25., 100., 100., 100., 100., 100., 100., 100.,
        85., 100.,  40.,  25., 100.,  40., 100., 100.,  40., 100., 100.,
       100., 100.,  85.,  40., 100., 100.,  85., 100.,  25., 100., 100.,
       100., 100., 100., 100., 100.,  25.,  85.,  25., 100.,  70., 100.,
       100., 100.,  40., 100.,  55., 100., 100., 100., 100., 100.,  85.,
        25.,  85.,  25., 100.,   0., 100., 100., 100., 100., 100., 100.,
       100., 100., 100., 100., 100., 100.,  40., 100., 100., 100.,  25.,
        85., 100., 100., 100.,  70., 100., 100., 100., 100., 100., 100.,
       100., 100.,  85., 100., 100.,  85., 100.], dtype=float32)

In [3]:
test['runner_state'][0][1].env_state.player.ability_state_1.ability_index

Array([0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1,
       0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0,
       0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1,
       0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1,
       1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0,
       0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0], dtype=int32)

In [4]:
test['runner_state'][0][1].env_state.player.action_points_current

Array([5., 5., 5., 5., 3., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5.,
       5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5.,
       5., 5., 3., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5.,
       5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5.,
       5., 2., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5.,
       5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5.,
       5., 5., 5., 1., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5.,
       5., 5., 5., 3., 5., 4., 5., 5., 5.], dtype=float32)

In [5]:
test['runner_state'][0][1].env_state.player.movement_points_current

Array([5.        , 0.1715728 , 3.        , 3.5857863 , 2.5857863 ,
       4.        , 5.        , 5.        , 2.5857863 , 3.5857863 ,
       0.17157292, 0.1715728 , 5.        , 5.        , 5.        ,
       4.        , 0.1715728 , 5.        , 5.        , 5.        ,
       5.        , 5.        , 0.1715728 , 5.        , 3.        ,
       5.        , 1.1715728 , 5.        , 5.        , 5.        ,
       2.5857863 , 4.        , 1.1715728 , 5.        , 0.75735915,
       5.        , 0.17157269, 5.        , 5.        , 5.        ,
       5.        , 5.        , 2.5857863 , 5.        , 5.        ,
       5.        , 5.        , 5.        , 5.        , 5.        ,
       5.        , 5.        , 5.        , 5.        , 5.        ,
       5.        , 5.        , 5.        , 2.5857863 , 2.5857863 ,
       4.        , 5.        , 5.        , 5.        , 4.        ,
       2.1715727 , 4.        , 0.1715728 , 5.        , 1.1715728 ,
       5.        , 5.        , 5.        , 3.        , 2.58578

In [7]:
test['runner_state'][0][1].env_state.player.ability_state_1.current_cooldown

Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 3, 3, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0], dtype=int32)

In [6]:
# import environment_MARL

In [None]:
# import jax
# import jax.numpy as jnp
# import flax.linen as nn
# import numpy as np
# from flax.linen.initializers import constant, orthogonal
# from typing import Sequence, NamedTuple, Any, Tuple, Union, Dict
# import distrax


# import functools

# class ScannedRNN(nn.Module):
#     @functools.partial(
#         nn.scan,
#         variable_broadcast="params",
#         in_axes=0,
#         out_axes=0,
#         split_rngs={"params": False},
#     )
#     @nn.compact
#     def __call__(self, carry, x):
#         """Applies the module."""
#         rnn_state = carry
#         ins, resets = x
#         rnn_state = jnp.where(
#             resets[:, np.newaxis],
#             self.initialize_carry(ins.shape[0], ins.shape[1]),
#             rnn_state,
#         )
#         new_rnn_state, y = nn.GRUCell(features=ins.shape[1])(rnn_state, ins)
#         return new_rnn_state, y

#     @staticmethod
#     def initialize_carry(batch_size, hidden_size):
#         # Use a dummy key since the default state init fn is just zeros.
#         cell = nn.GRUCell(features=hidden_size)
#         return cell.initialize_carry(jax.random.PRNGKey(0), (batch_size, hidden_size))


# class ActorRNN(nn.Module):
#     action_dim: Sequence[int]
#     config: Dict

#     @nn.compact
#     def __call__(self, hidden, x):
#         obs, dones, avail_actions = x
#         embedding = nn.Dense(
#             128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
#         )(obs)
#         embedding = nn.relu(embedding)

#         rnn_in = (embedding, dones)
#         hidden, embedding = ScannedRNN()(hidden, rnn_in)

#         actor_mean = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(
#             embedding
#         )
#         actor_mean = nn.relu(actor_mean)
#         action_logits = nn.Dense(
#             self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
#         )(actor_mean)
#         unavail_actions = 1 - avail_actions
#         action_logits = action_logits - (unavail_actions * 1e10)

#         pi = distrax.Categorical(logits=action_logits)

#         return hidden, pi


# class CriticRNN(nn.Module):
    
#     @nn.compact
#     def __call__(self, hidden, x):
#         world_state, dones = x
#         embedding = nn.Dense(
#             128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
#         )(world_state)
#         embedding = nn.relu(embedding)
        
#         rnn_in = (embedding, dones)
#         hidden, embedding = ScannedRNN()(hidden, rnn_in)
        
#         critic = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(
#             embedding
#         )
#         critic = nn.relu(critic)
#         critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
#             critic
#         )
        
#         return hidden, jnp.squeeze(critic, axis=-1)

In [None]:
# from jaxmarl.wrappers.baselines import JaxMARLWrapper
# from functools import partial

# class HanabiWorldStateWrapper(JaxMARLWrapper):
    
#     @partial(jax.jit, static_argnums=0)
#     def reset(self,
#               key):
#         obs, env_state = self._env.reset(key)
#         obs["world_state"] = self.world_state(obs, env_state)
#         return obs, env_state
    
#     @partial(jax.jit, static_argnums=0)
#     def step(self,
#              key,
#              state,
#              action):
#         obs, env_state, reward, done, info = self._env.step(
#             key, state, action
#         )
#         obs["world_state"] = self.world_state(obs, state)
#         return obs, env_state, reward, done, info

#     @partial(jax.jit, static_argnums=0)
#     def world_state(self, obs, state):
#         """ 
#         For each agent: [agent obs, own hand]
#         """
            
#         all_obs = jnp.array([obs[agent] for agent in self._env.agents])
#         # hands = state.player_hands.reshape((self._env.num_agents, -1))
#         return all_obs
        
    
#     def world_state_size(self):
   
#         return 18 # NOTE hardcoded hand size

In [None]:
# import jax
# import jax.numpy as jnp
# import numpy as np
# import optax
# from typing import Sequence, NamedTuple, Any, Tuple, Union, Dict

# from flax.training.train_state import TrainState
# import hydra
# from omegaconf import DictConfig, OmegaConf
# import jaxmarl
# from jaxmarl.wrappers.baselines import LogWrapper

# import wandb
# from config import train_config

# class Transition(NamedTuple):
#     global_done: jnp.ndarray
#     done: jnp.ndarray
#     action: jnp.ndarray
#     value: jnp.ndarray
#     reward: jnp.ndarray
#     log_prob: jnp.ndarray
#     obs: jnp.ndarray
#     world_state: jnp.ndarray
#     info: jnp.ndarray
#     avail_actions: jnp.ndarray


# def batchify(x: dict, agent_list, num_actors):
#     x = jnp.stack([x[a] for a in agent_list])
#     return x.reshape((num_actors, -1))


# def unbatchify(x: jnp.ndarray, agent_list, num_envs, num_actors):
#     x = x.reshape((num_actors, num_envs, -1))
#     return {a: x[i] for i, a in enumerate(agent_list)}


# def make_train(config):
#     env = environment_MARL.RL_Roguelike_JAX_MARL()
#     config["NUM_ACTORS"] = env.num_agents * config["NUM_ENVS"]
#     config["NUM_UPDATES"] = (
#         config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
#     )
#     config["MINIBATCH_SIZE"] = (
#         config["NUM_ACTORS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
#     )
#     config["CLIP_EPS"] = config["CLIP_EPS"] / env.num_agents if config["SCALE_CLIP_EPS"] else config["CLIP_EPS"]

#     # env = FlattenObservationWrapper(env) # NOTE need a batchify wrapper
#     env = HanabiWorldStateWrapper(env)
#     env = LogWrapper(env)

#     def linear_schedule(count):
#         frac = (
#             1.0
#             - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
#             / config["NUM_UPDATES"]
#         )
#         return config["LR"] * frac

#     def train(rng):
#         # INIT NETWORK
#         actor_network = ActorRNN(11, config=config)
#         critic_network = CriticRNN()
#         rng, _rng_actor, _rng_critic = jax.random.split(rng, 3)
#         ac_init_x = (
#             jnp.zeros((1, config["NUM_ENVS"], 18)),
#             jnp.zeros((1, config["NUM_ENVS"])),
#             jnp.zeros((1, config["NUM_ENVS"], 11)),
#         )
#         ac_init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128)
#         actor_network_params = actor_network.init(_rng_actor, ac_init_hstate, ac_init_x)
        
#         cr_init_x = (
#             jnp.zeros((1, config["NUM_ENVS"], env.world_state_size(),)), 
#             jnp.zeros((1, config["NUM_ENVS"])),
#         )
#         cr_init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128)
#         critic_network_params = critic_network.init(_rng_critic, cr_init_hstate, cr_init_x)
        
#         if config["ANNEAL_LR"]:
#             actor_tx = optax.chain(
#                 optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
#                 optax.adam(learning_rate=linear_schedule, eps=1e-5),
#             )
#             critic_tx = optax.chain(
#                 optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
#                 optax.adam(learning_rate=linear_schedule, eps=1e-5),
#             )
#         else:
#             actor_tx = optax.chain(
#                 optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
#                 optax.adam(config["LR"], eps=1e-5),
#             )
#             critic_tx = optax.chain(
#                 optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
#                 optax.adam(config["LR"], eps=1e-5),
#             )
#         actor_train_state = TrainState.create(
#             apply_fn=actor_network.apply,
#             params=actor_network_params,
#             tx=actor_tx,
#         )
#         critic_train_state = TrainState.create(
#             apply_fn=actor_network.apply,
#             params=critic_network_params,
#             tx=critic_tx,
#         )

#         # INIT ENV
#         rng, _rng = jax.random.split(rng)
#         reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
#         obsv, env_state = jax.vmap(env.reset, in_axes=(0,))(reset_rng)
#         ac_init_hstate = ScannedRNN.initialize_carry(config["NUM_ACTORS"], 128)
#         cr_init_hstate = ScannedRNN.initialize_carry(config["NUM_ACTORS"], 128)

#         # TRAIN LOOP
#         def _update_step(update_runner_state, unused):
#             # COLLECT TRAJECTORIES
#             runner_state, update_steps = update_runner_state
            
#             def _env_step(runner_state, unused):
#                 train_states, env_state, last_obs, last_done, hstates, rng = runner_state

#                 # SELECT ACTION
#                 rng, _rng = jax.random.split(rng)
#                 avail_actions = jax.vmap(env.get_legal_moves)(env_state.env_state)
#                 avail_actions = jax.lax.stop_gradient(
#                     batchify(avail_actions, env.agents, config["NUM_ACTORS"])
#                 )
#                 obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"])
#                 ac_in = (
#                     obs_batch[np.newaxis, :],
#                     last_done[np.newaxis, :],
#                     avail_actions,
#                 )
#                 ac_hstate, pi = actor_network.apply(train_states[0].params, hstates[0], ac_in)
#                 action = pi.sample(seed=_rng)
#                 log_prob = pi.log_prob(action)
#                 env_act = unbatchify(
#                     action, env.agents, config["NUM_ENVS"], env.num_agents
#                 )

#                 # VALUE
#                 world_state = last_obs["world_state"].reshape((config["NUM_ACTORS"],-1))
#                 cr_in = (
#                     world_state[None, :],
#                     last_done[np.newaxis, :],
#                 )
#                 cr_hstate, value = critic_network.apply(train_states[1].params, hstates[1], cr_in)

#                 # STEP ENV
#                 rng, _rng = jax.random.split(rng)
#                 rng_step = jax.random.split(_rng, config["NUM_ENVS"])
#                 obsv, env_state, reward, done, info = jax.vmap(
#                     env.step, in_axes=(0, 0, 0)
#                 )(rng_step, env_state, env_act)
#                 info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
#                 done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze()
#                 transition = Transition(
#                     jnp.tile(done["__all__"], env.num_agents),
#                     done_batch,
#                     action.squeeze(),
#                     value.squeeze(),
#                     batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(),
#                     log_prob.squeeze(),
#                     obs_batch,
#                     world_state,
#                     info,
#                     avail_actions,
#                 )
#                 runner_state = (train_states, env_state, obsv, done_batch, (ac_hstate, cr_hstate), rng)
#                 return runner_state, transition

#             initial_hstates = runner_state[-2]
#             runner_state, traj_batch = jax.lax.scan(
#                 _env_step, runner_state, None, config["NUM_STEPS"]
#             )
            
#             # CALCULATE ADVANTAGE
#             train_states, env_state, last_obs, last_done, hstates, rng = runner_state
      
#             last_world_state = last_obs["world_state"].reshape((config["NUM_ACTORS"],-1))
#             cr_in = (
#                 last_world_state[None, :],
#                 last_done[np.newaxis, :],
#             )
#             _, last_val = critic_network.apply(train_states[1].params, hstates[1], cr_in)
#             last_val = last_val.squeeze()

#             def _calculate_gae(traj_batch, last_val):
#                 def _get_advantages(gae_and_next_value, transition):
#                     gae, next_value = gae_and_next_value
#                     done, value, reward = (
#                         transition.global_done,
#                         transition.value,
#                         transition.reward,
#                     )
#                     delta = reward + config["GAMMA"] * next_value * (1 - done) - value
#                     gae = (
#                         delta
#                         + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
#                     )
#                     return (gae, value), gae

#                 _, advantages = jax.lax.scan(
#                     _get_advantages,
#                     (jnp.zeros_like(last_val), last_val),
#                     traj_batch,
#                     reverse=True,
#                     unroll=16,
#                 )
#                 return advantages, advantages + traj_batch.value

#             advantages, targets = _calculate_gae(traj_batch, last_val)

#             # UPDATE NETWORK
#             def _update_epoch(update_state, unused):
#                 def _update_minbatch(train_states, batch_info):
#                     actor_train_state, critic_train_state = train_states
#                     ac_init_hstate, cr_init_hstate, traj_batch, advantages, targets = batch_info

#                     def _actor_loss_fn(actor_params, init_hstate, traj_batch, gae):
#                         # RERUN NETWORK
#                         _, pi = actor_network.apply(
#                             actor_params,
#                             init_hstate.transpose(),
#                             (traj_batch.obs, traj_batch.done, traj_batch.avail_actions),
#                         )
#                         log_prob = pi.log_prob(traj_batch.action)

#                         # CALCULATE ACTOR LOSS
#                         ratio = jnp.exp(log_prob - traj_batch.log_prob)
#                         gae = (gae - gae.mean()) / (gae.std() + 1e-8)
#                         loss_actor1 = ratio * gae
#                         loss_actor2 = (
#                             jnp.clip(
#                                 ratio,
#                                 1.0 - config["CLIP_EPS"],
#                                 1.0 + config["CLIP_EPS"],
#                             )
#                             * gae
#                         )
#                         loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
#                         loss_actor = loss_actor.mean(where=(1 - traj_batch.done))
#                         entropy = pi.entropy().mean(where=(1 - traj_batch.done))
#                         actor_loss = (
#                             loss_actor
#                             - config["ENT_COEF"] * entropy
#                         )
#                         return actor_loss, (loss_actor, entropy)
                    
#                     def _critic_loss_fn(critic_params, init_hstate, traj_batch, targets):
#                         # RERUN NETWORK
#                         _, value = critic_network.apply(critic_params, init_hstate.transpose(), (traj_batch.world_state,  traj_batch.done)) 
                        
#                         # CALCULATE VALUE LOSS
#                         value_pred_clipped = traj_batch.value + (
#                             value - traj_batch.value
#                         ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
#                         value_losses = jnp.square(value - targets)
#                         value_losses_clipped = jnp.square(value_pred_clipped - targets)
#                         value_loss = (
#                             0.5 * jnp.maximum(value_losses, value_losses_clipped).mean(where=(1 - traj_batch.done))
#                         )
#                         critic_loss = config["VF_COEF"] * value_loss
#                         return critic_loss, (value_loss)

#                     actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True)
#                     actor_loss, actor_grads = actor_grad_fn(
#                         actor_train_state.params, ac_init_hstate, traj_batch, advantages
#                     )
#                     critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True)
#                     critic_loss, critic_grads = critic_grad_fn(
#                         critic_train_state.params, cr_init_hstate, traj_batch, targets
#                     )
                    
#                     actor_train_state = actor_train_state.apply_gradients(grads=actor_grads)
#                     critic_train_state = critic_train_state.apply_gradients(grads=critic_grads)
                    
#                     total_loss = actor_loss[0] + critic_loss[0]
#                     loss_info = {
#                         "total_loss": total_loss,
#                         "actor_loss": actor_loss[0],
#                         "critic_loss": critic_loss[0],
#                         "entropy": actor_loss[1][1],
#                     }
                    
#                     return (actor_train_state, critic_train_state), loss_info

#                 (
#                     train_states,
#                     init_hstates,
#                     traj_batch,
#                     advantages,
#                     targets,
#                     rng,
#                 ) = update_state
#                 rng, _rng = jax.random.split(rng)

#                 init_hstates = jax.tree_map(lambda x: jnp.reshape(
#                     x, (config["NUM_STEPS"], config["NUM_ACTORS"])
#                 ), init_hstates)
                
#                 batch = (
#                     init_hstates[0],
#                     init_hstates[1],
#                     traj_batch,
#                     advantages.squeeze(),
#                     targets.squeeze(),
#                 )
#                 permutation = jax.random.permutation(_rng, config["NUM_ACTORS"])

#                 shuffled_batch = jax.tree_util.tree_map(
#                     lambda x: jnp.take(x, permutation, axis=1), batch
#                 )

#                 minibatches = jax.tree_util.tree_map(
#                     lambda x: jnp.swapaxes(
#                         jnp.reshape(
#                             x,
#                             [x.shape[0], config["NUM_MINIBATCHES"], -1]
#                             + list(x.shape[2:]),
#                         ),
#                         1,
#                         0,
#                     ),
#                     shuffled_batch,
#                 )

#                 train_states, loss_info = jax.lax.scan(
#                     _update_minbatch, train_states, minibatches
#                 )
#                 update_state = (
#                     train_states,
#                     init_hstates,
#                     traj_batch,
#                     advantages,
#                     targets,
#                     rng,
#                 )
#                 return update_state, loss_info

#             ac_init_hstate = initial_hstates[0][None, :].squeeze().transpose()
#             cr_init_hstate = initial_hstates[1][None, :].squeeze().transpose()

#             update_state = (
#                 train_states,
#                 (ac_init_hstate, cr_init_hstate),
#                 traj_batch,
#                 advantages,
#                 targets,
#                 rng,
#             )
#             update_state, loss_info = jax.lax.scan(
#                 _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
#             )
#             loss_info = jax.tree_map(lambda x: x.mean(), loss_info)
            
#             train_states = update_state[0]
#             metric = traj_batch.info
#             rng = update_state[-1]

#             # def callback(metric):
                
#             #     wandb.log(
#             #         {
#             #             "returns": metric["returned_episode_returns"][-1, :].mean(),
#             #             "env_step": metric["update_steps"]
#             #             * config["NUM_ENVS"]
#             #             * config["NUM_STEPS"],
#             #         }
#             #     )
                
            
#             # metric["update_steps"] = update_steps
#             # jax.experimental.io_callback(callback, None, metric)
#             update_steps = update_steps + 1
#             runner_state = (train_states, env_state, last_obs, last_done, hstates, rng)
#             return (runner_state, update_steps), metric

#         rng, _rng = jax.random.split(rng)
#         runner_state = (
#             (actor_train_state, critic_train_state),
#             env_state,
#             obsv,
#             jnp.zeros((config["NUM_ACTORS"]), dtype=bool),
#             (ac_init_hstate, cr_init_hstate),
#             _rng,
#         )
#         runner_state, metric = jax.lax.scan(
#             _update_step, (runner_state, 0), None, config["NUM_UPDATES"]
#         )
#         return {"runner_state": runner_state}

#     return train

# def main():
#     """Main training function using direct config."""
#     # Use the train_config directly from config.py
#     config = train_config
    
#     # Initialize wandb
#     wandb.init(
#         entity=config["ENTITY"],
#         project=config["PROJECT"],
#         tags=["MAPPO", "RNN", config["ENV_NAME"]],
#         config=config,
#         mode=config["WANDB_MODE"],
#     )
    
#     # Set random seed
#     rng = jax.random.PRNGKey(config["SEED"])
    
#     # Run training
#     with jax.disable_jit(False):
#         train_jit = jax.jit(make_train(config))
#         out = train_jit(rng)

#     return out

# # Remove Hydra imports and decorators
# if __name__ == "__main__":
#     test = main()

ScopeParamShapeError: Initializer expected to generate shape (18, 128) but got shape (249, 128) instead for parameter "kernel" in "/Dense_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)

In [None]:
# class ActionInfo:
#     def __init__(self, id: int, name: str, create_fn):
#         self.id = id
#         self.name = name 
#         self.create_fn = create_fn

# class ActionRegistry:
#     def __init__(self):
#         self.actions = {}  # name -> ActionInfo
#         self.id_lookup = {} # id -> ActionInfo
#         self._next_id = 0
    
#     def register(self, name: str, create_fn) -> None:
#         action_id = self._next_id
#         info = ActionInfo(action_id, name, create_fn)
#         self.actions[name] = info
#         self.id_lookup[action_id] = info
#         self._next_id += 1
        
#     def get_by_name(self, name: str) -> ActionInfo:
#         return self.actions.get(name)
        
#     def get_by_id(self, id: int) -> ActionInfo:
#         return self.id_lookup.get(id)

# # Replace global registry
# action_registry = ActionRegistry()

# # Update register_action function
# def register_action(name: str, create_fn):
#     action_registry.register(name, create_fn)

# # Example usage:
# """
# # Get action by ID
# action_info = action_registry.get_by_id(0)
# if action_info:
#     action = action_info.create_fn()

# # Get action by name 
# action_info = action_registry.get_by_name("SuicideAction")
# if action_info:
#     action = action_info.create_fn()
# """

'\n# Get action by ID\naction_info = action_registry.get_by_id(0)\nif action_info:\n    action = action_info.create_fn()\n\n# Get action by name \naction_info = action_registry.get_by_name("SuicideAction")\nif action_info:\n    action = action_info.create_fn()\n'

In [None]:
# import ability_actions

In [None]:
# ability_registry = ability_actions.ability_registry

In [None]:
# ability_registry.num_abilities

2

In [None]:
# import jax.numpy as jnp
# from jax import lax, debug
# from utils import euclidean_distance, is_within_bounds, is_collision, do_invalid_move, do_damage
# from actions import Action
# from data_classes import DamageType

# class AbilityRegistry:
#     def __init__(self):
#         self.abilities = []  # List of (name, create_fn) tuples
#         self.num_abilities = 0

#     def register(self, name, create_fn):
#         """Register ability with auto-incrementing index"""
#         self.abilities.append((name, create_fn))
#         self.num_abilities += 1
#         return self.num_abilities - 1  # Return index of registered ability

#     def get_by_index(self, index):
#         """Get ability create_fn by index"""
#         return self.abilities[index][1]

# # Create global registry
# ability_registry = AbilityRegistry()

# # suicide:
# # Register suicide action
# SUICIDE_ACTION_IDX = ability_registry.register("SuicideAction", lambda: SuicideAction())

# class SuicideAction(Action):
#     def __init__(self):
#         super().__init__()
#         self._base_cooldown = jnp.int32(3)  # Set cooldown to 3
#         self.ability_index = SUICIDE_ACTION_IDX
#         self.base_damage = jnp.float32(5.0)
#         self.range = jnp.float32(8.0)
#         self._ability_description = "Deal damage based on strength to an enemy and take damage yourself"
        
#     def is_valid(self, state, unit, target):
#         enough_action_points = unit.action_points_current >= 1
#         within_range = state.distance_to_enemy <= self.range
#         return jnp.logical_and(enough_action_points, within_range)

#     def _perform_action(self, state, unit, target):
#         # Generate all 8 adjacent grid positions
#         adjacent_positions = [
#             (target.location_x + dx, target.location_y + dy)
#             for dx, dy in [(-1,-1), (-1,0), (-1,1), (0,-1), (0,1), (1,-1), (1,0), (1,1)]
#         ]
        
#         # Find closest valid adjacent position
#         best_x, best_y = adjacent_positions[0]
#         best_dist = euclidean_distance(unit.location_x, unit.location_y, best_x, best_y)
        
#         def update_best_position(i, val):
#             x, y = adjacent_positions[i]
#             dist = euclidean_distance(unit.location_x, unit.location_y, x, y)
#             use_new = jnp.logical_and(
#                 is_within_bounds(x, y),
#                 dist < val[2]
#             )
#             return lax.cond(
#                 use_new,
#                 lambda _: (x, y, dist),
#                 lambda _: val,
#                 None
#             )
            
#         new_x, new_y, _ = lax.fori_loop(1, 8, update_best_position, (best_x, best_y, best_dist))

#         damage_dealt = self.base_damage + unit.strength_current

#         # Apply damage using do_damage utility
#         new_unit, new_target = do_damage(unit, target, damage_dealt, DamageType.PURE)
#         # Do base_damage to self
#         new_unit, _ = do_damage(unit, new_unit, self.base_damage, DamageType.PURE) #TODO: don't return self damage
        
#         # Update position and action points
#         new_unit = new_unit.replace(
#             action_points_current=jnp.float32(new_unit.action_points_current - 1),
#             location_x=jnp.float32(new_x),
#             location_y=jnp.float32(new_y)
#         )
        
#         # Calculate new distance
#         new_distance = euclidean_distance(new_x, new_y, target.location_x, target.location_y)
        
#         return lax.cond(
#             jnp.equal(state.player.unit_id, unit.unit_id),
#             lambda: state.replace(
#                 player=new_unit, 
#                 enemy=new_target,
#                 distance_to_enemy=new_distance
#             ),
#             lambda: state.replace(
#                 player=new_target, 
#                 enemy=new_unit,
#                 distance_to_enemy=new_distance
#             )
#         )

# # Steal Strength - reduce enemy strength and increase own strength
# # Register the new action
# STEAL_STRENGTH_IDX = ability_registry.register("StealStrengthAction", lambda: StealStrengthAction())

# class StealStrengthAction(Action):
#     def __init__(self):
#         super().__init__()
#         self.ability_index = STEAL_STRENGTH_IDX
#         self._base_cooldown = jnp.int32(1)
#         self.range = jnp.float32(4.0)
#         self.strength_steal_amount = jnp.float32(2.0)
#         self._ability_description = "Steal 2 strength from the target"

#     def is_valid(self, state, unit, target):
#         enough_action_points = unit.action_points_current >= 1
#         within_range = state.distance_to_enemy <= self.range
#         return jnp.logical_and(enough_action_points, within_range)

#     def _perform_action(self, state, unit, target):
#         # Reduce target's strength
#         new_target = target.replace(
#             strength_current=jnp.maximum(0, target.strength_current - self.strength_steal_amount)
#         )
        
#         # Increase caster's strength and reduce action points
#         new_unit = unit.replace(
#             strength_current=unit.strength_current + self.strength_steal_amount,
#             action_points_current=unit.action_points_current - 1
#         )

#         return lax.cond(
#             jnp.equal(state.player.unit_id, unit.unit_id),
#             lambda: state.replace(player=new_unit, enemy=new_target),
#             lambda: state.replace(player=new_target, enemy=new_unit)
#         )
    
# # Strength Regen - regen health based on your strength
# # Add barrier - add a barrier based on resolve
# # Mana Burn
# # Multi Attack
# # Return
# # Fury Swipes
# # Push
# # Frost Arrows
# # Stun
# # Hook
# # Lifesteal / feast
# # Int steal
# # Int based nuke
# # Armour reduction
# # Add barrier
# # Spellsteal
# # Fracture casting