In [None]:
from dataclasses import dataclass, field
import numpy as np
from typing import List, Callable, Tuple, Optional
import matplotlib.collections as m_collec
import matplotlib
import matplotlib.pyplot as plt
import scipy.stats as stats
import numpy.typing as npt
from collections import defaultdict
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.patches as patches
from typing import Set
import copy
import itertools



In [None]:
import matplotlib as mpl

mpl.rcParams['figure.dpi'] = 500
mpl.rcParams['savefig.dpi'] = 500
mpl.rcParams['font.size'] = 14
mpl.rcParams['legend.fontsize'] = 'medium'

In [None]:
plt.style.use('default')

In [None]:
# Constants
STEP_SIZE = 0.2 #ant_lengths per second
THRESHOLD = 0.45

%store -r above_hist
%store -r below_hist

%store -r above_angs
%store -r below_angs

%store -r above_angs_before_int
%store -r below_angs_before_int

In [None]:
above_angs_table = ['{:.2f}'.format(x) for x in above_angs]
below_angs_table = ['{:.2f}'.format(x) for x in below_angs]

In [None]:
# ', '.join(below_angs_table)

In [None]:
ABOVE_ANGS = above_angs
BELOW_ANGS = below_angs

# ABOVE_ANGS_BEFORE_INT = above_angs_before_int
# BELOW_ANGS_BEFORE_INT = below_angs_before_int

In [None]:
from scipy.stats import circmean

print(np.rad2deg(circmean(np.deg2rad(above_angs))))
np.rad2deg(circmean(np.deg2rad(below_angs)))

In [None]:
np.rad2deg(circmean(np.deg2rad(np.array([91, 272]))))

In [None]:
above_hist
below_hist

In [None]:
len(above_hist[1])

In [None]:
angles_for_hist = np.arange(180/16,361,360/16)
angles_for_hist

In [None]:
above_p = above_hist[0]/sum(above_hist[0])
below_p= below_hist[0]/sum(below_hist[0])

In [None]:
def prop_given(c: float) -> float:
    return np.random.exponential(0.15) * (1-c)

In [None]:
y = np.zeros(10000)
for i in range(10000):
    y[i] = np.random.exponential(0.15)
    
x = np.linspace(0, 1, 50)
ys = np.digitize(y, x)

counts = np.unique(ys, return_counts=True)
xs = np.take(x, counts[0]-1)
plt.plot(xs, counts[1])
print(np.mean(y))

In [None]:
run_model1 = True

In [None]:
# @dataclass(slots=True)
class Ant:
    p: npt.NDArray[np.float32]
    crop: float
    
    def move(self) -> None:
        self.p += (np.random.uniform(-1, 1, 2) * STEP_SIZE)

In [None]:
def cart2pol(x, y):
    r = np.sqrt(x**2 + y**2)
    alpha = np.arctan2(y, x)
    return np.array([r, alpha])

def pol2cart(r, alpha):
    x = r * np.cos(alpha)
    y = r * np.sin(alpha)
    return np.array([x, y])

In [None]:
def theta_to_beta(p: npt.NDArray[np.float32], theta: float) -> float:
    _, alpha = np.rad2deg(cart2pol(*p))
    beta = theta-180+alpha
    return beta

# def beta_to_theta(beta: float) -> float:
#     theta = beta + 180 - alpha
#     return theta

