Import Libaries and Packages

In [1]:
import jax
import jax.numpy as jnp
from memo import memo
from functools import cache
from itertools import product

Magic Numbers for Roles, Actions and Intents (readability)

In [2]:
ATTACKER = 0
DEFENDER = 1
HEALER = 2

ATTACK = 0
DEFEND = 1
HEAL = 2
NUM_ACTIONS = 3

ENEMY_NO_ATTACK_INTENT = 0
ENEMY_ATTACK_INTENT = 1

Define environment parameters
* Max Team HP (also assumed to be team init HP)
* Max Enemy HP (also assumed to be enemy init HP)
* Players and their stats
* Role assignment under consideration
* boss damage
* enemy's attack probability in a given round
* epsilon, randomness in action
* time horizon for Markov Game simulation

In [3]:
TeamMaxHP       = 10
EnemyMaxHP      = 10

player_stats    = jnp.array([
    [2,2,2],        # Player 1 stats: [Attack, Defense, Healing]
    [2,2,2],        # Player 2 stats: [Attack, Defense, Healing]
    [2,2,2],        # Player 3 stats: [Attack, Defense, Healing]
])

role_assignment = [ATTACKER, DEFENDER, HEALER]
boss_damage = 4
EnemyAttackProb = 1.0

EPSILON = 0.1       # action randomness
HORIZON = 10        # simulation steps

Markov Game Formalization of the Game

* $H_t \in \{0, ..., 10\}$,           team HP
* $H_e \in \{0, ..., 10\}$,           enemy HP
* $i \in \{0,1\}$,                    enemy's attacking intent at a given round (1 for "will attack", 0 otherwise).
* $S = i \times H_t \times H_e$,             all possible HP and intent combinations
* $A = Actions^3$, Actions = {Attack, Defend, Heal}
* $T: S \times A \rightarrow S$,             Transition function
* $R: S \rightarrow \mathbb{R}$,    State reward function, $100 \times (1 - T/H)$ for dead enemy, alive team. $T$ is the number of turns needed to beat the enemy; else, 0
* role-specific policies $\pi(a|s)$

State Creation

+ This creates a grid-world-esque setup, where the team "teleports" to different cells depending on team and enemy HPs

In [4]:
H = TeamMaxHP + 1  # All possible Team HP values
W = EnemyMaxHP + 1  # All possible Enemy HP values

# A flattened cuboid (len(Intent)(2) * H(11) * W(11))
# Or, two squares, one on top of another
# top one, S[0,:,:] has numbers corresp to 'no intent of attack' ~ [0,120] (in this example)
# bottom one, S[1,:,:] has the ones that indicate 'attack intent' ~ [121,241] (in this example)

S               = jnp.arange(2 * H * W)
Action_profiles = jnp.array(list(product(range(NUM_ACTIONS), repeat=3)))  # All possible joint actions (3 players)
A               = jnp.arange(len(Action_profiles))                        # Action space (set of indices into Actions)

Role Policies

In [5]:
@jax.jit
def get_hp_and_intent_from_state(s):
    i = s // (H * W) # to which 'square' (see above) does s belong to?
    rem = s % (H * W)

    team_hp = rem // W
    enemy_hp = rem % W
    return team_hp, enemy_hp, i

@jax.jit
def fighter_policy(action):
    return jnp.where(
        action == ATTACK,
        1.0 - EPSILON,
        EPSILON / (NUM_ACTIONS - 1),
    )

@jax.jit
def defender_policy(action, intent):
    # If opponent intends to attack → defend
    # Else → attack
    preferred_action = jnp.where(
        intent == ENEMY_ATTACK_INTENT,
        DEFEND,
        ATTACK,
    )

    return jnp.where(
        action == preferred_action,
        1.0 - EPSILON,
        EPSILON / (NUM_ACTIONS - 1),
    )

