In [79]:
import sys
sys.path.insert(1, '/Users/siddarth.chaturvedi/Desktop/source/foragax/foragax_github/')

from base.agent_classes import *
from base.agent_methods import *
from base.space_methods import *
from base.space_classes import *
from policy.random_policy import *

import jax
import jax.numpy as jnp
from flax import struct


In [211]:
@struct.dataclass
class Animal(Agent):

    @staticmethod
    def create_agent(params: Params, unique_id: int, active_state: int, agent_type:int, key: jax.random.PRNGKey):
        key, policy_key = jax.random.split(key)
        policy = Random_Policy.create_policy(params, policy_key)
        
        agent_params_content = {'reproduction_rate': params.content['reproduction_rate'], 
                                'initial_energy': params.content['initial_energy']}
        agent_params = Params(content=agent_params_content) 
        
        space = params.content['space']
        X_pos_max = space.x_max
        Y_pos_max = space.y_max

        def create_active_agent(key, params):
            key, *pos_keys = jax.random.split(key, 3)
            X_pos = jax.random.randint(pos_keys[0], minval=0, maxval=X_pos_max, shape=(1,))
            Y_pos = jax.random.randint(pos_keys[1], minval=0, maxval=Y_pos_max, shape=(1,))
            energy = jnp.array([params.content['initial_energy']])
            state_content = {'X_pos': X_pos, 'Y_pos': Y_pos, 'energy': energy, 'reproduce': 0, 'key': key}
            state = State(content=state_content)
            return state
        
        def create_inactive_agent(key):
            state_content = {'X_pos': jnp.array([-1]), 'Y_pos': jnp.array([-1]), 'energy': jnp.array([-1]), 'reproduce': 0, 'key': key}
            state = State(content=state_content)
            return state
        
        agent_state = jax.lax.cond(active_state, lambda _: create_active_agent(key,params), lambda _: create_inactive_agent(key), None)
        return Animal(unique_id=unique_id, active_state=active_state, agent_type=agent_type, 
                    state=agent_state, policy=policy, params=agent_params, age=0.0)
    
    @staticmethod
    def step_agent(params:Params, input:Signal, animal:Agent):
        space = params.content['space']
        def step_active_agent(input, animal, space):
            energy_in = input.content['energy_in']
            X_pos = animal.state.content['X_pos']
            Y_pos = animal.state.content['Y_pos']
            energy = animal.state.content['energy']

            action, new_policy = Random_Policy.step_policy(None, animal.policy)
            action = action.content['action']
            
            # Move the animal
            X_pos_new = jax.lax.cond(action[0] < 0.25, lambda _: X_pos + 1, 
                                    lambda _: X_pos, None)
            X_pos_new = jax.lax.cond(jnp.logical_and(action[0] >= 0.25, action[0] < 0.5), lambda _: X_pos - 1, 
                                    lambda _: X_pos_new, None)
            Y_pos_new = jax.lax.cond(jnp.logical_and(action[0] >= 0.5, action[0] < 0.75), lambda _: Y_pos + 1,
                                    lambda _: Y_pos, None)
            Y_pos_new = jax.lax.cond(action[0] >= 0.75, lambda _: Y_pos - 1,
                                    lambda _: Y_pos_new, None)
            
            # check for boundaries
            X_pos_new = jax.lax.cond(jnp.logical_and(space.torous, X_pos_new[0] < 0), lambda _: jnp.array([space.x_max]),
                                    lambda _: X_pos_new, None)
            X_pos_new = jax.lax.cond(jnp.logical_and(space.torous, X_pos_new[0] > space.x_max), lambda _: jnp.array([space.x_min]),
                                    lambda _: X_pos_new, None)
            Y_pos_new = jax.lax.cond(jnp.logical_and(space.torous, Y_pos_new[0] < space.y_min), lambda _: jnp.array([space.y_max]),
                                    lambda _: Y_pos_new, None)
            Y_pos_new = jax.lax.cond(jnp.logical_and(space.torous, Y_pos_new[0] > space.y_max), lambda _: jnp.array([space.y_min]),
                                    lambda _: Y_pos_new, None)
            X_pos_new = jnp.clip(X_pos_new, space.x_min, space.x_max)
            Y_pos_new = jnp.clip(Y_pos_new, space.y_min, space.y_max)
            
            energy_new = energy - 1 + energy_in
            
            key, head_tail_key = jax.random.split(animal.state.content['key'])
            rand_float = jax.random.uniform(head_tail_key, shape=(1,))
            reproduce = jax.lax.cond(rand_float[0] < animal.params.content['reproduction_rate'], lambda _: 1, 
                                    lambda _: animal.state.content['reproduce'], None)
            
            new_animal_state_content = {'X_pos': X_pos_new, 'Y_pos': Y_pos_new, 'energy': energy_new, 'reproduce': reproduce, 'key': key}
            new_animal_state = State(content = new_animal_state_content)
            new_animal = animal.replace(state = new_animal_state, age = animal.age + 1.0, policy=new_policy)
            return new_animal
        
        def step_inactive_agent():
            return animal

        new_animal = jax.lax.cond(animal.active_state, lambda _: step_active_agent(input, animal, space), lambda _: step_inactive_agent(), None)
        return new_animal
    
    def remove_agent(params:Params, animals:Agent, idx):
        animal_to_remove = jax.tree_util.tree_map(lambda x: x[idx], animals)
        key = animal_to_remove.state.content['key']
        state_content = {'X_pos': jnp.array([-1]), 'Y_pos': jnp.array([-1]), 'energy': jnp.array([-1]), 'reproduce': 0, 'key':key} 
        state = State(content=state_content)
        age = 0.0
        active_state = False
        removed_animal = animal_to_remove.replace(state=state, age=age, active_state=active_state)
        return removed_animal
    
    def add_agent(params:Params, animals:Agent, idx, key:jax.random.PRNGKey):
        
        copy_ids = params.content['copy_ids'] # contains the ids of the agents that are selected for reproduction
        num_active_agents = params.content['num_active_agents']

        agent_to_add = jax.tree_util.tree_map(lambda x: x[idx], animals)
        agent_to_copy = jax.tree_util.tree_map(lambda x: x[copy_ids[idx - num_active_agents]], animals)
        
        X_pos = agent_to_copy.state.content['X_pos']
        Y_pos = agent_to_copy.state.content['Y_pos']
        energy = agent_to_copy.state.content['energy']/2
        add_agent_reproduction_key = agent_to_add.state.content['key']
        state_content = {'X_pos': X_pos, 'Y_pos': Y_pos, 'energy': energy, 'reproduce': 0, 'key': add_agent_reproduction_key}
        state = State(content=state_content)
        age = 0.0
        active_state = True
        new_animal = agent_to_add.replace(state=state, age=age, active_state=active_state)
        return new_animal, key
    
    def set_agent(params:Params, animals:Agent, idx, key:jax.random.PRNGKey):
        agent_to_set = jax.tree_util.tree_map(lambda x: x[idx], animals)
        X_pos = agent_to_set.state.content['X_pos']
        Y_pos = agent_to_set.state.content['Y_pos']
        energy = agent_to_set.state.content['energy']/2
        set_agent_reproduction_key = agent_to_set.state.content['key']
        state_content = {'X_pos': X_pos, 'Y_pos': Y_pos, 'energy': energy, 'reproduce': 0, 'key': set_agent_reproduction_key}
        state = State(content=state_content)
        new_animal = agent_to_set.replace(state=state)
        return new_animal, key

