In [3]:
import jax
import jax.numpy as jnp
import random
from functools import partial
from jax.numpy import pi
import random
from flax import struct

In [4]:
@struct.dataclass
class Resources:
    Xs: jnp.array
    Ys: jnp.array
    vals: jnp.array
    rads: jnp.array
    growth_rates: jnp.array
    decay_rates: jnp.array
    active_status: jnp.array

In [5]:
@struct.dataclass
class Agents:
    Xs: jnp.array
    Xdots: jnp.array
    Ys: jnp.array
    Ydots: jnp.array
    angs: jnp.array
    angdots: jnp.array
    energy: jnp.array
    active_status: jnp.array

In [6]:
class Forajax(object):
    res_init: Resources
    agents_init: Agents
    sim_time: float

    def __init__(self, num_active_res = 10, num_total_res = 20, num_active_agents = 1, num_total_agents=5, key = jax.random.PRNGKey(0)):
        self.num_active_res = num_active_res
        self.num_total_res = num_total_res
        self.num_active_agents = num_active_agents
        self.num_total_agents = num_total_agents
        self.sim_time = 0.0
        
        # initialize resources
        Xs = []
        Ys = []
        vals = []
        rads = []
        growth_rates = []
        decay_rates = []
        active_status = []

        key, *four_active_res_keys = jax.random.split(key, 4*num_active_res+1)
        self.key = key

        for i in range(num_active_res):
            Xs.append(jax.random.uniform(four_active_res_keys[i], shape=(), minval=0.0, maxval=2*pi))
            Ys.append(jax.random.uniform(four_active_res_keys[num_active_res+i], shape=(), minval=0.0, maxval=2*pi))
            vals.append(jax.random.uniform(four_active_res_keys[2*num_active_res+i], shape=(), minval=5.0, maxval=10.0))
            rads.append(0.1*vals[-1])
            growth_rates.append(jax.random.uniform(four_active_res_keys[3*num_active_res+i], shape=(), minval=0.0, maxval=1.0))
            decay_rates.append(0.1*growth_rates[-1])
            active_status.append(True)
        #zero padding
        for i in range(num_total_res-num_active_res):
            Xs.append(0.0)
            Ys.append(0.0)
            vals.append(0.0)
            rads.append(0.0)
            growth_rates.append(0.0)
            decay_rates.append(0.0)
            active_status.append(False)
        self.res_init = Resources(jnp.array(Xs), jnp.array(Ys), jnp.array(vals), jnp.array(rads), jnp.array(growth_rates), jnp.array(decay_rates), jnp.array(active_status))
        
        # initialize agents
        Xs = []
        Xdots = []
        Ys = []
        Ydots = []
        angs = []
        angdots = []
        energy = []
        active_status = []

        key, *seven_active_agents_keys = jax.random.split(key, 7*num_active_agents+1)
        self.key = key

        for i in range(num_active_agents):
            Xs.append(jax.random.uniform(seven_active_agents_keys[i], shape=(), minval=0.0, maxval=2*pi))
            Xdots.append(jax.random.uniform(seven_active_agents_keys[num_active_agents+i], shape=(), minval=-1.0, maxval=1.0))
            Ys.append(jax.random.uniform(seven_active_agents_keys[2*num_active_agents+i], shape=(), minval=0.0, maxval=2*pi))
            Ydots.append(jax.random.uniform(seven_active_agents_keys[3*num_active_agents+i], shape=(), minval=-1.0, maxval=1.0))
            angs.append(jax.random.uniform(seven_active_agents_keys[4*num_active_agents+i], shape=(), minval=0.0, maxval=2*pi))
            angdots.append(jax.random.uniform(seven_active_agents_keys[5*num_active_agents+i], shape=(), minval=-1.0, maxval=1.0))
            energy.append(jax.random.uniform(seven_active_agents_keys[6*num_active_agents+i], shape=(), minval=5.0, maxval=10.0))
            active_status.append(True)
        #zero padding
        for i in range(num_total_agents-num_active_agents):
            Xs.append(0.0)
            Xdots.append(0.0)
            Ys.append(0.0)
            Ydots.append(0.0)
            angs.append(0.0)
            angdots.append(0.0)
            energy.append(0.0)
            active_status.append(False)
        
        self.agents_init = Agents(jnp.array(Xs), jnp.array(Xdots), jnp.array(Ys), jnp.array(Ydots), jnp.array(angs), jnp.array(angdots), jnp.array(energy), jnp.array(active_status))
        print(self.agents_init)
    
    @partial(jax.jit, static_argnums=(0,))
    def dynamics(self, resources:Resources, agents:Agents, actions:jnp.array, dt:float):
        agent_Xs = agents.Xs + dt*agents.Xdots
        agent_Xdots = agents.Xdots + dt*actions[:,0]
        agent_Ys = agents.Ys + dt*agents.Ydots
        agent_Ydots = agents.Ydots + dt*actions[:,1]
        agent_angs = agents.angs + dt*agents.angdots
        agent_angdots = agents.angdots + dt*actions[:,2]
        

        # resource pos data
        sin_res_xs = jnp.sin(resources.Xs)
        cos_res_xs = jnp.cos(resources.Xs)
        sin_res_ys = jnp.sin(resources.Ys)
        cos_res_ys = jnp.cos(resources.Ys)

        # agent pos data
        sin_agent_xs = jnp.sin(agents.Xs)
        cos_agent_xs = jnp.cos(agents.Xs)
        sin_agent_ys = jnp.sin(agents.Ys)
        cos_agent_ys = jnp.cos(agents.Ys)

        # agent-resource interaction
        # distance matrix
        def one_agent_all_res_dist(sin_res_xs, cos_res_xs, sin_res_ys, cos_res_ys, sin_agent_x, cos_agent_x, sin_agent_y, cos_agent_y):
            return jnp.sqrt((sin_res_xs - sin_agent_x)**2 + (cos_res_xs - cos_agent_x)**2 + (sin_res_ys - sin_agent_y)**2 + (cos_res_ys - cos_agent_y)**2)
        dist_mat = jax.vmap(one_agent_all_res_dist, in_axes=(None, None, None, None, 0, 0, 0, 0), out_axes=0)(sin_res_xs, cos_res_xs, sin_res_ys, cos_res_ys, sin_agent_xs, cos_agent_xs, sin_agent_ys, cos_agent_ys)
        
        #energy consumption matrix
        def one_agent_all_res_energy(dist_mat_row, all_res_vals, all_res_rads):
            def one_agent_one_res_energy(dist, val, rad):
                return jax.lax.cond(dist < rad, lambda _: 0.2*val, lambda _: 0.0, None)
            return jax.vmap(one_agent_one_res_energy, in_axes=(0, 0, 0), out_axes=0)(dist_mat_row, all_res_vals, all_res_rads)
        energy_matrix = jax.vmap(one_agent_all_res_energy, in_axes=(0, None, None), out_axes=0)(dist_mat, resources.vals, resources.rads)
        

        # resource dynamics
        res_vals = resources.vals + dt*resources.vals*(resources.growth_rates - resources.decay_rates*resources.vals)
        res_vals = res_vals - jnp.sum(energy_matrix, axis=0)
        res_rads = 0.1*res_vals

        # agent energy dynamics
        agent_energy = agents.energy - dt*dt*jnp.linalg.norm(actions, axis=1) + jnp.sum(energy_matrix, axis=1)

        resources = Resources(resources.Xs, resources.Ys, res_vals, res_rads, resources.growth_rates, resources.decay_rates, resources.active_status)
        agents = Agents(agent_Xs, agent_Xdots, agent_Ys, agent_Ydots, agent_angs, agent_angdots, agent_energy, agents.active_status)
        return resources, agents
    
    @partial(jax.jit, static_argnums=(0,))
    def reset(self, resources:Resources, key):
        key, *total_res_keys = jax.random.split(key, self.num_total_res+1)
        total_res_keys = jnp.array(total_res_keys)

        key, *total_agents_keys = jax.random.split(key, self.num_total_agents+1)
        total_agents_keys = jnp.array(total_agents_keys)
        
        def reset_one_resource(x, y, val, growth_rate, active_status, key):
            return jax.lax.cond(jnp.logical_and(val < 0.1, active_status), lambda _:jnp.concatenate([
                                                                                                    jax.random.uniform(key, shape=(2,), minval=0.0, maxval=2*pi),
                                                                                                    jax.random.uniform(key, shape=(1,), minval=5.0, maxval=10.0),
                                                                                                    jax.random.uniform(key, shape=(1,), minval=0.0, maxval=1.0),
                                                                                                    ]),
                                                                                                    lambda _: jnp.array([x, y, val, growth_rate]), None)
        reset_resource_data = jax.vmap(reset_one_resource)(resources.Xs, resources.Ys, resources.vals, resources.growth_rates, resources.active_status, total_res_keys[:self.num_total_res])

        reset_res = Resources(reset_resource_data[:,0], reset_resource_data[:,1], reset_resource_data[:,2], 0.1*reset_resource_data[:,2], reset_resource_data[:,3], 0.1*reset_resource_data[:,3], reset_resource_data[:,2]>0.1)
        return reset_res, key
    
    @partial(jax.jit, static_argnums=(0,))
    def add_one_ressource(self, resources:Resources, key):
        # step 1: find the first inactive resource
        inactive_index = jnp.argmax(resources.active_status == False)
        # step 2: generate random values for the resource
        key, *four_subkeys = jax.random.split(key, 5)
        Xs = resources.Xs.at[inactive_index].set(jax.random.uniform(four_subkeys[0], shape=(), minval=0.0, maxval=2*pi))
        Ys = resources.Ys.at[inactive_index].set(jax.random.uniform(four_subkeys[1], shape=(), minval=0.0, maxval=2*pi))
        Vals = resources.vals.at[inactive_index].set(jax.random.uniform(four_subkeys[2], shape=(), minval=0.0, maxval=10.0))
        Rads = resources.rads.at[inactive_index].set(0.1*resources.vals[inactive_index])
        Growth_Rates = resources.growth_rates.at[inactive_index].set(jax.random.uniform(four_subkeys[3], shape=(), minval=0.0, maxval=1.0))
        Decay_Rates = resources.decay_rates.at[inactive_index].set(0.1*resources.growth_rates[inactive_index])
        Active_Status = resources.active_status.at[inactive_index].set(True)

        # step 3: update the resource
        reset_one_more = Resources(Xs, Ys, Vals, Rads, Growth_Rates, Decay_Rates, Active_Status)
        return reset_one_more, key
    
    @partial(jax.jit, static_argnums=(0,))
    def delete_one_resource(self, resources:Resources):
        #step 1: find the last active resource
        active_index = jnp.argmin(resources.active_status == True) - 1
        #step 2: kill the resource
        Xs = resources.Xs.at[active_index].set(0.0)
        Ys = resources.Ys.at[active_index].set(0.0)
        Vals = resources.vals.at[active_index].set(0.0)
        Rads = resources.rads.at[active_index].set(0.0)
        Growth_Rates = resources.growth_rates.at[active_index].set(0.0)
        Decay_Rates = resources.decay_rates.at[active_index].set(0.0)
        Active_Status = resources.active_status.at[active_index].set(False)

        #step 3: update the resource
        reset_one_less = Resources(Xs, Ys, Vals, Rads, Growth_Rates, Decay_Rates, Active_Status)
        return reset_one_less
    
    @partial(jax.jit, static_argnums=(0,))
    def add_one_agent(self, agents:Agents, key):
        inactive_index = jnp.argmax(agents.active_status == False)
        key, *seven_subkeys = jax.random.split(key, 8)
        Xs = agents.Xs.at[inactive_index].set(jax.random.uniform(seven_subkeys[0], shape=(), minval=0.0, maxval=2*pi))
        Xdots = agents.Xdots.at[inactive_index].set(jax.random.uniform(seven_subkeys[1], shape=(), minval=-1.0, maxval=1.0))
        Ys = agents.Ys.at[inactive_index].set(jax.random.uniform(seven_subkeys[2], shape=(), minval=0.0, maxval=2*pi))
        Ydots = agents.Ydots.at[inactive_index].set(jax.random.uniform(seven_subkeys[3], shape=(), minval=-1.0, maxval=1.0))
        angs = agents.angs.at[inactive_index].set(jax.random.uniform(seven_subkeys[4], shape=(), minval=0.0, maxval=2*pi))
        angdots = agents.angdots.at[inactive_index].set(jax.random.uniform(seven_subkeys[5], shape=(), minval=-1.0, maxval=1.0))
        energy = agents.energy.at[inactive_index].set(jax.random.uniform(seven_subkeys[6], shape=(), minval=5.0, maxval=10.0))
        active_status = agents.active_status.at[inactive_index].set(True)

        reset_one_more = Agents(Xs, Xdots, Ys, Ydots, angs, angdots, energy, active_status)
        return reset_one_more, key
    
    @partial(jax.jit, static_argnums=(0,))
    def delete_one_agent(self, agents:Agents):
        active_index = jnp.argmin(agents.active_status == True) - 1
        Xs = agents.Xs.at[active_index].set(0.0)
        Xdots = agents.Xdots.at[active_index].set(0.0)
        Ys = agents.Ys.at[active_index].set(0.0)
        Ydots = agents.Ydots.at[active_index].set(0.0)
        angs = agents.angs.at[active_index].set(0.0)
        angdots = agents.angdots.at[active_index].set(0.0)
        energy = agents.energy.at[active_index].set(0.0)
        active_status = agents.active_status.at[active_index].set(False)

        reset_one_less = Agents(Xs, Xdots, Ys, Ydots, angs, angdots, energy, active_status)
        return reset_one_less
    
    @partial(jax.jit, static_argnums=(0,))
    def step(self, resources:Resources, agents:Agents, actions:jnp.array, dt:float, sim_time:float, key:jnp.array):
        sim_time += dt

        reset_resources, key = self.reset(resources, key)
        resources_next, agents_next = self.dynamics(reset_resources, agents, actions, dt)

        ((final_resources_add, key) , flag_res_add) = jax.lax.cond(sim_time>1000.0, lambda _: (self.add_one_ressource(resources_next, key) , 1), lambda _: ((resources_next ,key) , 0), None)
        (final_resources_rem , flag_res_rem) = jax.lax.cond(sim_time>20.0, lambda _: (self.delete_one_resource(final_resources_add), 1), lambda _: (final_resources_add, 0), None)
        
        ((final_agents_add, key) , flag_agents_add) = jax.lax.cond(sim_time>1000.0, lambda _: (self.add_one_agent(agents_next, key) , 1), lambda _: ((agents_next ,key) , 0), None)
        (final_agents_rem , flag_agents_rem) = jax.lax.cond(sim_time>40.0, lambda _: (self.delete_one_agent(final_agents_add), 1), lambda _: (final_agents_add, 0), None)

        flags = jnp.array([flag_res_add, flag_res_rem, flag_agents_add, flag_agents_rem])

        return final_resources_rem, final_agents_rem, sim_time, flags, key

