# Part 1: State Value Function

Import Libaries and Packages

In [1]:
import jax
import jax.numpy as jnp
from memo import memo
from functools import cache
from itertools import product

Magic Numbers for Roles, Actions, and Intents \
Lists containing roles, actions \
All possible role combinations

In [2]:
ATTACKER = 0
DEFENDER = 1
HEALER = 2

ATTACK = 0
DEFEND = 1
HEAL = 2

ENEMY_NO_ATTACK_INTENT = 0
ENEMY_ATTACK_INTENT = 1

ROLES   = [ATTACKER, DEFENDER, HEALER]
ACTIONS = [ATTACK, DEFEND, HEAL]

ROLE_COMBOS = list(product(ROLES, repeat=3))

Define environment parameters
* Max Team HP (also assumed to be team init HP)
* Max Enemy HP (also assumed to be enemy init HP)
* Players and their stats
* Role assignment under consideration
* boss damage
* enemy's attack probability in every round
* epsilon, randomness in action
* time horizon for Markov Game simulation

In [3]:
TeamMaxHP       = 10
EnemyMaxHP      = 31

player_stats    = jnp.array([
    [2,2,2],        # Player 1 stats: [Attack, Defense, Healing]
    [2,2,2],        # Player 2 stats: [Attack, Defense, Healing]
    [2,2,2],        # Player 3 stats: [Attack, Defense, Healing]
])

role_assignment = [ATTACKER, ATTACKER, ATTACKER]
boss_damage = 2
EnemyAttackProb = 1.0

EPSILON = 1e-10       # action randomness
HORIZON = 10        # simulation steps

Markov Game Formalization of the Game

* $H_t \in \{0, ..., 10\}$,           team HP
* $H_e \in \{0, ..., 10\}$,           enemy HP
* $i \in \{0,1\}$,                    enemy's attacking intent at a given round (1 for "will attack", 0 otherwise).
* $S = i \times H_t \times H_e$,             all possible intent and HP combinations
* $A = Actions^3$, Actions = {Attack, Defend, Heal}
* $T: S \times A \rightarrow S$,             Transition function
* $R: S \rightarrow \mathbb{R}$,    State reward function, $100 \times (1 - T/H)$ for dead enemy, alive team. $T$ is the number of turns it took to beat the enemy; else, 0
* role-specific policies $\pi(a|s)$

State Creation

+ This creates a grid-world-esque setup, where the team "teleports" to different cells depending on enemy intent, team and enemy HPs

In [4]:
H = TeamMaxHP + 1  # All possible Team HP values
W = EnemyMaxHP + 1  # All possible Enemy HP values

# A flattened cuboid (len(Intent)(2) * H(11) * W(11))
# Or, two squares, one on top of another
# top one, S[0,:,:] has numbers corresp to 'no intent of attack' ~ [0,120] (in this example)
# bottom one, S[1,:,:] has the ones that indicate 'attack intent' ~ [121,241] (in this example)

S               = jnp.arange(2 * H * W)
action_profiles = jnp.array(list(product(ACTIONS, repeat=3)))  # All possible joint actions (3 players)
A               = jnp.arange(len(action_profiles))                        # Action space (set of indices into Actions)

Role Policies

In [5]:
@jax.jit
def get_intent_and_hps_from_state(s):
    i = s // (H * W) # to which 'square' (see above) does s belong to?
    rem = s % (H * W)

    team_hp = rem // W
    enemy_hp = rem % W
    return i, team_hp, enemy_hp

@jax.jit
def fighter_policy(action):
    return jnp.where(
        action == ATTACK,
        1.0 - EPSILON,
        EPSILON / (len(ACTIONS) - 1),
    )

@jax.jit
def defender_policy(action, intent):
    # If opponent intends to attack → defend
    # Else → attack
    preferred_action = jnp.where(
        intent == ENEMY_ATTACK_INTENT,
        DEFEND,
        ATTACK,
    )

    return jnp.where(
        action == preferred_action,
        1.0 - EPSILON,
        EPSILON / (len(ACTIONS) - 1),
    )