In [185]:
@struct.dataclass
class Grass(Agent):
    @staticmethod
    def create_agent(params: Params, unique_id: int, active_state: int, agent_type: int, key: jax.random.PRNGKey):
        
        space = params.content['space']
        regrowth_time = params.content['regrowth_time']
        
        params_content = {'regrowth_time': regrowth_time}
        params = Params(content=params_content)

        X_pos_max = space.x_max
        
        X_pos = jnp.array([jnp.mod(unique_id, X_pos_max)])
        Y_pos = jnp.array([jnp.floor_divide(unique_id, X_pos_max)])
        
        key, count_down_key, fully_grown_key = jax.random.split(key, 3)
        head_tail = jax.random.uniform(fully_grown_key, shape=(1,))
        fully_grown = jax.lax.cond(head_tail[0] < 0.5, lambda _: False, lambda _: True, None)
        count_down = jax.lax.cond(fully_grown, lambda _: jnp.array([0]), lambda _: jax.random.randint(count_down_key, minval=1, maxval=regrowth_time, shape=(1,)), None)

        state_content = {'X_pos': X_pos, 'Y_pos': Y_pos, 'fully_grown': fully_grown, 'count_down': count_down}
        state = State(content=state_content)
        age = 0.0
        return Grass(unique_id=unique_id, active_state=active_state, 
                     agent_type=agent_type, state=state, age=age, 
                     params=params, policy=None)
    
    @staticmethod
    def step_agent(params:Params, input:Signal, grass:Agent):
        energy_out = input.content['energy_out'] # the logic of fully grown grass gets eaten is in the main step function
        count_down = grass.state.content['count_down']
        fully_grown = grass.state.content['fully_grown']
        
        new_count_down, new_fully_grown = jax.lax.cond(energy_out[0], lambda _: (jnp.array([grass.params.content['regrowth_time']]), False), 
                                                       lambda _: (count_down, fully_grown), None)
        new_count_down = jax.lax.cond(new_fully_grown, lambda _: new_count_down, lambda _: new_count_down - 1, None)
        new_fully_grown = jax.lax.cond(new_count_down[0] <= 0, lambda _: True, lambda _: new_fully_grown, None)

        new_grass_state_content = {'X_pos': grass.state.content['X_pos'], 'Y_pos': grass.state.content['Y_pos'], 
                                   'fully_grown': new_fully_grown, 'count_down': new_count_down}
        
        new_grass_state = State(content=new_grass_state_content)

        new_grass = grass.replace(state=new_grass_state)
        return new_grass
        
        

