# Agent-based model of ant foraging

This notebook presents the proposed model of ant foraging. It allows to run the model and perform experiments (multiple organized model runs) for sensitivity analysis and model validation. 

This presented model code is an adaption of the code by Thierry Hoinville, 2018, as part of the following work:
Hoinville T, Wehner R. (2018) Optimal multiguidance integration in insect navigation 
Proceedings of the National Academy of Sciences, 115 (11) 2824-2829; 
DOI: [10.1073/pnas.1721668115](https://doi.org/10.1073/pnas.1721668115)

In [1]:
%matplotlib inline
# !pip install matplotlib
# !pip install agentpy
# !pip install salem
# !pip install shapely
# !pip install rasterio
# !pip install nbformat
# !pip install geopandas
# !pip install seaborn

In [3]:
import agentpy as ap
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math
import salem
import time
from pathlib import Path
from shapely.affinity import scale
from matplotlib import colors
import rasterio
from rasterio.features import rasterize, Affine

In [4]:
def PI(p,k):
    return lambda x: k*(p-x)

# Uncomment to suppress warning for division by zero
# np.seterr(invalid='ignore')

# Local fields
def Place(p, k,d):
    return lambda x: np.where(
        x==p, 0+0j, # singular point (still raises warnings!)
        k*np.exp(-np.abs(p-x)/d) * (p-x)/np.abs(p-x)
        # k * e^(-länge_vector/d) * vector/länge vector
    )

# NOT USED IN MODEL!
def Route(p,q, k,d):
    u = q-p
    L = np.abs(u)
    u /= L
    u_ = u.conjugate()
    def field(x):
        tau = ((x-p)*u_).real
        tau = np.clip(tau, 0, L)
        r = p + tau*u # point [p,q] the closest to x
        z = (r-x)*u_
        dzz = d+z-np.conjugate(z)
        return k*np.exp(-np.abs(r-x)/d) * dzz/np.abs(dzz) * u
    return field

In [5]:
class DesertAnt(ap.Agent):

    def setup(self):
        """Initializes an agent of type desert ant with its attributes
        """
        # Seeded random number generators
        self.seedU    = self.model.nprandom.integers(1,1e8)
        self.seedN    = self.model.nprandom.integers(1,1e8)
        self.seedC    = self.model.nprandom.integers(1,1e8)
        self.U = np.random.default_rng(self.seedU).uniform
        self.N = np.random.default_rng(self.seedN).normal
        self.cookieN =  np.random.default_rng(self.seedC).normal

        """ State variables """
        # states: 0 = initial search, 2 = oriented search, 1 = homing
        # Thus: motivation_state%2 == 0 -> foraging/searching, motivation_state%2 == 1 -> homing 
        self.motivation_state = 0
        self.name = f'Ant {self.id}'
        self.species = 'Cataglyphis velox'
        # starting location = nest 
        self.xp = complex(self.model.p.nest)
        # intial phi (heading)
        self.phi = self.N(self.model.p.iphi_mean,self.model.p.iphi_sd)
        # inital speed
        self.speed   = max(0,self.N(self.model.p.open_speed_median, self.model.p.open_speed_sd) * self.model.p.speed_impact)
        # naivigation memories
        self.search_memory  = {
        }
        self.homing_memory  = {
            'PI': PI(complex(self.model.p.nest), self.model.p.pi_k),
            'LM': Place(complex(self.model.p.nest), self.model.p.place_k,self.model.p.place_d),
        }
        # Set maximum possible acceleration
        self.max_abs_acceleration_per_timestep = 6.2 * self.model.p.timestep
        # Set distance at which cookie shall be provided
        self.cookie_treshold = self.cookieN(self.model.p.cookie_mean, self.model.p.cookie_sd)

        """ Derived state variables """
        self.vegetation = 0

        """ Functional variables for observation """
        self.run_counter = 1
        self.traj_counter = 1
        self.previous_states = []
        self.wro = False
        self.finished = False
        self.feeder_locations = {
        }
        self.feeder_locations_legacy = {
        }
        self.current_trajectory_list = []
        self.traj_list = []
        
    
    def setup_grid(self):
        """ Init land cover grid """
        self.grid = self.model.grid
        # continuous location 0,0 is at grid cell 150,150
        self.grid.move_to(self, (150,150))

        

    def provide_cookie(self):
        """ Checks whether to provide food for the ant and eventually creates feeder location """
        # if searching
        if(self.motivation_state%2 == 0):
            # if no active feeder in memory
            if (not next((key for key, goal_position in self.feeder_locations.items() if True), None)):
                # if distance to nest is larger than treshold
                if np.abs(self.xp - complex(self.model.p.nest)) > self.cookie_treshold:
                    # if not within minimum distance to some current or past feeder 
                    if (not next((key for key, goal_position in self.feeder_locations_legacy.items() if np.abs(self.xp - goal_position) < self.model.p.min_feeder_dist), None)):
                        # create feeder (= provide cookie crumbs)
                        self.feeder_locations["F"+str(self.motivation_state)] = self.xp
                        self.feeder_locations_legacy["F"+str(self.motivation_state)] = self.xp


    def within_reach_of(self):
        """ Checks whether ant is wihtin reach of (wro) any of its current destinations 
            (nest during homing or any feeder while searching)
        """
        # if searching for food
        if(self.motivation_state%2 == 0):
            current_goals = self.feeder_locations
        # if homing
        elif(self.motivation_state == 1):
            current_goals = {"Nest": complex(self.model.p.nest)}

        # get first goal nearby (dist < 0.5m)
        found_goal = next((key for key, goal_position in current_goals.items() if np.abs(self.xp - goal_position) < 0.5), None)

        if(found_goal):
            # goal already in memory?
            if ((self.motivation_state%2 == 0) and ("PI_"+str(found_goal) in self.search_memory)):
                
                # feeder disappears with 50% chance
                x = self.U()
                if(x > 0.5): # if feeder is exhausted, then remove from memory
                    del self.search_memory["PI_"+str(found_goal)]
                    del self.search_memory["LM_"+str(found_goal)]
                    del self.feeder_locations[found_goal]
                else: # else grab food and turn around 180 degree
                    self.wro = True
                    self.phi += math.pi
                return self.wro
            else: # reached nest or new feeder, turn around 180 degree
                self.wro = True
                self.phi += math.pi
                # if goal is a new feeder store it in memory
                if ((self.motivation_state%2 == 0) and ("PI_"+str(found_goal) not in self.search_memory)):
                        self.search_memory["PI_"+str(found_goal)] = PI(current_goals[found_goal], 0.68)
                        self.search_memory["LM_"+str(found_goal)] = Place(current_goals[found_goal], 7.5, 4)
        return self.wro


    def optimal_travel_vector(self):
        """ Calculates optimal travel vector (in physical coord syst) """
        if(self.motivation_state%2 == 0):
            return sum(v(self.xp) for v in self.search_memory.values() if v)
        else:
            return sum(v(self.xp) for v in self.homing_memory.values() if v)

    def check_for_vegetation(self):
        """ Checks for shrub in land cover grid cell the agent is positioned in"""
        if (self.grid.vegetation[self.grid.positions[self]] == 1):
            return 1
        else:
            return 0

    def run(self):
        """ Ant movement: Combines navigation guidance, random walk, speed and time into movement."""
        # Steering unit
        #  Recall optimal travel vector (in physical coord syst)
        vT = self.optimal_travel_vector()
        # W: Wiener process (= standard Brownian motion)
        dW = self.N(0, self.model.sqrt_dt)
        # Steering law
        xi = np.angle(vT) # optimal absolute travel direction
        dphi = -self.model.kphi * np.abs(vT) * np.sin(self.phi-xi) * self.model.dt
        dphi += self.model.sigma * dW
        # change heading
        self.phi += dphi
        
        # Velocity unit
        if(self.model.p.context): # check land cover type
            self.vegetation = self.check_for_vegetation()
            
        # accelerate
        if(self.vegetation == 1):
            pref_speed = self.N(self.model.p.veg_speed_median, self.model.p.veg_speed_sd)
            pref_accel = max(0, pref_speed * self.model.p.veg_speed_impact) - self.speed            
        else:
            pref_speed = self.N(self.model.p.open_speed_median, self.model.p.open_speed_sd)
            pref_accel = max(0, pref_speed * self.model.p.speed_impact) - self.speed
        # make sure ant does not exceed maximum possible acceleration
        if(pref_accel < -self.max_abs_acceleration_per_timestep):
            actual_accel = -self.max_abs_acceleration_per_timestep
        elif(pref_accel > self.max_abs_acceleration_per_timestep):
            actual_accel = self.max_abs_acceleration_per_timestep
        else:
            actual_accel = pref_accel
        
        # change walking speed
        self.speed += actual_accel        
        
        # Movement unit
        self.movement = self.speed*np.exp(1j*self.phi) * self.model.dt
        self.xp   += self.movement

    def update_grid_position(self):
        """Update position on land cover grid"""
        # ( P(0,0) of grid is in top left corner, P(150,150) of grid is P(0,0) in continuous space)
        x = int(150 + self.xp.real * 10)
        y = int(150 - self.xp.imag * 10)
        self.grid.move_to(self,(y,x))

    def update_trajectory(self):
        """Update the trajectory by adding new position to list"""
        current_row = [self.traj_counter,self.model.t,self.xp,self.phi,self.motivation_state,self.vegetation]
        self.current_trajectory_list.append(current_row)

    def store_traj_as_csv(self):
        """ Store a full trajectory to file"""
        current_trajectory = pd.DataFrame(self.current_trajectory_list, columns=['trial','time','x','phi','motivation','vegetation']).set_index(['trial','time'])
        
        # add ids if not present (depending on agentpy simple run or experiment run)
        if(self.model._run_id == None):
            self.model._run_id = ["X","X"]
        sampleID = self.model._run_id[0]
        iteration = self.model._run_id[1]
        # set path and cerate dir
        path = self.model.model_output_path + str(sampleID) + "_" + str(iteration)
        Path(path).mkdir(parents=True, exist_ok=True)
        # add some information to the trajectory
        current_trajectory['y'] = current_trajectory.x.apply(np.imag)
        current_trajectory['x'] = current_trajectory.x.apply(np.real)
        current_trajectory['cookie'] = current_trajectory.motivation % 2
        current_trajectory = current_trajectory[["x","y","cookie","vegetation"]]
        filename = "M_Ant"+ str(self.id) + "_R" + str(self.run_counter) + "_M" + str(self.motivation_state) + ".csv"
        current_trajectory.to_csv(path + "/" + filename, sep=';',index=False)
        
        # for visualization add current trajectory to full and reset trajectory
        if(self.p.plot_trjs):
            self.traj_list.append(self.current_trajectory_list)
            
        # reset current trajectory
        self.current_trajectory_list = []
       
    def return_all_trjs(self):
        """return all trajectories of an ant as dataframe (for visualization)"""
        return pd.DataFrame([item for l in self.traj_list for item in l ], columns=['trial','time','x','phi','motivation','vegetation']).set_index(['trial','time'])
    
    def update_motivation_state(self):
        """Update the motivation state of an agent"""
        # store state for metadata file
        self.previous_states.append(self.motivation_state)
        # change motivational state (0 = initial search, 2 = oriented search, 1 = homing)
        if(self.motivation_state%2 == 0):
            self.motivation_state = 1
        elif(self.motivation_state == 1):
            # self.current_feeder = self.rng.integers(0, len(self.search_memory))
            self.motivation_state = 2
            self.run_counter += 1
            # set new cookie provider treshold
            self.cookie_treshold = self.cookieN(self.model.p.cookie_mean, self.model.p.cookie_sd)
        self.traj_counter += 1
        # reset within reach of attribute
        self.wro = False
        # check if all runs were completed
        if(self.run_counter > self.model.p.foraging_runs):
            self.finished = True

In [6]:
class CADAModel(ap.Model):
    

    def setup(self):
        """Init model"""
        # define path to store model output in and create output directories
        self.model_output_path = "./Experiments/E_%d/trjs/" % self.p.epoch_time
        Path("./Experiments/E_%d/trjs" % self.p.epoch_time).mkdir(parents=True, exist_ok=True)
        Path("./Experiments/E_%d/indices" % self.p.epoch_time).mkdir(parents=True, exist_ok=True)

        # init ant agents 
        self.agents = ap.AgentList(self, self.p.agents, DesertAnt)

        # noise variance
        self.sigma2  =  self.p.model_sigma2 * self.p.sigma_impact
        # steering coefficient
        self.kphi    = self.p.kphi * self.p.kphi_impact
        
        # time sampling
        self.dt = self.model.p.timestep
        self.sigma   = np.sqrt(self.sigma2)
        self.sqrt_dt = np.sqrt(self.dt)
        
        # create grid environment
        self.mask = self.create_global_grid(300, 300, False)
        self.grid = ap.Grid(self, (300, 300), track_empty=False, check_border=False)
        self.grid.add_field('vegetation', values=self.mask)
        self.grid.add_agents(self.agents)
        self.agents.setup_grid()

        # boundaries and space sampling for belief vector plots         
        self.bounds   = [complex(self.p.nest) + -15-15j,complex(self.p.nest) + 15+15j]
        self.ds = 2

        
    def step(self):
        """Call methods for agents in every timestep. 
        """
        # select only agents that have not finished the given number of foraging runs
        not_finished = self.agents.select(self.agents.finished == False)
        # provide cookie if necessary yet
        not_finished.provide_cookie()
        
        # select agents, that reached their destiantion (nest or a feeder)
        at_destination = not_finished.select(not_finished.within_reach_of())
        at_destination.store_traj_as_csv()
        at_destination.update_motivation_state()

        # select agents, that have NOT reached their destiantion yet
        traveling = not_finished.select(not_finished.wro == False)
        traveling.run()
        traveling.update_trajectory()
        traveling.update_grid_position()


    def update(self):
        """ After each timestep check if all agents have finished their foraging runs - eventually, stop model.
        """
        # If all agents are finished, stop the model
        if(len(self.agents.select(self.agents.finished == True)) == self.p.agents):
            self.model.stop()

    def end(self):
        """After all agents have finished:
        """
        # create meatdata file of model ouput trajectories
        self.create_ant_metadata()
        
        # optionally plot trajectories
        if(self.p.plot_trjs):
            
            # create belief vector fields
            # oriented search field
            self.fields_0 = {e.name:self.map_bvf(e.search_memory) for e in self.agents}
            self.fields_0 = pd.concat(self.fields_0,names=['ant'])
            self.fields_0 = (
                self.fields_0.stack().rename('v')
                .reset_index('x')
                .apply([np.real,np.imag])
            )
            # homing field
            self.fields_1 = {e.name:self.map_bvf(e.homing_memory) for e in self.agents}
            self.fields_1 = pd.concat(self.fields_1,names=['ant'])
            self.fields_1 = (
                self.fields_1.stack().rename('v')
                .reset_index('x')
                .apply([np.real,np.imag])
            )

            # get all trajectories of all ants
            self.trajs   = {e.name:e.return_all_trjs() for e in self.agents}
            self.trajs   = pd.concat(self.trajs,names=['ant'])
            self.trajs['real'] = self.trajs.x.apply(np.real)
            self.trajs['imag'] = self.trajs.x.apply(np.imag)

            # Group trajectories by motivational state
            self.trajs_initial = self.trajs.query('motivation == 0')
            self.trajs_oriented = self.trajs.query('motivation == 2')
            self.trajs_homing = self.trajs.query('motivation == 1')


            # define color mapper function for belief vector field plots
            def get_nav_type(x):
                if("PI" in x):
                    return (0.651, 0.118, 0.302, 0.7)
                elif("LM" in x):
                    return (0.361, 0.58, 0.051, 0.7)
                elif("Travel" in x):
                    return (0, 0, 0, 0.5)
                elif("Training" in x):
                    return 3
                elif("Channel" in x):
                    return 2
                elif("FOOD" in x):
                    return 1
            
            # create three subplots for each agent, one for each motivational state
            fig, axs = plt.subplots(self.p.agents, 3, figsize=(24, self.p.agents*7), sharex=True, sharey=True, squeeze=0)

            for ant, ax in zip(self.agents, axs):
                # (optionally) plot belief vector fields
                if(self.model.p.plot_bvf):
                    # oriented search
                    dat = self.fields_0.loc[ant.name]
                    ax[1].quiver(*dat.values.T,color=dat.index.map((lambda x: get_nav_type(x))),
                            scale_units='width')
                    # homing
                    dat = self.fields_1.loc[ant.name]
                    ax[2].quiver(*dat.values.T,color=dat.index.map((lambda x: get_nav_type(x))),
                            scale_units='width')   
                
                # plot land cover grid
                if(self.model.p.context):
                    cmap = colors.ListedColormap(['#f2d2a9','#caff70'])
                    ll = self.bounds[0]
                    ur = self.bounds[1]
                    extents = [ll.real,ur.real,ll.imag,ur.imag]
                    ax[0].imshow(self.grid.vegetation, cmap=cmap, origin='upper', extent=extents)
                    ax[1].imshow(self.grid.vegetation, cmap=cmap, origin='upper', extent=extents)
                    ax[2].imshow(self.grid.vegetation, cmap=cmap, origin='upper', extent=extents)

                # plot trajectories (catch error in case data is trajectories are missing)
                try:
                    # plot initial search trajectories
                    dat_s = self.trajs_initial.loc[ant.name]
                    dat_s_veg = dat_s[dat_s.vegetation == 1]
                    dat_s_veg_ls = np.split(dat_s_veg, np.flatnonzero(np.diff(dat_s_veg.index.get_level_values(1)) != 1) + 1)
                    self.plot_by_group(dat_s, dat_s_veg_ls, 'trial', ax[0], 'deepskyblue', 'green', 'blue')
                except:
                    dat_s = pd.DataFrame()

                try:
                    # plot oriented search trajectories
                    dat_f = self.trajs_oriented.loc[ant.name]
                    dat_f_veg = dat_f[dat_f.vegetation == 1]
                    dat_f_veg_ls = np.split(dat_f_veg, np.flatnonzero(np.diff(dat_f_veg.index.get_level_values(1)) != 1) + 1)
                    self.plot_by_group(dat_f, dat_f_veg_ls, 'trial', ax[1], 'k', 'green', 'k')

                except:
                    dat_f = pd.DataFrame()

                try:
                    # plot homing trajectories
                    dat_h = self.trajs_homing.loc[ant.name]
                    dat_h_veg = dat_h[dat_h.vegetation == 1]
                    dat_h_veg_ls = np.split(dat_h_veg, np.flatnonzero(np.diff(dat_h_veg.index.get_level_values(1)) != 1) + 1)
                    self.plot_by_group(dat_h, dat_h_veg_ls, 'trial', ax[2], 'red', 'green', 'red')
                except:
                    dat_h = pd.DataFrame()
            
                # plot feeder and nest locations
                for key,complex_dest in ant.feeder_locations_legacy.items():
                    ax[0].scatter(complex_dest.real,complex_dest.imag,c='yellow',s=20,marker='s',zorder=2),
                    ax[1].scatter(complex_dest.real,complex_dest.imag,c='yellow',s=20,marker='s',zorder=2),
                    ax[2].scatter(complex_dest.real,complex_dest.imag,c='yellow',s=20,marker='s',zorder=2),
                complex_nest = complex(self.p.nest)
                ax[0].scatter(complex_nest.real,complex_nest.imag,c='k',s=10,marker='s',zorder=2),
                ax[1].scatter(complex_nest.real,complex_nest.imag,c='k',s=10,marker='s',zorder=2),
                ax[2].scatter(complex_nest.real,complex_nest.imag,c='k',s=10,marker='s',zorder=2),

                # subplot titles
                ax[0].set(aspect='equal',title=f"{ant.name} - initial search",xlabel='x',ylabel='y')
                ax[1].set(aspect='equal',title=f"{ant.name} - oriented search",xlabel='x',ylabel='y')
                ax[2].set(aspect='equal',title=f"{ant.name} - homing search",xlabel='x',ylabel='y')
                plt.setp(ax, xlabel="m", ylabel="m", xlim=(-12,12), ylim=(-12,12))
            # set boundaries and display figure
            fig.tight_layout()
            plt.savefig("model_outputs/" + str(self.model.p.epoch_time) + ".pdf", bbox_inches='tight')
            plt.show()


    def plot_by_group(self, data, veg_data, group_str, ax, c1, c2, c3):
        """Plot segements of a trajectory in corresponding color"""
        data.groupby(group_str).plot.line(
            ax=ax,x='real',y='imag',color=c1,linewidth=1,legend=False)
        for d in veg_data:
            d.plot.line(ax=ax,x='real',y='imag',color=c2,linewidth=1,legend=False)
        data.groupby(group_str).nth(-1).plot.scatter(
            ax=ax,x='real',y='imag',color=c3)


    def map_bvf(self, memory):
        """Return spacially sampled vector fields"""
        [xmin,xmax] = np.real(self.model.bounds)
        [ymin,ymax] = np.imag(self.model.bounds)
        ds = float(self.model.ds)
        x,y = np.ogrid[xmin:xmax+ds:ds,ymin:ymax+ds:ds]
        x = (x+1j*y).flatten()
        v = {field: v(x) for field,v in memory.items() if v}
        v = pd.DataFrame(v,index=pd.Index(x,name='x'))
        v['Travel'] = v.sum(axis='columns') # travel field = sum of all fields
        return v


    def create_global_grid(self, nx, ny, all_touched):
        """Return shrub polygons as grid"""
        # get the appropriate polygon shape file path
        if(self.model.p.env_ant == 3):
            poly_path = "model_inputs/ant03_scaled.shp"
        elif(self.model.p.env_ant == 5):
            poly_path = "model_inputs/ant05_scaled.shp"
        elif(self.model.p.env_ant == 6):
            poly_path = "model_inputs/ant06_scaled.shp"
        elif(self.model.p.env_ant == 11):
            poly_path = "model_inputs/ant11_scaled.shp"

        # empty mask
        mask = np.zeros((ny, nx), dtype=np.int16)
        # read shapefile and scale form meter to decimeter resolution (x10)
        vegetation = salem.read_shapefile(poly_path, cached=False)
        vegetation = vegetation.scale(xfact=10, yfact=10, origin=(0,0))

        # transform polygons to stay in mask boundaries (move origin from 0,0 to 150,150)
        transformation = Affine(1, 0.0, -150.0, 0.0, -1, 150)
        # convert vector to raster
        mask = rasterize(vegetation.geometry, out=mask, transform=transformation)
        return mask
    
    def create_ant_metadata(self):
        """Creates metadata for model run and saves to file"""
        # init empty DataFrame with column names
        metadf = pd.DataFrame(columns = ["filename", "trackname", "run", "number", "behavior", "scale_factor", "col", "environment_ant", "sampleID", "iteration"])
        # if model is run out of experiment, set run id and sample ID to "X"
        if(self._run_id == None):
            self._run_id = ["X","X"]
        sampleID = self._run_id[0]
        iteration = self._run_id[1]
        
        # make sure parent folders of file path exist
        path = self.model_output_path + str(sampleID) + "_" + str(iteration)
        Path(path).mkdir(parents=True, exist_ok=True)
        
        # check how often ant went for foraging
        run_ids = []
        i = 1
        while i <= self.model.p.foraging_runs:
            run_ids.append(i)
            run_ids.append(i)
            i += 1

        # create meatdata rows 
        for agent in self.agents:
            for run_id, state in zip(run_ids,agent.previous_states):
                filename = path + "/M_Ant"+ str(agent.id) + "_R" + str(run_id) + "_M" + str(state) + ".csv"
                trackname = "MA"+ str(agent.id) + "RM" + str(run_id)
                run =  "M" + str(run_id)
                num = str(agent.id)
                if (state  == 1):
                    behavior = "homing"
                elif(state == 0):
                    behavior = "initial search"
                elif(state == 2):
                    behavior = "oriented search"
                scale_factor = 1
                col = "black"
                env_ant = self.p.env_ant
                new_row = pd.DataFrame([{"filename": filename, "trackname": trackname, "run": run, "number": num,
                                        "behavior": behavior, "scale_factor": scale_factor, "col": col, 
                                        "environment_ant": env_ant, "sampleID": sampleID, "iteration": iteration}])
                metadf = pd.concat([metadf, new_row],  ignore_index=True)
        
        # save to file
        metadf.to_csv(path + "/metadata.csv", sep=";",index=False)