## Imports

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import math
import heapq
import mplcursors

## Heap

In [2]:
# Use a max heap of n elements, popping when the new element is less than the head
class KMaxHeap():
  def __init__(self, capacity):
    self.capacity = capacity
    self.data = [] 

  def push(self, elem):
    elem[0] = -elem[0]
    if len(self.data) < self.capacity:
      heapq.heappush(self.data, elem)
    elif elem[0] > self.data[0][0]:
      heapq.heappushpop(self.data, elem)

## BirdFlockSimulator

In [3]:
class NearestNeighbourCalculator:
  
  def __init__(self, neighBoursN):
    self.neighboursN = neighBoursN
  
  def getNNDf(self, df):
    res = []
    for i in range(len(df)):
      res.append(self.getNeighbours(i, df))
    nndf = pd.DataFrame(np.array(res))
    nndf = nndf.rename(mapper=self.mapColNames, axis='columns')
    ret = pd.concat([df, nndf], axis='columns')
    return ret.drop(columns=['x', 'y', 'vx', 'vy'])

  # Nearest Neighbours
  def getNeighbours(self, i, df):
    bird = df.iloc[i]
    maxHeap = KMaxHeap(self.neighboursN)
    for index, row in df.iterrows():
      if i == index:
        continue
      dist = abs(row.x-bird.x) ** 2 + abs(row.y-bird.y) ** 2
      maxHeap.push([dist, row])
    maxHeap.data.sort()
    nn = np.array(
      [[math.sqrt(-x[0]), x[1][0], x[1][1]]
      for x in reversed(maxHeap.data)]).flatten()
    return nn

  def mapColNames(self, col):
    nnn = col//3
    label = ""
    if col % 3 == 0:
      label = "dist"
    elif col % 3 == 1:
      label = "x"
    else:
      label = "y"
    return f"nn_{nnn}_{label}"
  

In [4]:
def foo(x):
  return x[0] + x[1]

a = pd.DataFrame([[1,2],[3,4]])
a.apply(foo, axis=1)

0    3
1    7
dtype: int64

In [5]:
class VectorCalculator:
  def __init__(self, neighBoursN, momentumFactor):
    self.neighboursN = neighBoursN
    self.momentumFactor = momentumFactor

  def calcVectors(self, row):
    individualXVectors = []
    individualYVectors = []
    for i in range(self.neighboursN):
      xDiff, yDiff = self.calcVector(row, (row[f"nn_{i}_x"], row[f"nn_{i}_y"], row[f"nn_{i}_dist"]))
      individualXVectors.append(float(xDiff))
      individualYVectors.append(float(yDiff))
    return individualXVectors + individualYVectors

  # Vector calculation
  def getVectors(self, df, df2):
    ret = pd.concat([df, df2], axis=1).apply(self.calcVectors, axis=1, result_type='expand')
    xVectorCols = [f'vx_{i}' for i in range(5)]
    yVectorCols = [f'vy_{i}' for i in range(5)]
    ret.columns = xVectorCols + yVectorCols

    ret['vx_d'] = ret[xVectorCols].sum(axis=1)
    ret['vy_d'] = ret[yVectorCols].sum(axis=1)
    ret['x'] = ret['vx_d'] * (1-self.momentumFactor) + df2.vx * (self.momentumFactor)
    ret['y'] = ret['vy_d'] * (1-self.momentumFactor) + df2.vx * (self.momentumFactor)
    
    # Prob a more efficient way of doing this
    for index, row in ret.iterrows():
      if ((row.x **2) + (row.y ** 2)) < 10:
        ret.iloc[index].x, ret.iloc[index].y = self.normalizeVector(row.x, row.y, math.sqrt(10))
    return ret
  
  # Alternative, use the vectors of the 5 closest neightbours instead?
  def calcVector(self, bird, neighbour):
    neighbourX, neighbourY, neighbourDist = neighbour
    vecX = neighbourX-bird.x
    vecY = neighbourY-bird.y
    # If neighbour is greater than 15, we want to be attracted to a point 10 away
    if neighbourDist >= 15:
      attraction = neighbourDist-10
      attractionFactor = 0.1
      return self.normalizeVector(vecX, vecY, attraction*attractionFactor)
    # If neighbour is less than 10, repulse them based on how close they are
    # Push away linearly
    elif neighbourDist <= 5:
      repulsion = 10-neighbourDist
      repulsionFactor = 0.1
      return self.normalizeVector(-vecX, -vecY, repulsion*repulsionFactor)
    # Else, maintain current vector
    else:
      return (0, 0)
  
  def normalizeVector(self, x, y, scale=1):
    hyp = math.sqrt(x**2 + y**2)
    return ((x/hyp)*scale, (y/hyp)*scale)

In [6]:
class BirdFlockSimulator:

  # Init
  def __init__(self, numBirds):
    self.numBirds = numBirds
    self.maxX = 100
    self.maxY = 100
    self.neighboursN = 5
    # 0.1 keeps them in place
    self.momentumFactor = 0.2
    self.birds = self.createBirdsDataframe()
    self.next = None

  def createBirdsDataframe(self):
    d = {
      'x': np.random.rand(self.numBirds) * self.maxX,
      'y': np.random.rand(self.numBirds) * self.maxY,
      'vx' : np.random.rand(self.numBirds) * 10,
      'vy' : np.random.rand(self.numBirds) * 10
    }
    return pd.DataFrame(data=d)

  def getNext(self):
    if self.next is None:
      self.next = self.calculateTick()
    return self.next[0]

  def tick(self, update=True):
    self.getNext()
    if update:
      self.birds, self.nnDf = self.next
      self.next = None

  def calculateTick(self):
    # NN
    nnCalc = NearestNeighbourCalculator(self.neighboursN)
    nnDf = nnCalc.getNNDf(self.birds)
    # Vectors
    vCalc = VectorCalculator(self.neighboursN, self.momentumFactor)
    vectors = vCalc.getVectors(nnDf, self.birds)
    # Create new df
    newD = self.birds[['x', 'y']] + vectors[['x', 'y']]
    vectors = vectors.rename(columns={'x' : 'vx', 'y' : 'vy'})
    ret = pd.concat([newD, vectors], axis=1)
    return ret, nnDf

  

## Plotting

In [7]:
%matplotlib widget

def positionPlot(df):
  fig, ax = plt.subplots(1, 1)
  scat = ax.scatter(df["x"], df["y"])
  scat.annotation_names = list(range(100))
  cursor = mplcursors.cursor([scat], hover=True)
  cursor.connect("add", lambda sel: sel.annotation.set_text(sel.artist.annotation_names[sel.target.index]))
  plt.show()

In [8]:
def velPlot(curr, next):
  fig, ax = plt.subplots()
  cxs = list(curr.x)
  cys = list(curr.y)
  nxs = list(next.x)
  nys = list(next.y)
  for cx, cy, nx, ny in zip(cxs, cys, nxs, nys):
    ax.annotate('', xytext=(cx, cy), xy=(nx, ny), arrowprops=dict(arrowstyle='->'))
  ax.set(xlim=(-100, 500), ylim=(-100, 500))
  plt.close(fig)
  return fig

## Test

In [9]:
sim = BirdFlockSimulator(100)

## 18.6

In [10]:
j = 0
for i in range(200):
  sim.tick()
  next = sim.getNext()
  fig = velPlot(sim.birds, next)
  num = str(i).rjust(3, '0')
  fig.savefig(f'{num}.png')