In [None]:
# @dataclass(slots=True)
@dataclass
class Forager:
    p: npt.NDArray[np.float32]
    bounds: npt.NDArray[np.float32]
    reflection_noise_factor: float
    crop: float = 0
    angle: float = 0
    exiting_crop: Optional[float] = None
    step: int = 0
    depth: int = 0
    trip: int = 0
    trip_length: int = 0
    prev_trip_length: int = 0
    interacted: bool = False
    sample_type: str = 'hist'
    sampling: Callable = None
    trajectory: dict = field(default_factory=lambda: defaultdict(list))
    interaction_pos_other_crop: dict = field(default_factory=dict)
    exiting_step_crop: dict = field(default_factory=dict)

        
    def __post_init__(self) -> None:
        if self.sample_type == 'hist':
            self.sampling = self.sample_angle_from_hist
        else:
            self.sampling = self.sample_angle_from_data
    
    # distance away from origin used for forager to fill on food
    def dist_from_origin(self, coor: npt.NDArray[np.float32]) -> bool:
        dist = np.linalg.norm(np.array([self.bounds[0], self.bounds[0]]) - coor)

        return bool(dist < 1)
    
    def move(self, col_state: float, other_ants) -> Optional[bool]:
        def move_by_angle() -> npt.NDArray[np.float32]:
            step = pol2cart(STEP_SIZE, np.deg2rad(self.angle))
            return self.p + step
        
        if self.dist_from_origin(self.p):
            self.exiting_step_crop[self.step] = [col_state, self.crop, self.depth]
            self.angle = self.enter_nest(other_ants)
            return True
        
        self.trip_length += 1

        self.exiting_crop = None
        potential_pos = move_by_angle()
        out_of_bounds = self.check_new_pos(potential_pos)
        
        while out_of_bounds.any():
            self.angle = self.reflection(np.argmin(out_of_bounds))
            potential_pos = move_by_angle()
            out_of_bounds = self.check_new_pos(potential_pos)
        
        self.p = potential_pos
        self.trip_depth()
        self.trajectory[self.trip].append([self.p, []])

    def trip_depth(self) -> None:
        current_depth = np.linalg.norm(self.p)
        if current_depth > self.depth:
            self.depth = current_depth

    def enter_nest(self, other_ants) -> float:
            self.crop = 1
            self.trip += 1
            self.prev_trip_length = self.trip_length
            self.trip_length = 0
            self.interacted = 0
            self.p = np.ones(2)
            self.depth = 0
            self.exiting_crop = self.crop
            self.trajectory[self.trip] = [[self.p, other_ants]]
            beta = np.random.uniform(-90, 180)
            return beta
    
    def check_new_pos(self, p: npt.NDArray[np.float32]) -> List[bool]:
        return (p < self.bounds[0]) | (p > self.bounds[1])
    
    def reflection(self, boundary: int) -> float:
        r_angle = np.deg2rad(self.angle)
        c = np.cos(r_angle)
        s = np.sin(r_angle)
        coors = [s, c]
        coors[boundary] *= -1 
        return np.rad2deg(np.arctan2(*coors)) + ((np.random.random() - 0.5) * self.reflection_noise_factor)

    def sample_angle_from_hist(self) -> float:
        if self.crop <= THRESHOLD:
            p = below_p
        else:
            p = above_p
        theta = np.random.choice(angles, p=p)
        return theta_to_beta(self.p, theta) 
    
    def sample_angle_from_data(self) -> float:
        if self.crop <= THRESHOLD:
            theta = np.random.choice(BELOW_ANGS)
        else:
            theta = np.random.choice(ABOVE_ANGS)

        return theta_to_beta(self.p, theta)

    @staticmethod
    def feed_duration_calc(given_amount: float) -> int:
        # return int(60 * given_amount)
        return 0

    def interaction(self, other_crop: float) -> Tuple[float, int]:
        self.interaction_pos_other_crop[self.step] = [self.p, other_crop]
        
        offered_amount = max(0, min(self.crop, prop_given(other_crop)))
#         print("offered amount is {}".format(offered_amount))
        given_amount = min(offered_amount, 1 - other_crop)
#         print("given_amoutn is {}".format(given_amount))
        self.crop -= given_amount
        self.interacted = True
        feed_duration = self.feed_duration_calc(given_amount)
        return given_amount, feed_duration

In [None]:
p = np.array([1,-1])
np.argmax((p < 0) | (p > 11))

In [None]:
# @dataclass(slots=True)
# class ModelAntClass:
#     bounds: npt.NDArray[np.float32]
#     biases: npt.NDArray[np.float32]
#     ants: list = field(default_factory=list)
#     foragers: list = field(default_factory=list)
    
#     def inialise(self, no_ants: int, no_foragers: int) -> None:
#         self.ants = [Ant(p, 0) for p in np.random.uniform(*self.bounds, size=(no_ants, 2))]
#         self.foragers = [Forager(p, 1, self.biases) for p in np.zeros([no_foragers, 2])]
    
#     def step(self) -> None:
#         pass

In [None]:
import csv

