In [None]:

from typing import List
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import random
import itertools

from matplotlib import rc
rc('animation', html='html5')


In [None]:
class Ball():
    def __init__(self, x, y, vx, vy, color):
        self.x = x
        self.y = y
        self.vx = vx
        self.vy = vy
        self.color = color

    def update_pos(self):
        self.x += self.vx * dt
        self.y += self.vy * dt

    def check_collision(self, ball):
        if (self.x - ball.x)**2 + (self.y - ball.y)**2 <= 4*r**2:
            if ((self.x - ball.x)*(self.vx - ball.vx) < 0) or ((self.y - ball.y)*(self.vy - ball.vy) < 0):
                if (self.x - ball.x) == 0:
                    angle = np.pi / 2
                else:
                    angle = np.arctan((self.y - ball.y) / (self.x - ball.x))

                # Rotate by -angle
                rotation = ((np.cos(angle), np.sin(angle)),
                            (-np.sin(angle), np.cos(angle)))

                # Swap x velocity and rotate by angle
                (vx1, vx2), (vy1, vy2) = np.dot(rotation,
                                                ((self.vx, ball.vx), (self.vy, ball.vy)))

                (self.vx, ball.vx), (self.vy, ball.vy) = np.dot(
                    np.transpose(rotation), ((vx2, vx1), (vy1, vy2)))

    def check_boundary(self):
        if self.x >= width and self.vx > 0:
            self.vx = - self.vx
        elif self.x <= 0 and self.vx < 0:
            self.vx = - self.vx
        if self.y >= height and self.vy > 0:
            self.vy = - self.vy
        elif self.y <= 0 and self.vy < 0:
            self.vy = - self.vy


In [None]:
# Check if check_collision works fine

# Init plot
width = 10
height = 10
fig = plt.figure()
ax = fig.add_subplot(autoscale_on=False,
                     xlim=(-1, width+1), ylim=(-1, height+1))
plt.gca().set_aspect('equal', adjustable='box')
# ax.grid()
ax.tick_params(axis='both', which='both',
               bottom=False, labelbottom=False,
               left=False, labelleft=False,
               top=False, right=False)
line, = ax.plot([], [], 'o', markersize=20)
plt.close() # Hide plot img

# Declare constants
duration = 2  # Seconds
dt = 0.05
r = 0.5

# Create balls
ball1 = Ball(4, 4.5, 1, 0, 1)
ball2 = Ball(6, 5, 0, 0, 1)

def animate(frame):
    ball1.check_collision(ball2)

    x_list = []
    y_list = []

    for ball in (ball1, ball2):
        ball.check_boundary()
        ball.update_pos()
        x_list.append(ball.x)
        y_list.append(ball.y)

    line.set_data(x_list, y_list)
    return line

# Show animation
ani = animation.FuncAnimation(
    fig, animate, frames=range(int(duration/dt)), interval=dt*1000, blit=False)

ani

In [None]:
# Declare constants
duration = 10  # Seconds
dt = 0.01
r = 0.3

width = 20
height = 20

# Init plot
fig = plt.figure()
ax = fig.add_subplot(autoscale_on=False,
                     xlim=(-1, width+1), ylim=(-1, height+1))
plt.gca().set_aspect('equal', adjustable='box')
# ax.grid()
ax.tick_params(axis='both', which='both',
               bottom=False, labelbottom=False,
               left=False, labelleft=False,
               top=False, right=False)
line, = ax.plot([], [], 'o')
plt.close() # Hide plot img

# Init lists
x_list = []
y_list = []

balls = []  # type: List[Ball]
for i in range(100):
    ball = Ball(10 * random.random(), 10 * random.random(),
                5 * random.random(), 5 * random.random(), 1)
    balls.append(ball)


In [None]:
cell_size = 1
cells_cols = int(width/cell_size+1)
cells_rows = int(height/cell_size+1)

''' Setting ticks lowers performance
ax.set_xticks(np.arange(0, width+.1, cell_size))
ax.set_yticks(np.arange(0, height+.1, cell_size))

pass # Suspend set ticks output
'''

In [None]:
# Check all combinations
# 25 seconds

def animate(frame):
    x_list.clear()
    y_list.clear()

    for ball1, ball2 in itertools.combinations(balls, 2):
        ball1.check_collision(ball2)

    for ball in balls:
        ball.check_boundary()
        ball.update_pos()
        x_list.append(ball.x)
        y_list.append(ball.y)

    line.set_data(x_list, y_list)
    return line


ani = animation.FuncAnimation(
    fig, animate, frames=range(int(duration/dt)), interval=dt*1000, blit=False)

ani


In [None]:
# Static cells

# Cell size 1: 19 seconds
# Cell size 2: 20 seconds
# Cell size 3: 20 seconds
# Cell size 5: 21 seconds

def animate(frame):
    x_list.clear()
    y_list.clear()

    # cells[x][y]
    cells = [[[]]*cells_rows for i in range(cells_cols)]
    for i, ball in enumerate(balls):
        # Assume 2r is smaller then the cell size
        x = {int(ball.x/cell_size+r), int(ball.x/cell_size-r)}
        y = {int(ball.y/cell_size+r), int(ball.y/cell_size-r)}
        for cell_x, cell_y in zip(x, y):
            cells[cell_x][cell_y].append(i)

    for i, j in zip(range(cells_cols), range(cells_rows)):
        for b1, b2 in itertools.combinations(cells[i][j], 2):
            balls[b1].check_collision(balls[b2])

    for ball in balls:
        ball.check_boundary()
        ball.update_pos()
        x_list.append(ball.x)
        y_list.append(ball.y)

    line.set_data(x_list, y_list)
    return line


ani = animation.FuncAnimation(
    fig, animate, frames=range(int(duration/dt)), interval=dt*1000, blit=False)
ani


In [None]:
# Sweep and prune
# 19 seconds

def animate(frame):
    x_list.clear()
    y_list.clear()

    balls.sort(key=lambda b: b.x)

    active = []
    for i, ball1 in enumerate(balls):
        if not active:
            active.append(ball1)
            continue
        for j, ball2 in reversed(list(enumerate(active))):
            if ball1.x - ball2.x > 2*r:
                del active[:j]
                active.append(ball1)
                continue
            ball1.check_collision(ball2)

    for ball in balls:
        ball.check_boundary()
        ball.update_pos()
        x_list.append(ball.x)
        y_list.append(ball.y)

    line.set_data(x_list, y_list)
    return line


ani = animation.FuncAnimation(
    fig, animate, frames=range(int(duration/dt)), interval=dt*1000, blit=False)
ani


In [None]:
# Sweep and prune with indices
# 19 seconds

def animate(frame):
    x_list.clear()
    y_list.clear()

    balls.sort(key=lambda b: b.x)

    left = 0
    right = 0
    for ball1 in balls:
        for j in range(left, right):
            ball2 = balls[j]
            if ball1.x - ball2.x > 2*r:
                left = j + 1
                continue
            ball1.check_collision(ball2)
        right += 1

    for ball in balls:
        ball.check_boundary()
        ball.update_pos()
        x_list.append(ball.x)
        y_list.append(ball.y)

    line.set_data(x_list, y_list)
    return line


ani = animation.FuncAnimation(
    fig, animate, frames=range(int(duration/dt)), interval=dt*1000, blit=False)
ani


In [None]:
# TODO K-D Tree algorithm