In [103]:
space = create_space(x_min=0, x_max=10, y_min=0, y_max=10, torous=True, wall_array = None)

In [212]:
def interaction(wolves:Animal, sheeps:Animal, grasses:Grass):
    
    # wolves eat sheep:
    def one_wolf_all_sheep(wolf, sheeps): # vmap across all wolves

        def one_wolf_one_sheep(wolf, sheep): # vmap across all sheep
            wolf_X_pos = wolf.state.content['X_pos']
            wolf_Y_pos = wolf.state.content['Y_pos']
            sheep_X_pos = sheep.state.content['X_pos']
            sheep_Y_pos = sheep.state.content['Y_pos']
            condition = jnp.logical_and(wolf_X_pos[0] == sheep_X_pos[0], wolf_Y_pos[0] == sheep_Y_pos[0])
            
            wolf_energy_in = jax.lax.cond(condition, lambda _: jnp.array([1.0]), lambda _: jnp.array([0.0]), None)
            return wolf_energy_in
        
        one_wolf_energy_from_all_sheeps = jax.vmap(one_wolf_one_sheep, in_axes=(None, 0))(wolf, sheeps)
        return one_wolf_energy_from_all_sheeps
    
    wolves_sheeps_matrix = jax.vmap(one_wolf_all_sheep, in_axes=(0, None))(wolves, sheeps)
    '''
    for wolves this matrix is summed across all columns to get the total energy gained by each wolf
    for sheeps we take the max of the matrix across all rows to get if the sheep is eaten or not
    '''
    wolves_energy_in = jnp.sum(wolves_sheeps_matrix, axis=1, dtype=jnp.int32)
    sheeps_eaten = jnp.max(wolves_sheeps_matrix, axis=0)

    # sheeps eat grass:
    def one_sheep_all_grass(sheep, grasses): # vmap across all sheeps

        def one_sheep_one_grass(sheep, grass): # vmap across all grass
            sheep_X_pos = sheep.state.content['X_pos']
            sheep_Y_pos = sheep.state.content['Y_pos']
            grass_X_pos = grass.state.content['X_pos']
            grass_Y_pos = grass.state.content['Y_pos']
            grass_fully_grown = grass.state.content['fully_grown']
            condition = jnp.logical_and(sheep_X_pos[0] == grass_X_pos[0], sheep_Y_pos[0] == grass_Y_pos[0])
            condition = jnp.logical_and(condition, grass_fully_grown)

            sheep_energy_in = jax.lax.cond(condition, lambda _: jnp.array([1.0]), lambda _: jnp.array([0.0]), None)
            return sheep_energy_in
        one_sheep_energy_from_all_grass = jax.vmap(one_sheep_one_grass, in_axes=(None, 0))(sheep, grasses)
        return one_sheep_energy_from_all_grass
    
    sheeps_grass_matrix = jax.vmap(one_sheep_all_grass, in_axes=(0, None))(sheeps, grasses)
    '''
    for sheeps this matrix is summed across all columns to get the total energy gained by each sheep
    for grass we take the max of the matrix across all rows to get if the grass is eaten or not
    '''
    sheeps_energy_in = jnp.sum(sheeps_grass_matrix, axis=1, dtype=jnp.int32)
    grasses_eaten = jnp.max(sheeps_grass_matrix, axis=0)

    return wolves_energy_in, sheeps_energy_in, sheeps_eaten, grasses_eaten