In [8]:
env = Forajax()
res = env.res_init
key = env.key
sim_time = env.sim_time
print(res)
agents = env.agents_init
actions = jnp.array([[0.0, 0.0, 0.0]])
dt = 0.1
#env.dynamics(res, agents, actions, dt)
for i in range(1000):
    print(i)
    res, agents, sim_time, flags, key = env.step(res, agents, actions, dt, sim_time, key)

Agents(Xs=Array([5.5600796, 0.       , 0.       , 0.       , 0.       ], dtype=float32), Xdots=Array([-0.82164264,  0.        ,  0.        ,  0.        ,  0.        ],      dtype=float32), Ys=Array([0.0567498, 0.       , 0.       , 0.       , 0.       ], dtype=float32), Ydots=Array([0.41381764, 0.        , 0.        , 0.        , 0.        ],      dtype=float32), angs=Array([0.2633204, 0.       , 0.       , 0.       , 0.       ], dtype=float32), angdots=Array([-0.46721315,  0.        ,  0.        ,  0.        ,  0.        ],      dtype=float32), energy=Array([6.186626, 0.      , 0.      , 0.      , 0.      ], dtype=float32), active_status=Array([ True, False, False, False, False], dtype=bool))
Resources(Xs=Array([1.6011232 , 2.5005746 , 5.398833  , 5.36206   , 2.0646498 ,
       2.812676  , 5.5102344 , 4.631612  , 0.10558852, 4.5209575 ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ],      dtype=floa