In [1]:
from agent_class_new import *
from agent_methods_new import *
import jax.numpy as jnp
import jax

In [2]:
@struct.dataclass
class Dice(Agent):
    @staticmethod
    def create_agent(type, params, id, active_state, key):
        key, subkey = jax.random.split(key)
        
        def create_active_agent(key):
            draw = jax.random.randint(key, (1,), 1, 7)
            state_content = {'draw': draw, 'value': 0}
            return State(content=state_content)
        
        def create_inactive_agent():
            state_content = {'draw': jnp.array([0]), 'value': 0}
            return State(content=state_content)
        agent_state = jax.lax.cond(active_state, lambda _: create_active_agent(subkey), lambda _: create_inactive_agent(), None)
        
        return Dice(params = params, id = id, agent_type = type, key = key,
                    active_state = active_state, state = agent_state, policy = None, age = 0.0)
    
    @staticmethod
    @jax.jit
    def step_agent(params: Params, input: Signal, dice_agent: Agent):
        
        def step_active_agent(dice_agent):
            
            input_value = input.content['value']
            input_value = input_value + dice_agent.state.content['value']
            key, subkey = jax.random.split(dice_agent.key)

            draw = jax.random.randint(subkey, (1,), 1, 7)
            state_content = {'draw': draw, 'value': input_value}
            new_state = State(content = state_content)
            return dice_agent.replace(state = new_state, age = dice_agent.age + 1.0, key = key)
        
        def step_inactive_agent(dice_agent):
            return dice_agent
        
        new_dice_agent = jax.lax.cond(dice_agent.active_state, lambda _: step_active_agent(dice_agent), lambda _: step_inactive_agent(dice_agent), None)
        return new_dice_agent
    
    @staticmethod
    def add_agent(params: Params, dice_agents: Agent, idx: jnp.int32):
        value = params.content['num_agents_add'] + 1
        inactive_agent = jax.tree_util.tree_map(lambda x:x[idx], dice_agents)
        key, subkey = jax.random.split(inactive_agent.key)
        draw = jax.random.randint(subkey, (1,), 1, 7)
        state_content = {'draw': draw, 'value': value}
        new_state = State(content = state_content)
        active_agent = inactive_agent.replace(state = new_state, active_state = True, age = 0.0, key = key)
        return active_agent
    
    @staticmethod
    def remove_agent(params: Params, dice_agents: Agent, idx: jnp.int32):
        dice_to_remove = jax.tree_util.tree_map(lambda x:x[idx], dice_agents)
        draw = jnp.array([0])
        state_content = {'draw': draw, 'value': 0}
        new_state = State(content = state_content)
        inactive_agent = dice_to_remove.replace(state = new_state, active_state = False)
        return inactive_agent

In [3]:
@struct.dataclass
class Dice_Set(Agent_Set):
    @staticmethod
    def create_agent_set(num_agents, num_active_agents, agents, agent_set_params, id, key):
        policy = None
        key, subkey = jax.random.split(key, 2)
        state_content = {'agent_set value': jax.random.uniform(subkey, minval = 0.0, maxval = 1.0, shape = (num_agents,))}
        state = State(content = state_content)

        return Dice_Set(num_agents = num_agents, agents = agents, params = agent_set_params, num_active_agents=num_active_agents, 
                        id = id, state = state, policy = policy, key = key)
    @staticmethod
    def step_agent_set(step_params, input, agent_set):
        #agent_input = Signal(content = {'value': input.content['agent_input']})
        #jax.debug.print("agent_input {x}", x = agent_input.content['value'])
        agent_input = Signal(content = {'value': input.content['agent_input']})
        agent_step_params = None
        
        new_agents = jit_step_agents_2(agent_set.agents.step_agent, agent_step_params, agent_input, agent_set.agents)
        return agent_set.replace(agents = new_agents)

In [4]:
params = None
num_agents = 10
num_active_agents = 5
agent_type = 1
key = jax.random.PRNGKey(0)


dice_agents = create_agents(Dice, params, num_agents, num_active_agents, agent_type, key)

In [5]:
step_params = None
input = Signal(content = {'value': jnp.tile( 1, num_agents)})
print(dice_agents.state.content['draw'])
print(dice_agents.state.content['value'])

[[6]
 [4]
 [6]
 [2]
 [4]
 [0]
 [0]
 [0]
 [0]
 [0]]
[0 0 0 0 0 0 0 0 0 0]


In [6]:
dice_agents = jit_step_agents(step_params, input, dice_agents)
print(dice_agents.state.content['draw'])
print(dice_agents.state.content['value'])

[[2]
 [1]
 [1]
 [2]
 [3]
 [0]
 [0]
 [0]
 [0]
 [0]]
[1 1 1 1 1 0 0 0 0 0]


In [7]:
dice_agents = jit_step_agents_2(Dice.step_agent, step_params, input, dice_agents)
print(dice_agents.state.content['draw'])
print(dice_agents.state.content['value'])