@jax.jit
def healer_policy(action, team_hp):
    # If team HP is low → heal
    # Else → attack
    preferred_action = jnp.where(
        team_hp < (0.5 * TeamMaxHP),
        HEAL,
        ATTACK,
    )

    return jnp.where(
        action == preferred_action,
        1.0 - EPSILON,
        EPSILON / (len(ACTIONS) - 1),
    )

@jax.jit
def choose_policy(role, action, intent, team_hp):
    return jnp.where(
        role == ATTACKER,
        fighter_policy(action),
        jnp.where(
            role == DEFENDER,
            defender_policy(action, intent),
            healer_policy(action, team_hp),
        ),
    )

@jax.jit
def action_profile_prob(s, a, r0, r1, r2):
    intent, team_hp, _  = get_intent_and_hps_from_state(s)
    a = action_profiles[a]  # decode joint action
    # r = role_combos[r]      # decode role profile

    prob = jnp.ones(3)
    prob = prob.at[0].set(choose_policy(r0, a[0], intent, team_hp))
    prob = prob.at[1].set(choose_policy(r1, a[1], intent, team_hp))
    prob = prob.at[2].set(choose_policy(r2, a[2], intent, team_hp))
    prob = jnp.prod(prob)
    return prob

Transition Function

In [6]:
@jax.jit
def T(s, a, s_, r0, r1, r2):
    """ Transition function T(s, a, s_) = P(s_ | s, a) """

    # extract HPs and intent from s, current state
    prob_act = action_profile_prob(s, a, r0, r1, r2)
    a = action_profiles[a]  # decode joint action

    i, team_hp, enemy_hp = get_intent_and_hps_from_state(s)

    # Calculate Total Attack, Max Defense, Total Heal

    total_attack = jnp.sum(jnp.where((a == ATTACK), player_stats[:, ATTACK],0))
    max_defense = jnp.max(jnp.where((a == DEFEND), player_stats[:, DEFEND],0))
    total_heal = jnp.sum(jnp.where((a == HEAL), player_stats[:, HEAL],0))

    # Update HPs
    new_enemy_hp = jnp.maximum(0, enemy_hp - total_attack)

    # No damage taken if enemy does not attack
    damage_incoming = jnp.where(
        i == ENEMY_ATTACK_INTENT,
        jnp.maximum(0, boss_damage - max_defense),
        0
    )

    new_team_hp = team_hp - damage_incoming + total_heal
    new_team_hp = jnp.maximum(0, jnp.minimum(TeamMaxHP, new_team_hp))

    # extract HPs and intent from s_, future state to compare and decide transition probability
    s_intent, s_team_hp, s_enemy_hp  = get_intent_and_hps_from_state(s_)

    prob =  ((new_team_hp == s_team_hp) & (new_enemy_hp == s_enemy_hp)) * \
                (
                    (
                        ((1 - EnemyAttackProb) * (s_intent == ENEMY_NO_ATTACK_INTENT)) +   # enemy does not attack, 1 - p
                        (EnemyAttackProb * (s_intent == ENEMY_ATTACK_INTENT))             # enemy attacks, p
                    )
                ) * prob_act

    return prob

Reward Function, termination check \
\
**NOTE:** $R: S \rightarrow \mathbb{R}$,    State reward function, $100 \times (1 - T/H)$ for dead enemy, alive team. $T$ is the number of turns it took to beat the enemy; else, 0. \
\
The code, however, computes $100 \times (1 - T/H)$. This is because the memo recurses *back* in time.
 

In [7]:
@jax.jit
def R(s, t):
    _, team_hp, enemy_hp  = get_intent_and_hps_from_state(s)
    return ((team_hp != 0) & (enemy_hp == 0)) * 100 * (t/HORIZON)

@jax.jit
def is_terminal(s):
    _, team_hp, enemy_hp = get_intent_and_hps_from_state(s)
    return jnp.logical_or(team_hp == 0, enemy_hp == 0) # if either entity dies

State-Value Function $V(s)$ (iterating over all role combos for parallel processing)