@jax.jit
def healer_policy(action, team_hp):
    # If team HP is low → heal
    # Else → attack
    preferred_action = jnp.where(
        team_hp < (0.5 * TeamMaxHP),
        HEAL,
        ATTACK,
    )

    return jnp.where(
        action == preferred_action,
        1.0 - EPSILON,
        EPSILON / (NUM_ACTIONS - 1),
    )

@jax.jit
def choose_policy(role, action, intent, team_hp):
    return jnp.where(
        role == ATTACKER,
        fighter_policy(action),
        jnp.where(
            role == DEFENDER,
            defender_policy(action, intent),
            healer_policy(action, team_hp),
        ),
    )

@jax.jit
def action_profile_prob(s, a, r0, r1, r2):
    team_hp, _, intent = get_hp_and_intent_from_state(s)
    a = Action_profiles[a]  # decode joint action
    prob = jnp.ones(3)
    prob = prob.at[0].set(choose_policy(r0, a[0], intent, team_hp))
    prob = prob.at[1].set(choose_policy(r1, a[1], intent, team_hp))
    prob = prob.at[2].set(choose_policy(r2, a[2], intent, team_hp))
    prob = jnp.prod(prob)
    return prob

Transition Function

+ helpers to convert state to actual HPs and back

In [6]:
@jax.jit
def T(s, a, s_, r0, r1, r2):
    """ Transition function T(s, a, s_) = P(s_ | s, a) """

    # extract HPs and intent from s, current state
    prob_act = action_profile_prob(s, a, r0, r1, r2)
    a = Action_profiles[a]  # decode joint action

    team_hp, enemy_hp, i = get_hp_and_intent_from_state(s)

    # Calculate Total Attack, Max Defense, Total Heal

    total_attack = jnp.sum(jnp.where((a == ATTACK), player_stats[:, ATTACK],0))
    max_defense = jnp.max(jnp.where((a == DEFEND), player_stats[:, DEFEND],0))
    total_heal = jnp.sum(jnp.where((a == HEAL), player_stats[:, HEAL],0))

    # Update HPs
    new_enemy_hp = jnp.maximum(0, enemy_hp - total_attack)

    # No damage taken if enemy does not attack
    damage_incoming = jnp.where(
        i == ENEMY_ATTACK_INTENT,
        jnp.maximum(0, boss_damage - max_defense),
        0
    )

    new_team_hp = team_hp - damage_incoming + total_heal
    new_team_hp = jnp.maximum(0, jnp.minimum(TeamMaxHP, new_team_hp))

    # extract HPs and intent from s_, future state to compare and decide transition probability
    s_team_hp, s_enemy_hp, s_intent = get_hp_and_intent_from_state(s_)

    prob =  ((new_team_hp == s_team_hp) & (new_enemy_hp == s_enemy_hp)) * \
                (
                    (
                        ((1 - EnemyAttackProb) * (s_intent == ENEMY_NO_ATTACK_INTENT)) +   # enemy does not attack, 1 - p
                        (EnemyAttackProb * (s_intent == ENEMY_ATTACK_INTENT))             # enemy attacks, p
                    )
                ) * prob_act

    return prob

Reward Function, termination check, and gamma

In [7]:
@jax.jit
def R(s, t):
    team_hp, enemy_hp, _ = get_hp_and_intent_from_state(s)
    return ((team_hp != 0) & (enemy_hp == 0)) * 100 * (t/HORIZON)

@jax.jit
def is_terminal(s):
    team_hp, enemy_hp, _ = get_hp_and_intent_from_state(s)
    return jnp.logical_or(team_hp == 0, enemy_hp == 0) # if either entity dies

State-Action Function $Q(s, a)$

In [8]:
@cache
@memo
def V[s:S](t, r0, r1, r2):
  observer: knows(s)
  observer: chooses(a in A, wpp = action_profile_prob(s, a, r0, r1, r2))
  observer: draws(s_ in S, wpp = T(s, a, s_, r0, r1, r2))
  return E[
        R(s, t) + 
            (
                0.0 if t == 0 else                                              # recursion depth reached
                0.0 if is_terminal(s) else                                      # terminal state (either entity dead)
                V[observer.s_](t - 1, r0, r1, r2)                               # continue to recurse
            ) 
    ]

