# Imports

In [37]:
# Object-oriented
from abc import ABC, abstractmethod, abstractproperty

# Math 
import random
import numpy as np
from scipy.spatial import distance

# System
import sys
import io

# Plotting, widgets
import seaborn as sns
from ipywidgets import *
import matplotlib.patches
import matplotlib.pyplot as plt
from IPython.display import clear_output, display
from IPython.display import HTML as html
sns.set(style="ticks")
plt.rcParams["font.family"] = "Arial"

# Set random seed
RANDOM_SEED = 0

In [40]:
# Importing from Firefly ipynb 
import ipynb
import importlib

import ipynb.fs.full.Firefly as Firefly

# Each time the Firefly code is changed, just run this to reload
importlib.reload(Firefly);

# World class where all the magic happens

Pattern used: Facade, to provide the GUI client (SimulationGUI and its associated class, WidgetCollection) a simplified interface to the firefly attributes and behaviors and data recording and plotting. 

In [39]:
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

class World:
    
    ''' Represent the entire model with the environment (2D space and time) and 
    all the fireflies and their behaviors and interactions with the environment. '''
    
    def __init__(self, num_males=30, step_size=1.0, num_total_steps=50, turning_angle_distribution=np.pi/8,
                 flash_interval_min=5, flash_interval_max=50, initial_arena_size=100, track_on=True):
        
        # Firefly & environment parameters
        self.num_males = num_males       # Number of fireflies in sky for a given world/simulation
        self.step_size = step_size
        self.num_total_steps = num_total_steps
        self.turning_angle_distribution = turning_angle_distribution
        self.flash_interval_min = flash_interval_min
        self.flash_interval_max = flash_interval_max
        self.initial_arena_size = initial_arena_size
        
        # For plotting - option of showing tracks or not
        self.track_on = track_on
        
        self.state = "stopped"            # State of the whole world
        self.setup()                      # When World is instantiated, set up a world aka sky with fireflies
        
    def setup(self):
        
        ''' Set up the world with Firefly objects instantiated at their initial positions. '''
        
        self.ticks = 0                    # Keeping track of simulation time (max is num_total_steps)
        
        # Set up data to record 
        self.all_fireflies_history = {}   # Keep history of all positions and flashes over time
        self.num_flashes_history = [0]    # Keep history of number of flashes over time
        self.distance_bw_flashes = [0]    # Keep history of the distance between fireflies that flash over time
        
        # List of Firefly objects, created using FireflyCollection which uses FireflyFactory
        self.firefly_collection = Firefly.FireflyCollection(self.num_males, self.step_size, self.num_total_steps, 
                                     self.flash_interval_min, self.flash_interval_max, 
                                     self.turning_angle_distribution,
                                     self.initial_arena_size)

    def update_data(self, num_flashes, avg_distance_bw_flashes):
        
        ''' Update data for the 2 line plots, number of flashes and distance between flashes. '''
        
        self.num_flashes_history.append(num_flashes)
        self.distance_bw_flashes.append(avg_distance_bw_flashes)
        
    def update_fireflies(self):
        
        ''' For a given time step, iterate through all fireflies in the collection and allow them 
        to pick a turning angle, check if they should flash, and take a step, while recording the
        necessary data. '''
        
        num_flashes = 0
        flash_positions = []
        
        for ff_i, ff in enumerate(self.firefly_collection):
            ff.pick_turning_angle()
            ff.check_flash()
            if ff.flash:
                num_flashes += 1
                flash_positions.append((ff.position_x, ff.position_y))
            ff.take_step(self.initial_arena_size)
            hist_dict = ff.record_history() 
            self.all_fireflies_history[f"firefly_{ff_i}"] = hist_dict

        # Compute the mean pairwise distance between flashing fireflies
        if flash_positions:
            avg_distance_bw_flashes = np.mean(distance.cdist(flash_positions, flash_positions, 'euclidean'))
        else:
            avg_distance_bw_flashes = 0
        
        # Append data to the respective lists
        self.update_data(num_flashes, avg_distance_bw_flashes)      
        
    def run(self, widgets):
        
        ''' Move the simulation through time... '''
        
        self.state = "running"
        
        for t_i in range(self.num_total_steps):
            self.ticks += 1
            self.update_fireflies()
            
            # Update the widgets to plot in real time as simulation runs
            widgets["plot_data"].value = self.plot_data()
            widgets["plot_simulation"].value = self.plot_simulation(t_i)
            widgets["ticks"].value = str(self.ticks)
            
        self.state = "stopped"
    
    def plot_data(self, show=False):
        
        ''' Plot the line plots showing number of flashing fireflies and mean distances. '''
        
        fig, ax = plt.subplots(2, 1, sharex=True)  
        
        # Plot number of flashing fireflies
        ax[0].plot(self.num_flashes_history, c="orange")
        ax[0].spines['right'].set_visible(False)
        ax[0].spines['top'].set_visible(False)
        ax[0].set_ylabel('Count')
        ax[0].set_title("Number of flashing fireflies")
        
        # Plot distance between flashers
        ax[1].plot(self.distance_bw_flashes, c="green")
        ax[1].spines['right'].set_visible(False)
        ax[1].spines['top'].set_visible(False)
        ax[1].set_ylabel('Distance')
        ax[1].set_title("Mean distance between flashes")
        
        # Shared x-axis for both plots
        ax = fig.add_subplot(111, frameon=False)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlabel('Time', labelpad=25) 

        fig.subplots_adjust(wspace=0.25, hspace=0.5)
        
        if show:
            plt.show()
            
        # Save frame for widget
        bytes = io.BytesIO()
        plt.savefig(bytes, format='svg')
        svg = bytes.getvalue()
        plt.close(fig)
        return svg.decode()
    
    def plot_simulation(self, t_i, show=False):
        
        ''' Plot the simulation with fireflies flashing in the sky... Showing their tracks is optional. '''
        
        # Mins and maxxes of the arena size 
        x_min = -self.initial_arena_size
        x_max = self.initial_arena_size
        y_min = -self.initial_arena_size
        y_max = self.initial_arena_size
        
        color = plt.cm.rainbow(np.linspace(0, 1, self.num_males))

        fig = plt.figure(1, figsize=(7, 7)) 
        ax = fig.add_subplot(1, 1, 1)

        # Sky with fireflies ❉ ❉ ❉ ✌✌✌❉ ❉ ❉
        for ff_i, (key, val) in enumerate(self.all_fireflies_history.items()):
            
            # For a single firefly, extract its info and plot its trajectory
            firefly_data = self.all_fireflies_history[key]

            x = firefly_data['trajectory_history'][:,0]
            y = firefly_data['trajectory_history'][:,1]
            flashes = np.where(firefly_data['flash_history'] == 1)
            
            x_flashes = x[flashes]
            y_flashes = y[flashes]
            
            # Plot position trajectory if track on
            if self.track_on:
                plt.plot(x[0:t_i], y[0:t_i], lw=1, zorder=0, label=f"{ff_i+1}", c=color[ff_i])

            # Plot flashing patterns
            for global_i, f_i in enumerate(flashes[0]):
                if t_i == f_i:
                    plt.scatter(x_flashes[global_i], y_flashes[global_i], c='orange', s=40)

        # Set black background
        ax = plt.gca()
        ax.set_facecolor('xkcd:black')
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        plt.xlim(x_min-50, x_max+50)
        plt.ylim(y_min-50, y_max+50)

        if show:
            plt.show()

        # Save frame for widget
        bytes = io.BytesIO()
        plt.savefig(bytes, format='svg', bbox_inches='tight', pad_inches=0)
        svg = bytes.getvalue()
        plt.close(fig)
        return svg.decode()
