## Imports

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import math
import heapq
import mplcursors
from tqdm import tqdm

## Heap

In [None]:
# Use a max heap of n elements, popping when the new element is less than the head
class KMaxHeap():
  def __init__(self, capacity):
    self.capacity = capacity
    self.data = [] 

  def push(self, elem):
    elem[0] = -elem[0]
    if len(self.data) < self.capacity:
      heapq.heappush(self.data, elem)
    elif elem[0] > self.data[0][0]:
      heapq.heappushpop(self.data, elem)

## BirdFlockSimulator

In [None]:
from numpy import disp


# TODO Approx K
class NearestNeighbourCalculator:
  
  def __init__(self, neighBoursN):
    self.neighboursN = neighBoursN
  
  def getNNDf(self, df):
    res = []
    for i in range(len(df)):
      res.append(self.getNeighbours(i, df))
    nndf = pd.DataFrame(np.array(res))
    
    nndf.rename(mapper=self.mapColNames, axis='columns', inplace=True)
    ret = pd.concat([df, nndf], axis='columns')
    return ret.drop(columns=['x', 'y', 'vx', 'vy'])

  # Nearest Neighbours
  def getNeighbours(self, i, df):
    bird = df.iloc[i]
    maxHeap = KMaxHeap(self.neighboursN)
    # Benchmark filtering on a distance too
    for index, row in df.iterrows():
      if i == index:
        continue
      dist = abs(row.x-bird.x) ** 2 + abs(row.y-bird.y) ** 2
      maxHeap.push([dist, row])
    maxHeap.data.sort()
    nn = np.array(
      # Dist, X, Y, Vx, Vy
      [[math.sqrt(-neighbour[0]), neighbour[1][0], neighbour[1][1], neighbour[1][2], neighbour[1][3]]
      for neighbour in reversed(maxHeap.data)]).flatten()
    return nn

  def mapColNames(self, col):
    nnn = col//5
    label = ""
    if col % 5 == 0:
      label = "dist"
    elif col % 5 == 1:
      label = "x"
    elif col % 5 == 2:
      label = "y"
    elif col % 5 == 3:
      label = "vx"
    else:
      label = "vy"
    return f"nn_{nnn}_{label}"
  

In [None]:
class VectorCalculator:
  def __init__(self, neighBoursN, momentumFactor, boundaryFactor, maxX, maxY):
    self.neighboursN = neighBoursN
    self.momentumFactor = momentumFactor
    self.boundaryFactor = boundaryFactor
    self.maxX = maxX
    self.maxY = maxY

  def getBoundaryVectors(self, row):
    x = 0.0
    y = 0.0
    # X Axis
    if row['x'] < 10:
      x = 10-row['x']
    if row['x'] > self.maxX - 10:
      x = (self.maxX - 10) - row['x']
    # Y Axis
    if row['y'] < 10:
      y = 10-row['y']
    if row['y'] > self.maxY - 10:
      y = (self.maxY - 10) - row['y']
    return x*self.boundaryFactor ,y*self.boundaryFactor

  def calcVectors(self, row):
    individualXVectors = []
    individualYVectors = []
    vectorReason = []
    # Get Vectors from neighbours
    for i in range(self.neighboursN):
      xDiff, yDiff, reason = self.calcVector(row, (row[f"nn_{i}_x"], row[f"nn_{i}_y"], row[f"nn_{i}_dist"]), self.neighboursN, i)
      individualXVectors.append(float(xDiff))
      individualYVectors.append(float(yDiff))
      vectorReason.append(reason)
    # Add vectors for hitting boundary
    xBoundary, yBoundary = self.getBoundaryVectors(row)

    return individualXVectors + [xBoundary] + individualYVectors + [yBoundary] + vectorReason

  # Vector calculation
  def getVectors(self, df, prev):
    ret = pd.concat([df, prev], axis=1).apply(self.calcVectors, axis=1, result_type='expand')
    xVectorCols = [f'vx_{i}' for i in range(self.neighboursN)] + ['vx_b']
    yVectorCols = [f'vy_{i}' for i in range(5)] + ['vy_b']
    reasonCols = [f'r_{i}' for i in range(5)]
    ret.columns = xVectorCols + yVectorCols + reasonCols
    ret['vx_d'] = ret[xVectorCols].sum(axis=1)
    ret['vy_d'] = ret[yVectorCols].sum(axis=1)
    # Maintain momentum from previous velocity
    ret['x'] = ret['vx_d'] * (1-self.momentumFactor) + prev.vx * (self.momentumFactor)
    ret['y'] = ret['vy_d'] * (1-self.momentumFactor) + prev.vy * (self.momentumFactor)
    # Min speed
    # Prob a more efficient way of doing this
    for index, row in ret.iterrows():
      if ((row.x **2) + (row.y ** 2)) < 10:
        newX, newY = self.normalizeVector(row.x, row.y, math.sqrt(10))
        ret.iat[index, ret.columns.get_loc('x')] = newX
        ret.iat[index, ret.columns.get_loc('y')] = newY
    return ret
  
  # Alternative, use the vectors of the 5 closest neightbours instead?
  def calcVector(self, row, neighbour, N, i):
    neighbourX = row[f"nn_{i}_x"]
    neighbourY = row[f"nn_{i}_y"]
    neighbourDist = row[f"nn_{i}_dist"]
    neighbourVx = row[f"nn_{i}_vx"]
    neighbourVy = row[f"nn_{i}_vy"]

    vecX = neighbourX-row.x
    vecY = neighbourY-row.y
    # If neighbour is greater than 11, we want to be attracted to a point 10 away
    if neighbourDist >= 15:
      attraction = neighbourDist-10
      attractionFactor = 0.1
      x,y = self.normalizeVector(vecX, vecY, attraction*attractionFactor/N)
      return (x, y, "A")
    # If neighbour is less than 9, repulse them based on how close they are
    # Push away linearly
    elif neighbourDist <= 5:
      repulsion = 10-neighbourDist
      repulsionFactor = 0.1
      x,y = self.normalizeVector(-vecX, -vecY, repulsion*repulsionFactor/N)
      return (x, y, "R")
    # Else, copy vector of the neighbour
    else:
      return (neighbourVx/N, neighbourVy/N, "C")
  
  def normalizeVector(self, x, y, scale=1):
    hyp = math.sqrt(x**2 + y**2)
    return ((x/hyp)*scale, (y/hyp)*scale)