# @dataclass(slots=True)
@dataclass
class ModelAntArray:
    bounds: npt.NDArray[np.float32]
    no_ants: int
    sample_type: str
    forager_data: object
    interaction_data: object
    visit_data: object
    repeat: int = 1
    tick: int = 0
    waiting_duration: int = 0
    radius: float = 0.2
    minimum_interaction: float = 0.01
    reflection_noise_factor: float = 0.1
    ant_bounds: list = field(default_factory=list)
    angles: dict = field(default_factory=dict)
    ants: npt.NDArray = np.empty(0)
    foragers: npt.NDArray = np.empty(0)
    col_states: dict = field(default_factory=dict)
    save_trajectories: bool = False

    log: bool = False
        
    @property
    def col_state(self) -> float:
        return self.ants[:, 2].mean()
        
    def __post_init__(self) -> None:
        self.ants = np.zeros((3, self.no_ants))
        self.angles = {'above': [], 'below': []}
        self.ant_bounds = [self.bounds[0], self.bounds[1]]

    
    def inialise(self, no_foragers: int = 1) -> None:
        self.ants = np.array([np.array([*p, 0]) for p in np.random.uniform(*self.ant_bounds, size=(self.no_ants, 2))])
        self.foragers = np.array([Forager(p, self.bounds,
                                          sample_type=self.sample_type, 
                                          reflection_noise_factor=self.reflection_noise_factor) for p in np.zeros([no_foragers, 2])])
    
    def step(self) -> None:
        self.tick += 1
        if self.waiting_duration > 0:
            self.waiting_duration -= 1
            return
        moves = np.hstack(((np.random.uniform(-1, 1, size=(self.no_ants, 2)) * (STEP_SIZE * 3)), np.zeros((self.no_ants, 1))))
#         moves = np.c_[(np.random.uniform(-1, 1, size=(self.no_ants, 2)) * STEP_SIZE), np.zeros(self.no_ants)]
#         moves = np.c_[np.ones([self.no_ants, 2]), np.zeros(self.no_ants)]      
        self.ants += moves

        # ugly but works, albeit a bit slow, for limiting non-foragers possible positions
        self.ants[:,:-1] = self.ants[:,:-1].clip(self.ant_bounds[0] + self.radius/2, self.bounds[1] - self.radius/2)
        in_entrance_mask  = np.argwhere(np.linalg.norm(np.array([self.bounds[0], self.bounds[0]]) - self.ants[:,:-1], axis=1) < 1)
        self.ants[in_entrance_mask] -= moves[in_entrance_mask]
        
        f: Forager
        for f in self.foragers: 
            other_ants =  copy.deepcopy(self.ants)
            for_exit = f.move(self.col_state, other_ants)
            f.step = self.tick
            dists = np.linalg.norm(self.ants[:,:-1] - f.p, axis=1)
            close_ant = np.argmin(dists)

            if for_exit:
                if self.log:
                    self.visit_data.writerow([self.tick, self.col_state, 1-self.col_state, self.repeat, f.exiting_crop, f.prev_trip_length])
#             print(dists)
            if np.min(dists) < self.radius: # less than the radius for interaction
                self.store_angs(f.angle, f.crop)
                other_ant = self.ants[close_ant]
#                 print(other_ant)
                given_food, self.waiting_duration = f.interaction(other_ant[2])
                self.ants[close_ant][2] += given_food
#                 print(other_ant)
                if given_food >= self.minimum_interaction:
                    f.angle = f.sampling()
#                 else:
#                     print(f.crop, self.col_state)
                if self.log:
                    self.interaction_data.writerow([self.tick, self.col_state, 1-self.col_state, self.repeat, f.crop, other_ant[-1], given_food, f.p, f.trip])

        self.col_states[self.tick] = self.col_state

        if self.log:
            forag = self.foragers[0] # type: Forager
            self.forager_data.writerow([self.tick, self.col_state, 1-self.col_state, self.repeat, forag.crop, forag.trip, forag.p, forag.exiting_crop, forag.crop < THRESHOLD])


    def store_angs(self, a:float, crop: float) -> None:
        if crop <= THRESHOLD:
            self.angles['below'].append(a)
        else:
            self.angles['above'].append(a)

    def visualise(self, ax, color=True, plot=True) -> m_collec.PathCollection:
        ax.set_xlim(-1, 12)
        ax.set_ylim(-1, 12)
        cmap = matplotlib.cm.get_cmap('viridis')
#         print(cmap)
        sc = ax.scatter(self.ants[:, 0], self.ants[:, 1], c=self.ants[:, 2], vmin=0, vmax=1)
        if color:
            plt.colorbar(sc)
        for f in self.foragers:
            ax.scatter(*f.p, c=np.array([cmap(f.crop)]).reshape(1, -1), s=90, marker='s', vmin=0, vmax=1)