jit_interaction = jax.jit(interaction)

        


In [229]:
class Ecosystem():
    grasses: Agent_Set
    sheeps: Agent_Set
    wolves: Agent_Set
    space: Space
    sim_steps: jnp.int32
    key: jax.random.PRNGKey
    
    def __init__(self, grass_regrowth_time, wolf_reproduction_rate, wolf_energy, sheep_reproduction_rate, sheep_energy, X_max, Y_max, sim_steps, key):
        self.key = key

        self.space = create_space(x_min=0, x_max=X_max, y_min=0, y_max=Y_max, torous=True, wall_array = None)
        
        key, grass_key = jax.random.split(key)
        grass_num = self.space.x_max*self.space.y_max
        self.grasses = Agent_Set(agent=Grass, num_total_agents=grass_num, num_active_agents=grass_num, agent_type=0)
        grass_params_content = {'regrowth_time': grass_regrowth_time, 'space': self.space}
        grass_params = Params(content=grass_params_content)
        self.grasses.agents = create_agents(params=grass_params, agent_set=self.grasses, key=grass_key)

        key, sheep_key = jax.random.split(key)
        self.sheeps = Agent_Set(agent=Animal, num_total_agents = 100, num_active_agents = 30, agent_type=1)
        sheep_params_content = {'reproduction_rate': sheep_reproduction_rate, 'initial_energy': sheep_energy, 'space': self.space}
        sheep_params = Params(content=sheep_params_content)
        self.sheeps.agents = create_agents(params=sheep_params, agent_set=self.sheeps, key=sheep_key)

        key, wolf_key = jax.random.split(key)
        self.wolves = Agent_Set(agent=Animal, num_total_agents = 100, num_active_agents = 10, agent_type=2)
        wolf_params_content = {'reproduction_rate': wolf_reproduction_rate, 'initial_energy': wolf_energy, 'space': self.space}
        wolf_params = Params(content=wolf_params_content)
        self.wolves.agents = create_agents(params=wolf_params, agent_set=self.wolves, key=wolf_key)
        self.sim_steps = sim_steps

    @staticmethod
    def select_dead_sheeps(sheeps_agents:Agent, select_params:Params):
        energy = jnp.reshape(sheeps_agents.state.content['energy'], (-1))
        is_eaten = jnp.reshape(select_params.content['sheeps_eaten'], (-1))
        is_dead = jnp.logical_or(energy <= 0, is_eaten)
        is_dead = jnp.logical_and(is_dead, sheeps_agents.active_state)
        return is_dead
    
    @staticmethod
    def select_dead_wolves(wolves_agents:Agent, select_params:Params):
        energy = jnp.reshape(wolves_agents.state.content['energy'], (-1))
        is_dead = jnp.logical_and(energy <= 0, wolves_agents.active_state)
        return is_dead
    
    @staticmethod
    def select_reproduce_animals(animals_agents:Agent, select_params:Params):
        is_reproduce = jnp.reshape(animals_agents.state.content['reproduce'], (-1))
        return jnp.logical_and(is_reproduce, animals_agents.active_state)
    
    @staticmethod
    def get_num_animals_to_add(num_selected, active_state_arr, max_num):
        num_active = jnp.sum(active_state_arr, dtype=jnp.int32)
        num_to_add = max_num - num_active
        num_to_add = jax.lax.cond(num_to_add > num_selected, lambda _: num_selected, lambda _: num_to_add, None)
        return num_to_add, num_active
    
    @staticmethod
    def add_animals(animal_agents, max_num, key):
        # select the animals that are going to reproduce
        num_animal_reproduce, animal_reproduce_indx = jit_select_agents(Ecosystem.select_reproduce_animals, None, animal_agents)
        num_animal_reproduce, num_active_animals = jax.jit(Ecosystem.get_num_animals_to_add)(num_animal_reproduce, animal_agents.active_state, max_num)

        # add the animals that are going to reproduce
        animal_add_params_content = {'copy_ids': animal_reproduce_indx, 'num_active_agents': num_active_animals}
        animal_add_params = Params(content=animal_add_params_content)
        new_animal_agents, key = jit_add_agents(Animal.add_agent, num_agents_add=num_animal_reproduce, 
                                                add_params=animal_add_params, agents=animal_agents, key=key)
        
        # set the energy of the animals that are going to reproduce to half and remove the reproduce flag
        animal_set_params_content = {'set_ids': animal_reproduce_indx}
        animal_set_params = Params(content=animal_set_params_content)
        new_animal_agents, key = jit_set_agents(Animal.set_agent, num_agents_set=num_animal_reproduce, 
                                                set_params=animal_set_params, agents=new_animal_agents, key=key)
        
        # sort the animals in descending order of reproduction so that in the next step the animals that were slected but did not reproduce are at beginning
        new_animal_agents, sorted_ids = jit_sort_agents(quantity=new_animal_agents.state.content['reproduce'], ascend=False, agents=new_animal_agents)
        
        return new_animal_agents, key

    
    
    def step(self):
        # interaction
        wolves_energy_in, sheeps_energy_in, sheeps_eaten, grasses_eaten = jit_interaction(self.wolves.agents, self.sheeps.agents, self.grasses.agents)
        
        # agent steps movement and energy update
        grass_step_signal = Signal(content={'energy_out': grasses_eaten})
        self.grasses.agents = step_agents(params=None, agent_set=self.grasses, input=grass_step_signal)


        animals_step_param_content = {'space': self.space}
        animals_step_params = Params(content=animals_step_param_content)

        sheep_step_signal = Signal(content={'energy_in': sheeps_energy_in})
        self.sheeps.agents = step_agents(params=animals_step_params, agent_set=self.sheeps, input=sheep_step_signal)

        wolf_step_signal = Signal(content={'energy_in': wolves_energy_in})
        self.wolves.agents = step_agents(params=animals_step_params, agent_set=self.wolves, input=wolf_step_signal)

        # remove eaten or energy defficient sheeps
        
        # select the dead sheeps
        dead_sheep_select_params_content = {'sheeps_eaten': sheeps_eaten}   
        dead_sheep_select_params = Params(content=dead_sheep_select_params_content)
        num_sheeps_dead, sheep_remove_indx = jit_select_agents(Ecosystem.select_dead_sheeps, dead_sheep_select_params, self.sheeps.agents)
        
        # remove the dead sheeps
        dead_sheep_remove_params_content = {'remove_ids': sheep_remove_indx}
        dead_sheep_remove_params = Params(content=dead_sheep_remove_params_content)
        self.sheeps.agents = jit_remove_agents(Animal.remove_agent, num_agents_remove=num_sheeps_dead, 
                                               remove_params=dead_sheep_remove_params, agents=self.sheeps.agents)

        #sort the sheeps in descending order of activation so that the inactive sheeps are at the end
        self.sheeps.agents, sorted_ids = jit_sort_agents( quantity=self.sheeps.agents.active_state, ascend=False, agents=self.sheeps.agents)
        
        
        # remove dead wolves

        # select the dead wolves
        num_wolves_dead, wolf_remove_indx = jit_select_agents(Ecosystem.select_dead_wolves, None, self.wolves.agents)
        # remove the dead wolves
        dead_wolf_remove_params_content = {'remove_ids': wolf_remove_indx}
        dead_wolf_remove_params = Params(content=dead_wolf_remove_params_content)
        self.wolves.agents = jit_remove_agents(Animal.remove_agent, num_agents_remove=num_wolves_dead, 
                                               remove_params=dead_wolf_remove_params, agents=self.wolves.agents)
        
        #sort the wolves in descending order of activation so that the inactive wolves are at the end
        self.wolves.agents, sorted_ids = jit_sort_agents( quantity=self.wolves.agents.active_state, ascend=False, agents=self.wolves.agents)
        
        # add new sheeps and wolves
        self.sheeps.agents, self.key = jax.jit(Ecosystem.add_animals)(self.sheeps.agents, self.sheeps.num_total_agents, self.key)
        self.wolves.agents, self.key = jax.jit(Ecosystem.add_animals)(self.wolves.agents, self.wolves.num_total_agents, self.key)

        return num_sheeps_dead#, num_sheeps_reproduce, sheep_reproduce_indx



        




        
        

    

