In [1]:
from agent_classe import *
from agent_method import *
import jax.numpy as jnp
import jax

In [2]:
@struct.dataclass
class Dice(Agent):
    @staticmethod
    def create_agent(params: Params, unique_id: int, active_state: int, agent_type: int, key:jax.random.PRNGKey):
        key, subkey = jax.random.split(key)
        
        def create_active_agent(key):
            draw = jax.random.randint(key, (1,), 1, 7)
            state_content = {'draw': draw, 'key': key, 'input_value': 0}
            return State(content=state_content)
        
        def create_inactive_agent(key):
            state_content = {'draw': jnp.array([0]), 'key': key, 'input_value': 0}
            return State(content=state_content)
        agent_state = jax.lax.cond(active_state, lambda _: create_active_agent(subkey), lambda _: create_inactive_agent(subkey), None)
        
        return Dice(params = params, unique_id = unique_id, agent_type = agent_type, 
                    active_state = active_state, state = agent_state, policy = None, age = 0.0)
    
    @staticmethod
    def step_agent(params: Params, input: Signal, dice_agent: Agent, key: jax.random.PRNGKey):
        
        def step_active_agent(dice_agent):
            old_state = dice_agent.state.content
            key, subkey = jax.random.split(old_state['key'])
            input_value = input.content['value']
            input_value = input_value+1
            draw = jax.random.randint(subkey, (1,), 1, 7)
            state_content = {'draw': draw, 'key': subkey, 'input_value': input_value}
            new_state = State(content = state_content)
            return dice_agent.replace(state = new_state, age = dice_agent.age + 1.0)
        
        def step_inactive_agent(dice_agent):
            return dice_agent
        
        new_dice_agent = jax.lax.cond(dice_agent.active_state, lambda _: step_active_agent(dice_agent), lambda _: step_inactive_agent(dice_agent), None)
        return new_dice_agent, key
    
    @staticmethod
    def add_agent(params: Params, dice_agents: Agent, idx, key: jax.random.PRNGKey):
        inactive_dice_agent = jax.tree_util.tree_map(lambda x:x[idx], dice_agents)
        useless_key, subkey = jax.random.split(inactive_dice_agent.state.content['key'])
        draw = jax.random.randint(subkey, (1,), 1, 7)
        state_content = {'draw': draw, 'key': subkey, 'input_value': jnp.array([0])}
        new_state = State(content=state_content)
        active_dice_agent = inactive_dice_agent.replace(active_state = True, state = new_state)
        return active_dice_agent, key
    
    @staticmethod
    def remove_agent(params: Params, dice_agents:Agent, idx, key: jax.random.PRNGKey):
        remove_ids = params.content['remove_ids']
        dice_agent_to_remove = jax.tree_util.tree_map(lambda x:x[remove_ids[idx]], dice_agents)
        draw = jnp.array([0])
        state_content = {'draw': draw, 'key': dice_agent_to_remove.state.content['key'], 'input_value': jnp.array([0])}
        state = State(content=state_content)
        inactive_dice_agent = dice_agent_to_remove.replace(active_state = False, state = state)
        return inactive_dice_agent, key

In [3]:
Dice_set = Agent_Set(agent = Dice, num_total_agents = 10, num_active_agents = 5, agent_type = 0)
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
Dice_set.agents = create_agents(params = None, agent_set = Dice_set, key = subkey)

AgentSet initialized


In [4]:
print(Dice_set.agents.active_state)

[1. 1. 1. 1. 1. 0. 0. 0. 0. 0.]


In [5]:
key, *step_keys = jax.random.split(key, 11)

step_keys = jnp.array(step_keys)
input_content = {'value': jnp.tile(jnp.array([0]), 10)}
input_signal = Signal(content = input_content)

print(Dice_set.agents.state.content['input_value'])
Dice_set.agents, key = step_agents(params = None, input = input_signal, agent_set = Dice_set, key = step_keys)
print(Dice_set.agents.state.content['input_value'])


[0 0 0 0 0 0 0 0 0 0]
[1 1 1 1 1 0 0 0 0 0]


In [6]:
%%timeit
Dice_set.agents, key = step_agents(params = None, input = input_signal, agent_set = Dice_set, key = step_keys)