In [8]:
@cache
@memo
def V[s:S](t, r0, r1, r2):
  observer: knows(s)
  observer: chooses(a in A, wpp = action_profile_prob(s, a, r0, r1, r2))
  observer: draws(s_ in S, wpp = T(s, a, s_, r0, r1, r2))
  return E[
        R(s, t) + 
            (
                0.0 if t == 0 else             # recursion depth reached
                0.0 if is_terminal(s) else     # terminal state (either entity dead)
                V[observer.s_](t - 1, r0, r1, r2)       # continue to recurse
            ) 
    ]

In [9]:
V(0, *role_assignment)
values = V(HORIZON, *role_assignment).reshape((2, H, W))

In [10]:
ROLE_MAP = {0: 'ATTACKER', 1: 'DEFENDER', 2: 'HEALER'}

value_dict = {}

roles_to_values = {}

for roles in ROLE_COMBOS:

    values = V(0, *roles)
    values = V(HORIZON, *roles)
    values = values.reshape((2, H, W))

    start_values = values[:, -1, -1].tolist()

    role_names = tuple(ROLE_MAP[r] for r in roles)
    value_dict[role_names] = start_values

    roles_to_values[roles] = values

sorted_strategies = sorted(value_dict.items(), key=lambda item: item[1], reverse=True)

print(f"{'Role Assignment':<45} | {'No Init Attack':<20} | {'Init Attack':<20}")
print("-" * 90)

for strategy, value in sorted_strategies:

    formatted_roles = [f"'{role}'" for role in strategy]
    strategy_str = "(" + ", ".join(formatted_roles) + ")"

    val_no_attack = value[0]
    val_attack = value[1]

    print(f"{strategy_str:<45} : {val_no_attack:.4f}{'':<14} | {val_attack:.4f}")