In [9]:
V(0, *role_assignment)
vals = V(HORIZON, *role_assignment)
jnp.set_printoptions(linewidth=100, precision=2, suppress=True)
print(vals.reshape((2, H, W)))

[[[  0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.  ]
  [100.    89.9   89.9   87.98  87.98  77.62  77.62  66.64  66.64  55.13  55.13]
  [100.    89.9   89.9   87.98  87.98  77.62  77.62  66.64  66.64  55.13  55.13]
  [100.    89.9   89.9   88.15  88.15  79.81  79.81  75.94  75.94  66.56  66.56]
  [100.    89.9   89.9   88.15  88.15  79.81  79.81  75.94  75.94  66.56  66.56]
  [100.    89.99  89.99  89.72  89.72  87.21  87.21  79.17  79.17  75.49  75.49]
  [100.    89.99  89.99  89.72  89.72  87.21  87.21  79.17  79.17  75.49  75.49]
  [100.    89.99  89.99  89.72  89.72  87.21  87.21  79.2   79.2   75.82  75.82]
  [100.    89.99  89.99  89.72  89.72  87.21  87.21  79.2   79.2   75.82  75.82]
  [100.    89.99  89.99  89.72  89.72  87.21  87.21  79.2   79.2   75.83  75.83]
  [100.    89.99  89.99  89.72  89.72  87.21  87.21  79.2   79.2   75.83  75.83]]

 [[  0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.  ]
  [100.    76.57  76.57  5

In [10]:
ROLE_MAP = {0: 'ATTACKER', 1: 'DEFENDER', 2: 'HEALER'}

role_combos = list(product(range(3), repeat=3))
value_dict = {}

for roles in role_combos:

    values = V(0, *roles)
    values = V(HORIZON, *roles)
    values = values.reshape((2, H, W))

    start_values = values[:, -1, -1].tolist()

    role_names = tuple(ROLE_MAP[r] for r in roles)
    value_dict[role_names] = start_values

sorted_strategies = sorted(value_dict.items(), key=lambda item: item[1], reverse=True)

print(f"{'Role Assignment':<45} | {'No Init Attack':<20} | {'Init Attack':<20}")
print("-" * 90)

for strategy, value in sorted_strategies:

    formatted_roles = [f"'{role}'" for role in strategy]
    strategy_str = "(" + ", ".join(formatted_roles) + ")"

    val_no_attack = value[0]
    val_attack = value[1]

    print(f"{strategy_str:<45} : {val_no_attack:.4f}{'':<14} | {val_attack:.4f}")

Role Assignment                               | No Init Attack       | Init Attack         
------------------------------------------------------------------------------------------
('ATTACKER', 'HEALER', 'HEALER')              : 78.8484               | 78.8361
('HEALER', 'ATTACKER', 'HEALER')              : 78.8484               | 78.8361
('HEALER', 'HEALER', 'ATTACKER')              : 78.8484               | 78.8361
('ATTACKER', 'ATTACKER', 'HEALER')            : 78.8484               | 78.8184
('ATTACKER', 'HEALER', 'ATTACKER')            : 78.8484               | 78.8184
('HEALER', 'ATTACKER', 'ATTACKER')            : 78.8484               | 78.8184
('HEALER', 'HEALER', 'HEALER')                : 78.8483               | 78.7606
('ATTACKER', 'ATTACKER', 'ATTACKER')          : 78.8478               | 78.3330
('ATTACKER', 'DEFENDER', 'HEALER')            : 75.8265               | 69.5600
('ATTACKER', 'HEALER', 'DEFENDER')            : 75.8265               | 69.5600
('HEALER', 'ATTAC