#             ax.scatter(*f.p, s=90, marker='s', vmin=0, vmax=1)
        if plot:
            plt.show()
        return sc

In [None]:
def nest_state_plot(n_for: int = 1) -> None:
    m = ModelAntArray([0, 11], 90, sample_type= 'all')
    m.inialise(n_for)
    interval = 800
    fig, axes = plt.subplots(2, 5, figsize=(20, 10))
    fig.subplots_adjust(left=0.1,
                        bottom=0.1, 
                        right=0.92, 
                        top=0.9, 
                        wspace=0.4, 
                        hspace=0.4)

    for n, ax in enumerate(axes.flatten()):

    #     ax.set_title(f"{colony_state=}")
        ax.set_title(f"s: {interval * n}, cs: {round(m.col_state, 2)}")
        sc = m.visualise(ax, color=False, plot=False)
        for _ in range(interval):
            m.step()

    cbar_ax = fig.add_axes([0.95, 0.15, 0.01, 0.7])
    plt.colorbar(sc, cbar_ax)
    plt.show()

In [None]:
# f = open("forager_data.csv", 'w+') 
# i = open('interaction_data.csv', 'w+')
# v = open('visit_data.csv', 'w+') 
# 
# f_writer = csv.writer(f)
# i_writer = csv.writer(i)
# v_writer = csv.writer(v)
# 
# f_writer.writerow(['step', 'colony state', 'empty colony state', 'repeat', 'crop', 'trip', 'position', 'exiting crop', 'fall thresh'])
# i_writer.writerow(['step', 'colony state', 'empty colony state', 'repeat', 'forager crop', 'nest ant crop', 'interaction volume', 'position', 'trip'])
# v_writer.writerow(['step', 'colony state', 'empty colony state', 'repeat', 'exiting crop', 'trip length'])


# m = ModelAntArray([0, 11], 89, sample_type= 'all', forager_data = None, interaction_data = i_writer, visit_data = v_writer, 
# log=True, radius=STEP_SIZE+0.1, minimum_interaction=0, save_trajectories=True)
# m.inialise(1)

In [None]:
# if run_model1:
#     for _ in range(60000):
#         m.step()
# run_model1 = True

# f.close()
# i.close()
# v.close()

In [None]:
def twod_numpy_intersect(arr1: npt.NDArray, arr2: npt.NDArray) -> Set:
    return set(map(tuple, arr1)).intersection(map(tuple, arr2))


In [None]:
def trajectory_plot(ax, m: ModelAntArray, trip_no: int) -> None:
    forager = m.foragers[0]
    trajectories = forager.trajectory
    trip = trajectories.get(trip_no)
    if not trip:
        raise ValueError("Trip out of range")
    f_pos_x = [x[0][0] for x in trip]
    f_pos_y= [x[0][1] for x in trip]

    ax.plot(f_pos_x, f_pos_y, zorder=2, c='k', linewidth=2)
    cmap = matplotlib.cm.get_cmap('viridis')

    sc = ax.scatter(trip[0][1][:, 0], trip[0][1][:, 1], c=trip[0][1][:, 2], vmin=0, vmax=1, zorder=1, s=90)
    rect = patches.Circle((0, 0), radius=1, linewidth=1, edgecolor='grey', facecolor='grey', fill=True, alpha=0.3)
    ax.add_patch(rect)

    u = np.diff(f_pos_x)
    v = np.diff(f_pos_y)
    norm = np.sqrt(u**2+v**2) 
    pos_x = f_pos_x[:-1] + u/2
    pos_y = f_pos_y[:-1] + v/2
    sli = 4
    ax.quiver(pos_x[::sli], pos_y[::sli], (u/norm)[::sli], (v/norm)[::sli], 
              angles="xy", zorder=5, pivot="mid", width=0.006, scale=30)

    col_state = np.sum(trip[0][1][:, 2]) / len(trip[0][1][:, 2])

    ax.set_title(f"Colony state: {round(col_state, 1)}")
    ax.title.set_fontsize(28)
    ax.set_ylim(0, 12)
    ax.set_xlim(0, 12)
    ax.set_aspect(1)

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(sc, cax=cax)


In [None]:
# fig, ax = plt.subplots(1, 2, figsize=(16, 16))
# trajectory_plot(ax[0], m, 12)
# trajectory_plot(ax[1], m, 130)
# plt.savefig(f"2d_continuous_trip_traj.png", format='png')