29.9 µs ± 10.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
def is_one(dice_agent: Agent, select_params: Params):
    draw = jnp.reshape(dice_agent.state.content['draw'], (-1))
    return draw == 1
num_agents_dead, remove_indices = jit_select_agents(select_func = is_one, select_params = None, agents = Dice_set.agents)
dice_remove_params_content = {'remove_ids': remove_indices}
dice_remove_params = Params(content = dice_remove_params_content)
Dice_set.agents, key = jit_remove_agents(remove_func = Dice.remove_agent, num_agents_remove = num_agents_dead, 
                                                remove_params = dice_remove_params, agents = Dice_set.agents, key = key)
print(jnp.reshape(Dice_set.agents.state.content['draw'], (-1)))

ValueError: safe_zip() argument 2 is shorter than argument 1

In [None]:
# remove all agents who have drawn a 1
# first, select all agents who have drawn a 1
def is_one(dice_agent: Agent, select_params: Params):
    draw = jnp.reshape(dice_agent.state.content['draw'], (-1))
    return draw == 1
num_agents_dead, remove_indices = jit_select_agents(select_func = is_one, select_params = None, agents = Dice_set.agents)
dice_remove_params_content = {'remove_ids': remove_indices}
dice_remove_params = Params(content = dice_remove_params_content)
Dice_set.agents = jit_remove_agents(remove_func = Dice.remove_agent, num_agents_remove = num_agents_dead, 
                                                remove_params = dice_remove_params, agents = Dice_set.agents)
print(jnp.reshape(Dice_set.agents.state.content['draw'], (-1)))

# sort agents by active state, as new agents are ALWAYS added at the END of the set
Dice_set.agents, sorted_indices = fgx_methods.jit_sort_agents(quantity = Dice_set.agents.active_state, ascend = False, agents = Dice_set.agents)
print(jnp.reshape(Dice_set.agents.state.content['draw'], (-1)))
print(Dice_set.agents.active_state)

# add a new agent for every agent that has drawn a 6
# first, select all agents who have drawn a 6
def is_six(dice_agent: fgx_classes.Agent, select_params: fgx_classes.Params):
    draw = jnp.reshape(dice_agent.state.content['draw'], (-1))
    return draw == 6
num_agents_add, add_indices = fgx_methods.jit_select_agents(select_func = is_six, select_params = None, agents = Dice_set.agents)

# clip the number of agents to add to the number of inactive agents
num_active_agents = jnp.sum(Dice_set.agents.active_state, dtype = jnp.int32)
num_agents_add = jnp.minimum(num_agents_add, Dice_set.num_total_agents - num_active_agents)

# add the agents
Dice_set.agents, key = fgx_methods.jit_add_agents(add_func = Dice.add_agent, num_agents_add = num_agents_add, 
                                                  add_params = None, agents = Dice_set.agents, key = None)
print(jnp.reshape(Dice_set.agents.state.content['draw'], (-1)))
print(Dice_set.agents.active_state)

In [32]:
%%timeit
agents = jit_step_agents(Dice.step_agent, step_params, input_signal, dice_agent_set.agents)


266 ms ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
%%time
agents, key = jit_add_agents_1(add_func = Dice.add_agent, num_agents_add = num_agents_add, 
                                                  add_params = None, agents = dice_agent_set.agents, key = None)


CPU times: user 193 ms, sys: 5.3 ms, total: 199 ms
Wall time: 196 ms


In [20]:
%%time
agents, key = jit_add_agents_2(add_func = Dice.add_agent, num_agents_add = num_agents_add, 
                                                  add_params = None, agents = dice_agent_set.agents, key = None)


CPU times: user 375 µs, sys: 75 µs, total: 450 µs
Wall time: 434 µs


In [None]:
Dice_set.agents = fgx_methods.create_agents(params = None, agent_set = Dice_set, key = jax.random.PRNGKey(0))
print(jnp.reshape(Dice_set.agents.state.content['draw'], (-1)))

Dice_set.agents = fgx_methods.step_agents(params = None, agent_set = Dice_set, input=None)
print(jnp.reshape(Dice_set.agents.state.content['draw'], (-1)))