[[1]
 [4]
 [2]
 [2]
 [5]
 [0]
 [0]
 [0]
 [0]
 [0]]
[2 2 2 2 2 0 0 0 0 0]


In [8]:
agent_set_params = None
agent_params = None
agent_type = 1
num_agent_set = 10
num_agents = 10
num_active_agents = jnp.arange(1, 11)
key = jax.random.PRNGKey(0)

dice_agent_set = create_agent_sets(agent_set_params = agent_set_params, agent_params = agent_params,
                                    agent_set = Dice_Set, agent = Dice, num_agent_set = num_agent_set, agent_type = agent_type,
                                    num_agents = num_agents, num_active_agents = num_active_agents, key = key)

In [9]:
print(dice_agent_set.agents.state.content['draw'].reshape(num_agent_set, num_agents))
print(dice_agent_set.num_active_agents)
print(dice_agent_set.agents.active_state)

[[5 0 0 0 0 0 0 0 0 0]
 [1 5 0 0 0 0 0 0 0 0]
 [3 3 4 0 0 0 0 0 0 0]
 [1 2 5 2 0 0 0 0 0 0]
 [1 6 3 4 1 0 0 0 0 0]
 [3 5 2 6 6 3 0 0 0 0]
 [6 1 1 6 2 5 2 0 0 0]
 [4 2 1 5 4 1 5 2 0 0]
 [2 2 3 1 2 4 6 5 6 0]
 [3 4 5 1 6 5 4 2 6 5]]
[ 1  2  3  4  5  6  7  8  9 10]
[[1 0 0 0 0 0 0 0 0 0]
 [1 1 0 0 0 0 0 0 0 0]
 [1 1 1 0 0 0 0 0 0 0]
 [1 1 1 1 0 0 0 0 0 0]
 [1 1 1 1 1 0 0 0 0 0]
 [1 1 1 1 1 1 0 0 0 0]
 [1 1 1 1 1 1 1 0 0 0]
 [1 1 1 1 1 1 1 1 0 0]
 [1 1 1 1 1 1 1 1 1 0]
 [1 1 1 1 1 1 1 1 1 1]]


In [10]:
num_agents_add = jnp.tile(1, num_agent_set)
add_params = Params(content= {'num_agents_add': jnp.tile(1, num_agent_set)})
dice_agent_set = jit_add_agents(Dice.add_agent, num_agents_add, add_params, dice_agent_set)
print(dice_agent_set.agents.state.content['draw'].reshape(num_agent_set, num_agents))
#print(dice_agent_set.agents.active_state)
#print(dice_agent_set.num_active_agents)
#print(dice_agent_set.agents.state.content['value'].reshape(num_agent_set, num_agents))

[[5 5 0 0 0 0 0 0 0 0]
 [1 5 6 0 0 0 0 0 0 0]
 [3 3 4 6 0 0 0 0 0 0]
 [1 2 5 2 4 0 0 0 0 0]
 [1 6 3 4 1 2 0 0 0 0]
 [3 5 2 6 6 3 4 0 0 0]
 [6 1 1 6 2 5 2 5 0 0]
 [4 2 1 5 4 1 5 2 6 0]
 [2 2 3 1 2 4 6 5 6 6]
 [3 4 5 1 6 5 4 2 6 5]]


In [11]:
def select_draw(agents:Dice, params:Params):
    draws = agents.state.content['draw'].reshape(-1)
    select_draw = params.content['select_draw']
    return draws == select_draw
select_params = Params(content = {'select_draw': jnp.array([5,6,3,2,1,6,2,4,3,5])})

selected_len, slected_ids = jit_select_agents(select_draw, select_params, dice_agent_set.agents)
print(selected_len)
print(slected_ids)

[2 1 2 2 2 2 2 2 1 3]
[[0 1 2 3 4 5 6 7 8 9]
 [2 0 1 3 4 5 6 7 8 9]
 [0 1 2 3 4 5 6 7 8 9]
 [1 3 0 2 4 5 6 7 8 9]
 [0 4 1 2 3 5 6 7 8 9]
 [3 4 0 1 2 5 6 7 8 9]
 [4 6 0 1 2 3 5 7 8 9]
 [0 4 1 2 3 5 6 7 8 9]
 [2 0 1 3 4 5 6 7 8 9]
 [2 5 9 0 1 3 4 6 7 8]]


In [12]:
num_agents_remove = jnp.tile(1, num_agent_set)
remove_ids = jnp.array([[0 , 1, 2, 3, 4, 5, 6, 7, 8, 9],
                        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
                        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
                        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
                        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
                        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0]])

remove_ids = slected_ids

#num_agents_remove = jnp.array([2, 5, 3, 4, 2, 5, 3, 4, 2, 5])
num_agents_remove = selected_len

remove_params = Params(content= {'remove_ids': remove_ids})

dice_agent_set = jit_remove_agents(Dice.remove_agent, num_agents_remove, remove_params, dice_agent_set)
print(dice_agent_set.num_active_agents)
print(dice_agent_set.agents.active_state)