In [None]:
def col_state_plot(n_for: int = 1) -> None:
    m = ModelAntArray([0, 11], 90, sample_type= 'all', 
                          radius=STEP_SIZE+0.1, minimum_interaction=0.02)    
    m.inialise(n_for)
    cs = np.zeros(50000)
    for i in range(2):
        for s in range(50000):
            cs[s] += m.col_state
            m.step()
    
    cs = cs / 10
    plt.plot(cs)
    plt.show()
    return m

In [None]:
1/240

## Maybe get rid of interactions when forager crop is 0

In [None]:
(1/np.array([2, 2, 2, 200, 200])).mean()

In [None]:
run_model = True

# Interaction rate

In [None]:
def flatten(t):
    return [item for sublist in t for item in sublist]

In [None]:
def interaction_rate(for_int_data, model_data, ax) -> None:
    freqs = []
    cols = []
    for fd, md in zip(for_int_data, model_data):
        colony_states = [md[x] for x in fd.keys()]
        steps = np.fromiter(fd.keys(), dtype=int)
        steps = np.insert(steps, 0, 0)
        diff = np.diff(steps)
        freq = 1/diff
        
        cols.append(colony_states)
        freqs.append(freq)
    
    cols = np.array(list(itertools.chain.from_iterable(cols)))
    freqs = np.array(list(itertools.chain.from_iterable(freqs)))
    
    binned_avg = stats.binned_statistic(cols, 
                                        freqs,
                                        'mean', bins=10)
    ax.plot(binned_avg.bin_edges[:-1]+0.05,
             binned_avg.statistic, '--ko', label="forager interaction rate")
#     ax.legend()
    ax.set_title("Interaction rate")
    ax.set_xlim([0, 1])

# Sampled colony state

In [None]:
def sampled_col_state(for_int_data, model_data, ax) -> None:
    sampled_col = []
    cols = []
    for fd, md in zip(for_int_data, model_data):
        colony_states = [md[x] for x in fd.keys()]
        sampled = np.array(list(fd.values()), dtype=object)[:, 1]

        cols.extend(colony_states)
        sampled_col.extend(sampled)

    binned_avg = stats.binned_statistic(np.array(cols),
                                        np.array(sampled_col),
                                        'mean', bins=10)
    binned_sem = stats.binned_statistic(np.array(cols), 
                                        np.array(sampled_col),
                                        'std', bins=10)
    
    binned_col_avg = stats.binned_statistic(np.array(cols),
                                        np.array(cols),
                                        'mean', bins=10)
    binned_col_sem = stats.binned_statistic(np.array(cols),
                                            np.array(cols),
                                            'std', bins=10)
    
    ax.plot(binned_avg.bin_edges[:-1]+0.05,
             binned_avg.statistic, 'green',
             zorder=4, label="Recipients")
    
    ax.fill_between(binned_avg.bin_edges[:-1]+0.05, 
                    binned_avg.statistic-binned_sem.statistic, 
                    binned_avg.statistic+binned_sem.statistic, 
                    color='green',
                    alpha=0.25)
    
    ax.plot(binned_col_avg.bin_edges[:-1]+0.05,
             binned_col_avg.statistic, 'purple',
             zorder=4, label="All workers")

    ax.fill_between(binned_col_avg.bin_edges[:-1]+0.05, 
                    binned_col_avg.statistic-binned_col_sem.statistic, 
                    binned_col_avg.statistic+binned_col_sem.statistic, 
                    color='purple',
                    alpha=0.25)
    ax.set_ylabel("Crop state [-]")

    ax.set_xlim([0, 1])
#     ax.set_title("Forager sampled colony state")
#     ax.legend()

# Unloading rate

In [None]:
def unloading_rate(for_exit_data, ax) -> None:
    cols = []
    unloading_rate = []
    for fd in for_exit_data:
        steps = np.fromiter(fd.keys(), dtype=int)
        durs = np.diff(steps)
        colony_state = np.array(list(fd.values()))[1:, 0]
        crop_at_exit = np.array(list(fd.values()))[1:, 1]
        given_amounts = 1- crop_at_exit

        unloading = given_amounts / durs
        unloading_rate.extend(unloading)
        cols.extend(colony_state)

    binned_avg = stats.binned_statistic(np.array(cols),
                                        np.array(unloading_rate),
                                        'mean', bins=5)
    binned_sem = stats.binned_statistic(np.array(cols),
                                        np.array(unloading_rate),
                                        stats.sem, 5)
    ax.plot(binned_avg.bin_edges[:-1]+0.1,
             binned_avg.statistic, 'green',
             zorder=4)
    ax.fill_between(binned_avg.bin_edges[:-1]+0.1, 
                    binned_avg.statistic-binned_sem.statistic, 
                    binned_avg.statistic+binned_sem.statistic, 
                    color='green',
                    alpha=0.25)

    ax.set_ylabel(r"Unloading rate [$seconds^{-1}]$")
    ax.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
    ax.set_xlim([0, 1])