# remove all agents who have drawn a 1
# first, select all agents who have drawn a 1
def is_one(dice_agent: fgx_classes.Agent, select_params: fgx_classes.Params):
    draw = jnp.reshape(dice_agent.state.content['draw'], (-1))
    return draw == 1
num_agents_dead, remove_indices = fgx_methods.jit_select_agents(select_func = is_one, select_params = None, agents = Dice_set.agents)
dice_remove_params_content = {'remove_ids': remove_indices}
dice_remove_params = fgx_classes.Params(content = dice_remove_params_content)
Dice_set.agents = fgx_methods.jit_remove_agents(remove_func = Dice.remove_agent, num_agents_remove = num_agents_dead, 
                                                remove_params = dice_remove_params, agents = Dice_set.agents)
print(jnp.reshape(Dice_set.agents.state.content['draw'], (-1)))

# sort agents by active state, as new agents are ALWAYS added at the END of the set
Dice_set.agents, sorted_indices = fgx_methods.jit_sort_agents(quantity = Dice_set.agents.active_state, ascend = False, agents = Dice_set.agents)
print(jnp.reshape(Dice_set.agents.state.content['draw'], (-1)))
print(Dice_set.agents.active_state)

# add a new agent for every agent that has drawn a 6
# first, select all agents who have drawn a 6
def is_six(dice_agent: fgx_classes.Agent, select_params: fgx_classes.Params):
    draw = jnp.reshape(dice_agent.state.content['draw'], (-1))
    return draw == 6
num_agents_add, add_indices = fgx_methods.jit_select_agents(select_func = is_six, select_params = None, agents = Dice_set.agents)

# clip the number of agents to add to the number of inactive agents
num_active_agents = jnp.sum(Dice_set.agents.active_state, dtype = jnp.int32)
num_agents_add = jnp.minimum(num_agents_add, Dice_set.num_total_agents - num_active_agents)

# add the agents
Dice_set.agents, key = fgx_methods.jit_add_agents(add_func = Dice.add_agent, num_agents_add = num_agents_add, 
                                                  add_params = None, agents = Dice_set.agents, key = None)
print(jnp.reshape(Dice_set.agents.state.content['draw'], (-1)))
print(Dice_set.agents.active_state)


In [11]:
agent_set_params_content = {'num_total_agents': 10, 'num_active_agents': 7, 'agent_type': 1}
def create_ageent_set_params(num_total_agents, num_active_agents, agent_type):
    return Params(content = {'num_total_agents': num_total_agents, 'num_active_agents': num_active_agents, 'agent_type': agent_type})

num_total_agents = jnp.array([10, 10, 10])
num_active_agents = jnp.array([4, 5, 3])
agent_type = jnp.array([1, 1, 1])

agent_set_params = jax.vmap(create_ageent_set_params, in_axes=(0, 0, 0))(num_total_agents, num_active_agents, agent_type)
print(agent_set_params)

def do_something(params):
    num_total_agents = params.content['num_total_agents']
    num_active_agents = params.content['num_active_agents']
    agent_type = params.content['agent_type']
    return num_total_agents + num_active_agents + agent_type
vmapped_do_something = jax.vmap(do_something)

print(vmapped_do_something(agent_set_params))


Params(content={'agent_type': Array([1, 1, 1], dtype=int32), 'num_active_agents': Array([4, 5, 3], dtype=int32), 'num_total_agents': Array([10, 10, 10], dtype=int32)})
[15 16 14]


In [24]:
dice_agent_set = jax.vmap(create_agents, in_axes=(None, None, 0, 0))(Dice.create_agent, agent_params, agent_set_params, subkeys)

TypeError: Cannot interpret value of type <class 'function'> as an abstract array; it does not have a dtype attribute

In [9]:
agent_set_params_content = {'num_total_agents': 10, 'num_active_agents': 7, 'agent_type': 1}
agent_set_params = Params(content = agent_set_params_content)

In [11]:
dice_agent_set = jax.vmap(create_agents, in_axes=(None, None, None, 0))(Dice.create_agent, agent_params, agent_set_params, subkeys)

In [None]:
agent_params = None


dice_agent_set = create_agents(Dice.create_agent, agent_params, agent_set_params, key = jax.random.PRNGKey(0))