Role Assignment                               | No Init Attack       | Init Attack         
------------------------------------------------------------------------------------------
('ATTACKER', 'ATTACKER', 'HEALER')            : 40.0000               | 30.0000
('ATTACKER', 'HEALER', 'ATTACKER')            : 40.0000               | 30.0000
('ATTACKER', 'HEALER', 'HEALER')              : 40.0000               | 30.0000
('HEALER', 'ATTACKER', 'ATTACKER')            : 40.0000               | 30.0000
('HEALER', 'ATTACKER', 'HEALER')              : 40.0000               | 30.0000
('HEALER', 'HEALER', 'ATTACKER')              : 40.0000               | 30.0000
('HEALER', 'HEALER', 'HEALER')                : 30.0000               | 20.0000
('ATTACKER', 'ATTACKER', 'DEFENDER')          : 20.0000               | 20.0000
('ATTACKER', 'DEFENDER', 'ATTACKER')          : 20.0000               | 20.0000
('ATTACKER', 'DEFENDER', 'HEALER')            : 20.0000               | 20.0000
('ATTACKER', 'HEA

# Part 2: Simulation

memo model for role inference

In [11]:
@memo
def role_inference[r0 : ROLES, r1 : ROLES, r2 : ROLES](role_prior:... , obs_a0, obs_a1, obs_a2, s):
    observer: knows(r0, r1, r2) # Push array axis variables into observer's frame
    observer: thinks[
        team: assigned(r0 in ROLES, r1 in ROLES, r2 in ROLES, wpp=get_element(role_prior, r0, r1, r2)), # Assign roles to each player
        team: chooses(a0 in ACTIONS, wpp=role_policy(r0, a0, s)), # Choose player 0's action (a0) according to their role (r0)
        team: chooses(a1 in ACTIONS, wpp=role_policy(r1, a1, s)), # Choose player 1's action (a1) according to their role (r1)
        team: chooses(a2 in ACTIONS, wpp=role_policy(r2, a2, s))  # Choose player 2's action (a2) according to their role (r2)
    ]
    observer: observes_that [team.a0 == obs_a0] # Observe player 0's action
    observer: observes_that [team.a1 == obs_a1] # Observe player 1's action
    observer: observes_that [team.a2 == obs_a2] # Observe player 2's action
    return observer[Pr[r0 == team.r0 and r1 == team.r1 and r2 == team.r2]]

# Due to some of memo's limitations, we also need to introduce a helper function that extracts an element from a (3D) array:
@jax.jit
def get_element(array, i0, i1, i2):
    return array[i0, i1, i2]

def role_policy(role, action, state):
    intent, team_hp, _ = get_intent_and_hps_from_state(state)
    return choose_policy(role, action, intent, team_hp)

In [12]:
role_prior = jnp.ones((3, 3, 3))
actions = [ATTACK, DEFEND, HEAL]
role_probs = role_inference(role_prior, *actions, 2 * H * W - 1, print_table=True)
role_prior = jnp.ones((3, 3, 3))

+-----------+-----------+-----------+---------------------+
| r0: ROLES | r1: ROLES | r2: ROLES | role_inference      |
+-----------+-----------+-----------+---------------------+
| 0         | 0         | 0         | 0.0                 |
| 0         | 0         | 1         | 0.0                 |
| 0         | 0         | 2         | 0.0                 |
| 0         | 1         | 0         | 0.1666666567325592  |
| 0         | 1         | 1         | 0.1666666567325592  |
| 0         | 1         | 2         | 0.1666666567325592  |
| 0         | 2         | 0         | 0.0                 |
| 0         | 2         | 1         | 0.0                 |
| 0         | 2         | 2         | 0.0                 |
| 1         | 0         | 0         | 0.0                 |
| 1         | 0         | 1         | 0.0                 |
| 1         | 0         | 2         | 0.0                 |
| 1         | 1         | 0         | 0.0                 |
| 1         | 1         | 1         | 0.

Softmax sampling of roles

$r_i \sim Softmax \left( \mathbb{E}_{r \sim \text{role\_prior}(r_i, r_{-i})} \left[ V^{r}(s)\right]\right)$

In [13]:
def softmax_dist_over_roles(i, s):

    global ROLES
    other_agents = [a for a in range(3) if a != i]
    intent, team_hp, enemy_hp = get_intent_and_hps_from_state(s)

    other_probs = jnp.sum(role_prior, axis=i)
    other_probs = other_probs / jnp.sum(other_probs)

    expected_values = jnp.zeros(3)

    for r_i in ROLES:
        ev = 0.0
        for r_j in ROLES:
            for r_k in ROLES:

                curr_roles = [None] * 3
                curr_roles[i] = r_i
                curr_roles[other_agents[0]] = r_j
                curr_roles[other_agents[1]] = r_k

                weight = other_probs[r_j, r_k]
                ev += weight * roles_to_values[tuple(curr_roles)][intent, team_hp, enemy_hp]

        expected_values = expected_values.at[r_i].set(ev)

    return jax.nn.softmax(expected_values)

Based on role policies defined before, chose action

In [14]:
def choose_action(role, intent, team_hp, key):
    # Compute action probabilities for this role
    probs = jnp.array([
        choose_policy(role, ATTACK, intent, team_hp),
        choose_policy(role, DEFEND, intent, team_hp),
        choose_policy(role, HEAL, intent, team_hp),
    ])

    # (Optional but safe) normalize in case of numerical drift
    probs = probs / jnp.sum(probs)

    action = jax.random.choice(key, len(ACTIONS), p=probs)
    return action


Change game state variables when actions occur

In [15]:
def game_step(intent, team_hp, enemy_hp, action_profile, key):

    total_attack = jnp.sum(jnp.where((action_profile == ATTACK), player_stats[:, ATTACK],0))
    max_defense = jnp.max(jnp.where((action_profile == DEFEND), player_stats[:, DEFEND],0))
    total_heal = jnp.sum(jnp.where((action_profile == HEAL), player_stats[:, HEAL],0))

    new_enemy_hp = jnp.maximum(0, enemy_hp - total_attack)

    damage_incoming = jnp.where(intent == ENEMY_ATTACK_INTENT, jnp.maximum(0, boss_damage - max_defense), 0)

    new_team_hp = team_hp - damage_incoming + total_heal
    new_team_hp = jnp.maximum(0, jnp.minimum(TeamMaxHP, new_team_hp))

    intent_key, _ = jax.random.split(key)
    new_intent = jnp.where(jax.random.uniform(intent_key) < EnemyAttackProb, 
                           ENEMY_ATTACK_INTENT, ENEMY_NO_ATTACK_INTENT)

    return new_intent, new_team_hp, new_enemy_hp

In [16]:
def simulate_game(n_steps, initial_state):
    
    global role_prior 
    key = jax.random.PRNGKey(42) 
    intent, team_hp, enemy_hp = get_intent_and_hps_from_state(initial_state)
    
    state_history, role_history, action_history = [], [], []
    belief_history = [role_prior] 
    
    # Initialize roles for the first round
    current_roles = [0, 0, 0] 

    for t in range(n_steps):
        current_state = intent * (H * W) + team_hp * W + enemy_hp
        state_history.append(current_state)
        
        if is_terminal(current_state):
            break
            
        # 1. Sample Roles ONLY every 2 turns
        if t % 2 == 0:
            new_roles = current_roles[:]
            i = 0   # only player 0 is doing the sampling, rest are stubborn
            key, subkey = jax.random.split(key)
            probs = softmax_dist_over_roles(i, current_state)
            role = jax.random.categorical(subkey, jnp.log(probs))
            new_roles[0] = int(role)
            current_roles = new_roles
        
        # Always append the 'active' roles for this turn to the history
        role_history.append(list(current_roles))

        # 2. Choose Actions
        current_actions = []
        for i in range(3):
            key, subkey = jax.random.split(key)
            action = choose_action(current_roles[i], intent, team_hp, key=subkey)
            current_actions.append(int(action))
        action_history.append(current_actions)

        # 3. Update Role Beliefs (Bayesian Inference)
        role_prior = role_inference(
            role_prior, 
            current_actions[0], current_actions[1], current_actions[2], 
            current_state
        )
        belief_history.append(role_prior)

        # 4. Advance Environment
        key, subkey = jax.random.split(key)
        intent, team_hp, enemy_hp = game_step(intent, team_hp, enemy_hp, jnp.array(current_actions), key=subkey)

    # Capture final state
    final_state = intent * (H * W) + team_hp * W + enemy_hp
    if not is_terminal(state_history[-1]):
        state_history.append(final_state)

    return state_history, role_history, action_history, belief_history

In [17]:
def get_state_from_intent_and_hps(intent, team_hp, enemy_hp):
    return intent * (H * W) + team_hp * W + enemy_hp

initial_state_idx = get_state_from_intent_and_hps(ENEMY_ATTACK_INTENT, TeamMaxHP, EnemyMaxHP)

states, roles, actions, beliefs = simulate_game(n_steps=20, initial_state=initial_state_idx)

In [18]:
# Table Header
print(f"{'Step':<5} | {'Intent':<10} | {'Team HP':<8} | {'Enemy HP':<8} | {'Roles (P0, P1, P2)':<22} | {'Actions':<15} | {'Enemy Atk?'}")
print("-" * 110)

# Map integers to names for readability
role_names = {0: "ATK", 1: "DEF", 2: "HEL"}
action_names = {0: "ATK", 1: "DEF", 2: "HEL"}

for t in range(len(actions)):
    # Get state info at the START of the round
    intent_val, team_hp, enemy_hp = get_intent_and_hps_from_state(states[t])
    
    # Format Roles and Actions for this step
    step_roles = f"({role_names[roles[t][0]]}, {role_names[roles[t][1]]}, {role_names[roles[t][2]]})"
    step_actions = f"({action_names[actions[t][0]]}, {action_names[actions[t][1]]}, {action_names[actions[t][2]]})"
    
    # Determine enemy behavior
    did_attack = "YES" if intent_val == ENEMY_ATTACK_INTENT else "NO"
    intent_str = "ATTACK" if intent_val == ENEMY_ATTACK_INTENT else "WAIT"
    
    print(f"{t:<5} | {intent_str:<10} | {team_hp:<8} | {enemy_hp:<8} | {step_roles:<22} | {step_actions:<15} | {did_attack}")

# Final termination row
final_intent, final_team_hp, final_enemy_hp = get_intent_and_hps_from_state(states[-1])
print("-" * 110)
print(f"FINAL | {'-':<10} | {final_team_hp:<8} | {final_enemy_hp:<8} | {'-':<22} | {'-':<15} | GAME OVER")

Step  | Intent     | Team HP  | Enemy HP | Roles (P0, P1, P2)     | Actions         | Enemy Atk?
--------------------------------------------------------------------------------------------------------------
0     | ATTACK     | 10       | 31       | (HEL, ATK, ATK)        | (ATK, ATK, ATK) | YES
1     | ATTACK     | 8        | 25       | (HEL, ATK, ATK)        | (ATK, ATK, ATK) | YES
2     | ATTACK     | 6        | 19       | (DEF, ATK, ATK)        | (DEF, ATK, ATK) | YES
3     | ATTACK     | 6        | 15       | (DEF, ATK, ATK)        | (DEF, ATK, ATK) | YES
4     | ATTACK     | 6        | 11       | (ATK, ATK, ATK)        | (ATK, ATK, ATK) | YES
5     | ATTACK     | 4        | 5        | (ATK, ATK, ATK)        | (ATK, ATK, ATK) | YES
--------------------------------------------------------------------------------------------------------------
FINAL | -          | 2        | 0        | -                      | -               | GAME OVER


In [19]:
from itertools import product

# Map indices to names
ROLE_LABELS = {0: "ATK", 1: "DEF", 2: "HEL"}
role_indices = list(product(range(3), repeat=3))

print("\n--- Belief Evolution (Probability of Role Combinations) ---")

for t, current_belief in enumerate(beliefs):
    print(f"\nStep {t} Beliefs:")
    
    # Sort combinations by probability to show the most likely ones first
    probs = []
    for r0, r1, r2 in role_indices:
        prob = current_belief[r0, r1, r2]
        combo_name = f"({ROLE_LABELS[r0]}, {ROLE_LABELS[r1]}, {ROLE_LABELS[r2]})"
        probs.append((combo_name, prob))
    
    # Sort by probability descending
    probs.sort(key=lambda x: x[1], reverse=True)
    
    for combo, p in probs:
        print(f"  {combo:<15}: {p:.4f}")


--- Belief Evolution (Probability of Role Combinations) ---

Step 0 Beliefs:
  (ATK, ATK, ATK): 1.0000
  (ATK, ATK, DEF): 1.0000
  (ATK, ATK, HEL): 1.0000
  (ATK, DEF, ATK): 1.0000
  (ATK, DEF, DEF): 1.0000
  (ATK, DEF, HEL): 1.0000
  (ATK, HEL, ATK): 1.0000
  (ATK, HEL, DEF): 1.0000
  (ATK, HEL, HEL): 1.0000
  (DEF, ATK, ATK): 1.0000
  (DEF, ATK, DEF): 1.0000
  (DEF, ATK, HEL): 1.0000
  (DEF, DEF, ATK): 1.0000
  (DEF, DEF, DEF): 1.0000
  (DEF, DEF, HEL): 1.0000
  (DEF, HEL, ATK): 1.0000
  (DEF, HEL, DEF): 1.0000
  (DEF, HEL, HEL): 1.0000
  (HEL, ATK, ATK): 1.0000
  (HEL, ATK, DEF): 1.0000
  (HEL, ATK, HEL): 1.0000
  (HEL, DEF, ATK): 1.0000
  (HEL, DEF, DEF): 1.0000
  (HEL, DEF, HEL): 1.0000
  (HEL, HEL, ATK): 1.0000
  (HEL, HEL, DEF): 1.0000
  (HEL, HEL, HEL): 1.0000

Step 1 Beliefs:
  (ATK, ATK, ATK): 0.1250
  (ATK, ATK, HEL): 0.1250
  (ATK, HEL, ATK): 0.1250
  (ATK, HEL, HEL): 0.1250
  (HEL, ATK, ATK): 0.1250
  (HEL, ATK, HEL): 0.1250
  (HEL, HEL, ATK): 0.1250
  (HEL, HEL, HEL): 0.