#     ax.legend()
#     ax.set_title("Forager unloading rate")

# Trip freq

In [None]:
def exiting_frequency(forag_exit_data, ax, short_trips) -> None:
    cols = []
    freqs = []
    for fd in forag_exit_data:
        steps = np.fromiter(fd.keys(), dtype=int)
        durs = np.diff(steps)
        freq = 1/durs
        colony_state = np.array(list(fd.values()))[1:, 0]

        freq = freq[(durs > short_trips)]
        colony_state = colony_state[(durs > short_trips)]

        freqs.extend(freq)
        cols.extend(colony_state)

    binned_avg = stats.binned_statistic(np.array(cols),
                                        np.array(freqs),
                                        'mean', 5)
    binned_sem = stats.binned_statistic(np.array(cols),
                                        np.array(freqs),
                                        stats.sem, 5)
    
    empty_col_state = 1 - binned_avg.bin_edges[1:]
    ax.plot(empty_col_state+0.1,
             binned_avg.statistic, 'green',
             zorder=4)
    ax.fill_between(empty_col_state+0.1,
                    binned_avg.statistic-binned_sem.statistic, 
                    binned_avg.statistic+binned_sem.statistic, 
                    color='green',
                    alpha=0.25)
    ax.set_ylabel(r"Foraging frequency [$steps^{-1}$]")
    ax.set_xlabel(r"Empty colony state [-]")
    ax.set_xlim([0, 1])
    ax.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
#     ax.legend()
#     ax.set_title("Forager trip frequency")

# Exiting crop

In [None]:
def exiting_crop(forag_exit_data, ax, short_trips) -> None:
    cols = []
    exit_crops = []
    for fd in forag_exit_data:
        steps = np.fromiter(fd.keys(), dtype=int)
        durs = np.diff(steps)

        colony_state = np.array(list(fd.values()))[1:, 0]
        ec = np.array(list(fd.values()))[1:, 1]

        ec = ec[(durs > short_trips)]
        colony_state = colony_state[(durs > short_trips)]

        exit_crops.extend(ec)
        cols.extend(colony_state)

    binned_avg = stats.binned_statistic(np.array(cols),
                                        np.array(exit_crops),
                                        'mean', 10)
    binned_sem = stats.binned_statistic(np.array(cols),
                                        np.array(exit_crops),
                                        'std', 10)
    
    ax.plot(binned_avg.bin_edges[:-1]+0.05,
             binned_avg.statistic, 'green',
             zorder=4)
    ax.fill_between(binned_avg.bin_edges[:-1]+0.05, 
                    binned_avg.statistic-binned_sem.statistic, 
                    binned_avg.statistic+binned_sem.statistic, 
                    color='green',
                    alpha=0.25)

    ax.scatter(cols, exit_crops, alpha=0.2, c='k', s=1.5)
    ax.set_ylim(0, 1)
    ax.set_xlim([0, 1])
    ax.set_ylabel("Forager's crop at exit [-]")
#     ax.set_title("Forager exit crop")

# Trip duration

In [None]:
def trip_duration(forag_exit_data, ax) -> None:
    cols = []
    durations = []
    for fd in forag_exit_data:
        steps = np.fromiter(fd.keys(), dtype=int)
        durs = np.diff(steps)
        colony_state = np.array(list(fd.values()))[1:, 0]

        durations.extend(durs)
        cols.extend(colony_state)

    binned_avg = stats.binned_statistic(np.array(cols),
                                        np.array(durations),
                                        'mean', 10)
    ax.plot(binned_avg.bin_edges[:-1]+0.05,
             binned_avg.statistic, '--ko',
             zorder=4, label="foragers trip duration")
#     ax.legend()
    ax.set_title("Forager trip duration")
    ax.set_xlim([0, 1])

# Tip depth