In [230]:
eco = Ecosystem(grass_regrowth_time=5, wolf_reproduction_rate=0.05, wolf_energy=10, sheep_reproduction_rate=0.1, sheep_energy=5, X_max=10, Y_max=10, sim_steps=100, key=jax.random.PRNGKey(0))
num_sheeps_dead = eco.step()
print(num_sheeps_dead)
#print(num_sheeps_reproduce)
#print(indx)

AgentSet initialized
AgentSet initialized
AgentSet initialized




3


In [232]:
for i in range(1000):
    num_sheeps_dead = eco.step()

In [None]:
grass_params_content = {'regrowth_time': 30, 'space': space}
grass_params = Params(content=grass_params_content)



In [75]:
grass_params_content = {'regrowth_time': 30, 'space': space}
grass_params = Params(content=grass_params_content)
grasses = Agent_Set(agent=Grass, num_total_agents=space.x_max*space.y_max, num_active_agents=space.x_max*space.y_max, agent_type=0)
grasses.agents = create_agents(params = grass_params, agent_set = grasses, key = jax.random.PRNGKey(1))
print(grasses.agents.state.content['fully_grown'])
print(grasses.agents.state.content['count_down'])

AgentSet initialized
[ True  True  True  True  True  True  True False  True False False  True
 False False  True False False  True  True  True  True  True False  True
  True False  True False False False  True False False False False  True
 False  True False False False False False  True  True False False  True
 False False  True False False  True  True False False False  True  True
  True  True False  True  True False False  True  True False  True False
 False False False  True  True  True False False  True  True False False
  True  True False False  True  True  True False False False  True  True
 False False False False]