In [None]:
class BirdFlockSimulator:

  # Init
  def __init__(self, numBirds):
    self.numBirds = numBirds
    #self.maxInitX
    #self.maxInitY
    self.maxX = 500
    self.maxY = 500
    self.neighboursN = 5
    self.momentumFactor = 0.5
    self.boundaryFactor = 0.5
    self.birds = self.createXYObjectsDataframe(self.numBirds)
    self.predators = self.createXYObjectsDataframe(1)
    self.next = None

  def createXYObjectsDataframe(self, NObjects):
    d = {
      'x': np.random.rand(NObjects) * self.maxX,
      'y': np.random.rand(NObjects) * self.maxY,
      'vx' : np.random.rand(NObjects) * 10,
      'vy' : np.random.rand(NObjects) * 10
    }
    return pd.DataFrame(data=d)

  def getNext(self):
    if self.next is None:
      self.next = self.calculateTick()
    return self.next[0]

  def tick(self, update=True):
    self.getNext()
    if update:
      self.birds, self.nnDf = self.next
      self.next = None

  def calculateTick(self):
    # NN
    nnCalc = NearestNeighbourCalculator(self.neighboursN)
    nnDf = nnCalc.getNNDf(self.birds)
    # Vectors
    vCalc = VectorCalculator(self.neighboursN, self.momentumFactor, self.boundaryFactor, self.maxX, self.maxY)
    vectors = vCalc.getVectors(nnDf, self.birds)
    
    # Create new df
    newD = self.birds[['x', 'y']] + vectors[['x', 'y']]
    vectors = vectors.rename(columns={'x' : 'vx', 'y' : 'vy'})
    ret = pd.concat([newD, vectors], axis=1)
    return ret, nnDf

  

## Plotting

In [None]:
%matplotlib widget

def selFun(sel):
  sel.annotation.set_text(sel.artist.annotation_names[sel.target.index])

def positionPlot(curr):
  fig, ax = plt.subplots(1, 1)
  scat = ax.scatter(curr["x"], curr["y"])
  scat.annotation_names = list(range(100))
  cursor = mplcursors.cursor([scat], hover=True)
  cursor.connect("add", lambda sel: selFun(sel))
  plt.show()

In [None]:
def velPlot(sim):
  # Get Data
  curr = sim.birds
  next = sim.getNext()
  # Plot data
  fig, ax = plt.subplots()
  cxs = list(curr.x)
  cys = list(curr.y)
  nxs = list(next.x)
  nys = list(next.y)
  for cx, cy, nx, ny in zip(cxs, cys, nxs, nys):
    ax.annotate('', xytext=(cx, cy), xy=(nx, ny), arrowprops=dict(arrowstyle='->'))
  ax.set(xlim=(-10, 510), ylim=(-10, 510))
  plt.close(fig)
  return fig

In [None]:
def posPlot(sim):
  # Get Data
  curr = sim.birds
  # Plot data
  fig, ax = plt.subplots(1, 1)
  scat = ax.scatter(curr["x"], curr["y"])
  scat.annotation_names = list(range(100))
  cursor = mplcursors.cursor([scat], hover=True)
  cursor.connect("add", lambda sel: selFun(sel))
  ax.set(xlim=(-10, 550), ylim=(-10, 550))
  plt.close(fig)
  return fig

## Test

In [None]:
sim = BirdFlockSimulator(100)

## Position Plot


In [None]:
for i in tqdm(range(1000)):
  sim.tick()
  fig = posPlot(sim)
  num = str(i).rjust(3, '0')
  fig.savefig(f'figs/{num}.png')