[0 2 2 3 4 5 6 7 9 7]
[[0 0 0 0 0 0 0 0 0 0]
 [1 1 0 0 0 0 0 0 0 0]
 [1 1 0 0 0 0 0 0 0 0]
 [1 1 1 0 0 0 0 0 0 0]
 [1 1 1 1 0 0 0 0 0 0]
 [1 1 1 1 1 0 0 0 0 0]
 [1 1 1 1 1 1 0 0 0 0]
 [1 1 1 1 1 1 1 0 0 0]
 [1 1 1 1 1 1 1 1 1 0]
 [1 1 1 1 1 1 1 0 0 0]]


In [13]:
print(dice_agent_set.agents.state.content['draw'].reshape(num_agent_set, num_agents))

[[0 0 0 0 0 0 0 0 0 0]
 [1 5 0 0 0 0 0 0 0 0]
 [4 6 0 0 0 0 0 0 0 0]
 [1 5 4 0 0 0 0 0 0 0]
 [6 3 4 2 0 0 0 0 0 0]
 [3 5 2 3 4 0 0 0 0 0]
 [6 1 1 6 5 5 0 0 0 0]
 [2 1 5 1 5 2 6 0 0 0]
 [2 2 1 2 4 6 5 6 6 0]
 [3 4 1 6 4 2 6 0 0 0]]


In [68]:
agent_set_input = None
agent_input = jnp.tile(1,(num_agent_set, num_agents))
input = Signal(content = {'agent_set_input': agent_set_input, 'agent_input': agent_input})

step_params = None


dice_agent_set = jit_step_agent_sets(step_params, input, dice_agent_set)
print(dice_agent_set.agents.state.content['value'])
print(dice_agent_set.agents.state.content['draw'])

[[2 4 4 4 4 4 4 4 4 4]
 [2 2 4 4 4 4 4 4 4 4]
 [2 2 2 4 4 4 4 4 4 4]
 [2 2 2 2 4 4 4 4 4 4]
 [2 2 2 2 2 4 4 4 4 4]
 [2 2 2 2 2 2 4 4 4 4]
 [2 2 2 2 2 2 2 4 4 4]
 [2 2 2 2 2 2 2 2 4 4]
 [2 2 2 2 2 2 2 2 2 4]
 [2 2 2 2 2 2 2 2 2 2]]
[[[1]
  [3]
  [3]
  [6]
  [6]
  [2]
  [4]
  [6]
  [3]
  [3]]

 [[4]
  [5]
  [4]
  [1]
  [4]
  [6]
  [1]
  [3]
  [2]
  [1]]

 [[4]
  [2]
  [5]
  [4]
  [5]
  [3]
  [4]
  [2]
  [6]
  [3]]

 [[2]
  [5]
  [2]
  [2]
  [1]
  [4]
  [2]
  [2]
  [5]
  [5]]

 [[4]
  [3]
  [4]
  [3]
  [6]
  [6]
  [4]
  [5]
  [5]
  [6]]

 [[1]
  [5]
  [2]
  [2]
  [2]
  [6]
  [3]
  [5]
  [2]
  [6]]

 [[2]
  [1]
  [5]
  [1]
  [4]
  [1]
  [5]
  [6]
  [5]
  [3]]

 [[2]
  [4]
  [3]
  [6]
  [1]
  [6]
  [6]
  [6]
  [4]
  [1]]

 [[5]
  [5]
  [4]
  [1]
  [1]
  [4]
  [5]
  [4]
  [3]
  [2]]

 [[3]
  [6]
  [4]
  [1]
  [2]
  [1]
  [5]
  [1]
  [6]
  [5]]]


In [169]:

agent_set = jit_step_agent_sets_2(Dice_Set.step_agent_set, step_params, input, dice_agent_set)

In [175]:
agent_set = jit_step_agent_sets_2(Dice_Set.step_agent_set, step_params, input, agent_set)
print(agent_set.agents.state.content['value'])

[[7 0 0 0 0 0 0 0 0 0]
 [7 7 0 0 0 0 0 0 0 0]
 [7 7 7 0 0 0 0 0 0 0]
 [7 7 7 7 0 0 0 0 0 0]
 [7 7 7 7 7 0 0 0 0 0]
 [7 7 7 7 7 7 0 0 0 0]
 [7 7 7 7 7 7 7 0 0 0]
 [7 7 7 7 7 7 7 7 0 0]
 [7 7 7 7 7 7 7 7 7 0]
 [7 7 7 7 7 7 7 7 7 7]]


In [167]:
%%timeit
agent_set = jit_step_agent_sets(step_params, input, dice_agent_set)

47 µs ± 82 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [13]:
def add(x, y):
    return x+y
vmap_add = jax.vmap(add)
a = jnp.arange(10)
b = jnp.arange(10)
print(vmap_add(a,b))
c = 1.0
d = 2.0
print(vmap_add(c,d))



[ 0  2  4  6  8 10 12 14 16 18]


ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())