[[ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 5]
 [ 0]
 [19]
 [22]
 [ 0]
 [12]
 [10]
 [ 0]
 [28]
 [ 7]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [29]
 [ 0]
 [ 0]
 [ 1]
 [ 0]
 [28]
 [ 2]
 [22]
 [ 0]
 [29]
 [12]
 [19]
 [27]
 [ 0]
 [ 5]
 [ 0]
 [14]
 [ 9]
 [17]
 [27]
 [15]
 [ 0]
 [ 0]
 [10]
 [16]
 [ 0]
 [12]
 [ 2]
 [ 0]
 [25]
 [29]
 [ 0]
 [ 0]
 [ 5]
 [16]
 [15]
 [ 0]
 [ 0]
 [ 0]
 [ 

In [76]:
grass_step_params=Params(content={'regrowth_time': 30})
grass_input = Signal(content={'energy_out': jnp.tile(False, space.x_max*space.y_max)})
grasses.agents = step_agents(params=grass_step_params, input=grass_input, agent_set=grasses)
print(grasses.agents.state.content['fully_grown'])
print(grasses.agents.state.content['count_down'])


[ True  True  True  True  True  True  True False  True False False  True
 False False  True False False  True  True  True  True  True False  True
  True  True  True False False False  True False False False False  True
 False  True False False False False False  True  True False False  True
 False False  True False False  True  True False False False  True  True
  True  True False  True  True False False  True  True False  True False
 False False False  True  True  True False False  True  True False False
  True  True False False  True  True  True False False False  True  True
 False False False False]
[[ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 4]
 [ 0]
 [18]
 [21]
 [ 0]
 [11]
 [ 9]
 [ 0]
 [27]
 [ 6]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [28]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [27]
 [ 1]
 [21]
 [ 0]
 [28]
 [11]
 [18]
 [26]
 [ 0]
 [ 4]
 [ 0]
 [13]
 [ 8]
 [16]
 [26]
 [14]
 [ 0]
 [ 0]
 [ 9]
 [15]
 [ 0]
 [11]
 [ 1]
 [ 0]
 [24]
 [28]
 [ 0]
 [ 0]
 [ 4]
 [15]
 [14]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 7]
 [ 0]
 [ 0]


In [77]:
wolves_params = {'reproduction_rate': 0.1, 'initial_energy': 20, 'space': space}
wolves_params = Params(content=wolves_params)
wolves = Agent_Set( agent = Animal, num_total_agents=10, num_active_agents=5, agent_type=1)
wolves.agents = create_agents( params=wolves_params, agent_set=wolves, key=jax.random.PRNGKey(0))
print(wolves.agents.active_state)
#print(wolves.agents.state.content['energy'])
print(wolves.agents.state.content['X_pos'][0:5])
print(wolves.agents.state.content['Y_pos'][0:5])

AgentSet initialized
[1. 1. 1. 1. 1. 0. 0. 0. 0. 0.]
[[7]
 [2]
 [4]
 [0]
 [3]]
[[8]
 [7]
 [6]
 [6]
 [9]]


In [5]:
energy_in = jnp.tile(jnp.array([0]), 10)
input = Signal(content={'energy_in': energy_in})
step_params = Params(content={'space': space})
wolves.agents = step_agents(params=step_params, input=input, agent_set=wolves)

In [64]:
sheep_params = {'reproduction_rate': 0.1, 'initial_energy': 20, 'space': space}
sheep_params = Params(content=sheep_params)
sheeps = Agent_Set( agent = Animal, num_total_agents=10, num_active_agents=5, agent_type=2)
sheeps.agents = create_agents( params=sheep_params, agent_set=sheeps, key=jax.random.PRNGKey(1))
print(sheeps.agents.state.content['X_pos'][0:5])
print(sheeps.agents.state.content['Y_pos'][0:5])


AgentSet initialized
[[3]
 [4]
 [1]
 [4]
 [7]]
[[9]
 [4]
 [6]
 [4]
 [0]]


In [78]:
wolf_energy_in, sheeps_eaten, sheeps_energy_in, grasses_eaten = interaction(wolves.agents, sheeps.agents, grasses.agents)
print(wolf_energy_in)
print(sheeps_energy_in)
print(sheeps_eaten)
print(grasses_eaten)

[[0.]
 [0.]
 [0.]
 [0.]
 [1.]
 [5.]
 [5.]
 [5.]
 [5.]
 [5.]]
[[1.]
 [0.]
 [0.]
 [0.]
 [0.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]]
[[0.]
 [1.]
 [1.]
 [1.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]]
[[0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [1.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [1.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]]


In [117]:
from typing import Callable

In [123]:
@struct.dataclass
class jit_check:
    a_funct: callable = struct.field(pytree_node=False)
    an_array: jnp.ndarray

def make_jit_check(a_funct, an_array):
    return jit_check(a_funct=jax.jit(a_funct), an_array=an_array)

a_funct = lambda x: x+1
an_array = jnp.array([1,2,3])
jit_check = make_jit_check(a_funct, an_array)

def do_some_jitting(jit_check):
    return jit_check.a_funct(jit_check.an_array)
jit_do_some_jitting = jax.jit(do_some_jitting)

print(jit_do_some_jitting(jit_check))

[2 3 4]


In [208]:
a =0
b = jax.lax.cond(a, lambda _: 1, lambda _: 2, None)
print(b)

2