In [None]:
def trip_depth(forag_exit_data, ax) -> None:
    cols = []
    depths = []
    for fd in forag_exit_data:
        colony_state = np.array(list(fd.values()))[1:, 0]
        depth = np.array(list(fd.values()))[1:, 2]

        depths.extend(depth)
        cols.extend(colony_state)

    binned_avg = stats.binned_statistic(np.array(cols),
                                        np.array(depths),
                                        'mean', 10)
    binned_sem = stats.binned_statistic(np.array(cols),
                                        np.array(depths),
                                        stats.sem, 10)
    ax.plot(binned_avg.bin_edges[:-1]+0.05,
             binned_avg.statistic, 'green',
             zorder=4)
    ax.fill_between(binned_avg.bin_edges[:-1]+0.05, 
                    binned_avg.statistic-binned_sem.statistic, 
                    binned_avg.statistic+binned_sem.statistic, 
                    color='green',
                    alpha=0.25)
    ax.set_ylabel(r"Trip depth [ant-lengths]")
    ax.set_xlim([0, 1])
#     ax.legend()
#     ax.set_title("Forager trip depth")

In [None]:
import multiprocessing as mp
import os

In [None]:
f = open("forager_data.csv", 'w+') 
inter = open('interaction_data.csv', 'w+')
v = open('visit_data.csv', 'w+') 

f_writer = csv.writer(f)
i_writer = csv.writer(inter)
v_writer = csv.writer(v)

f_writer.writerow(['step', 'colony state', 'empty colony state', 'repeat', 'crop', 'trip', 'position', 'exiting crop', 'fall thresh'])
i_writer.writerow(['step', 'colony state', 'empty colony state', 'repeat', 'forager crop', 'nest ant crop', 'interaction volume', 'position', 'trip'])
v_writer.writerow(['step', 'colony state', 'empty colony state', 'repeat', 'exiting crop', 'trip length'])



def run_model_func(i):
    m = ModelAntArray([0, 11], 89, repeat=i, sample_type= 'all', forager_data = f_writer, interaction_data = i_writer, visit_data = v_writer, 
    log=True, radius=0.3, minimum_interaction=0, reflection_noise_factor=0.3)
    m.inialise(1)
    while m.col_state < 0.95:
        m.step()
    return m

In [None]:

cpus = os.cpu_count() // 2

no_forag = 1
no_runs = 20
output = []
logging = True

if not logging:
    with mp.Pool(processes=cpus) as pool:
        for i in range(no_runs):
            output.append(pool.apply_async(run_model_func))
        pool.close()
        pool.join()

else:
    for i in range(no_runs):
        output.append(run_model_func(i))
    

f.flush()
inter.flush()
v.flush()

f.close()
inter.close()
v.close()
run_model = True

In [None]:
forager_int_data = []
forager_exit_data = []
model_data = []

for i, async_result in enumerate(output):
    m = async_result.get()
    for f in m.foragers:
        forager_int_data.append(f.interaction_pos_other_crop)
        forager_exit_data.append(f.exiting_step_crop)
        model_data.append(m.col_states)

In [None]:
import matplotlib as mpl

mpl.rcParams['figure.dpi'] = 500
mpl.rcParams['savefig.dpi'] = 500
mpl.rcParams['font.size'] = 14
mpl.rcParams['legend.fontsize'] = 'medium'

In [None]:
fig, ax = plt.subplots(1,4,figsize=[20,4.5])
# ax = flatten(axes)
short_trips = 8

sampled_col_state(forager_int_data, model_data, ax[0])
unloading_rate(forager_exit_data, ax[1])
trip_depth(forager_exit_data, ax[2])
exiting_crop(forager_exit_data, ax[3], short_trips)
fig.tight_layout()
plt.savefig(f"continuous_figure_6_shorttrips{short_trips}_015radius_03reflection_005mininter_{no_forag}forag.svg", format='svg')

In [None]:
fig, ax = plt.subplots(figsize=(3.5,3.5))
exiting_frequency(forager_exit_data, ax, short_trips)
plt.savefig(f"continuous_trip_freq_shorttrips{short_trips}_015radius_03reflection_005mininter_{no_forag}forag.svg", format='svg')


In [None]:
fig, axes = plt.subplots(5, 5, figsize=(22, 22))

for n, ax in enumerate(flatten(axes)):
    st = n * 2
    exiting_frequency(forager_exit_data, ax, st)