In [None]:
! pip install --upgrade numpy
! pip install --upgrade jaxlib
! pip install --upgrade jax
! pip install --upgrade plotly

In [2]:
import numpy as np, pandas as pd, jax as jx, jax.numpy as jnp, plotly.express as px
rng = np.random.default_rng(42)

def pprint(arr):
    display(pd.DataFrame(arr))

def dot(x, y, axis=-1, keepdims=False):
    return (x * y).sum(axis=axis, keepdims=keepdims)

def unit(x, axis=-1):
    l = (dot(x, x, axis, keepdims=True))**(1/2)
    return x / l

def solve_quadratic(a, b, c, wall_mask):
    """Want smallest positive solution, if exists"""
    tol = 1e-5
    inf = 1e12
    # We insist that a is positive to avoid redundant code to handle pos & neg a.
    a_sign = 2*(a >= 0) - 1
    a *= a_sign
    b *= a_sign
    c *= a_sign
    d = b**2 - 4*a*c
    a_mask = a > tol
    b_mask = jnp.abs(b) > tol
    d_mask = a_mask & (d >= 0)
    d = jnp.where(d_mask, jnp.sqrt(d), inf)

    # It can be tricky to avoid "redetecting" the same collision twice.
    # Typically, if we compute next collision times between particle & walls,
    # the smallest dt will be the wall it JUST hit.  Without care, the particle 
    # could "stick" on the wall.  
    
    # One naive solution sets a minimal positive tol
    # and ignores smaller dt's.  However, this will evitably break when a collision
    # near a corner give very small, but legit, dt.
    # Conversely, we can not simply ignore the wall the particle just hit.
    # In the presence of force or curved wall, a particle can legitimately
    # hit the same wall twice.
    # Instead, we track which wall the particle just hit (wall_mask) and discard the 
    # smallest positive dt between them.
    
    # Note, we imposed a > 0 so we can be certain
    # that -b-d/2a < -b+d/2a.  Thus, if -b-d/2a > 0, then it is the smallest positive dt.
    # For numerical reasons, we define "positive" as > -tol
    # For efficiency, do partial calculations to test
    # ex: (-b+-d)/2a > -tol => +-d > b-2*a*tol
    # compute RHS once and create boolean mask where true
    e = b - 2*a*tol
    q0_mask =  d_mask & (-d > e)
    q1_mask =  d_mask & ( d > e)
    l = -c / b
    l_mask  = ~a_mask & b_mask & (l > -tol)

    # plain Python logic would be:
    # if q0_mask:
    #     if ~wall_mask:
    #         return -b-d/2a
    #     else:
    #         return -b+d/2a
    # elif q1_mask:
    #     if ~wall_mask:
    #         return -b+d/2a
    #     else:
    #         return inf
    # elif l_mask:
    #     if ~wall_mask:
    #         return -c/b
    #     else:
    #         return inf
    # else:
    #     return inf
    return jnp.where(q0_mask & ~wall_mask, (-b-d)/(2*a), jnp.where((q0_mask & wall_mask) | (q1_mask & ~wall_mask), (-b+d)/(2*a), jnp.where(l_mask & ~wall_mask, l, inf)))

class Base():
    def __getitem__(self, key):
        return getattr(self, key)

    def __setitem__(self, key, val):
        setattr(self, key, val)

class Ngon(Base):
    def __init__(self, sides=6, length=10):
        self.sides = int(sides)
        self.length = float(length)
        self.angle = 2*np.pi/self.sides
        self.radius = self.length / (2*np.sin(self.angle/2))
        theta = (np.arange(self.sides+1) - 1/2) * self.angle
        vert  = np.stack([np.cos(theta), np.sin(theta)]) * self.radius
        tang  = np.diff(vert)
        norm  = np.stack([-tang[1], tang[0]])  # inward
        self.vert = vert.T
        self.tang = tang.T
        self.norm = norm.T

class Particles(Base):
    def __init__(self, p=2, mass=1.0, speed=1.0, force=[0.0, 0.0], collision_law='specular'):
        self.p = int(p)
        self.mass = float(mass)
        self.speed = float(abs(speed))
        self.force = jnp.array(force, dtype=float)
        self.collision_law = collision_law
        laws = ['specular']
        if collision_law in laws:
            self.collide = self[f'{collision_law}_collision']
        else:
            raise Exception(f"collision_law must be one of {laws}; got '{collision_law}'")

    def initialize(self, tbl):
        w = rng.integers(0, tbl.sides, self.p)
        v = tbl.vert[w]
        t = tbl.tang[w]
        a = rng.uniform(low=0, high=1, size=[self.p, 1])
        self.pos = jnp.array(a*t + v)
        # I do not think I have the correct stationary distribution for particle velocity leaving the surface
        # I recall that the tangentials components should be gaussian, but the normal should
        # have some sort of Maxwell-Boltzmann or chi or chi-square or or rayleigh or something.
        # I'll put a guess here, but you should fix
        t = unit(tbl.tang[w])
        n = unit(tbl.norm[w])
        a = rng.normal(size=[self.p, 1])
        b = np.sqrt(rng.chisquare(2, size=[self.p, 1]))
        self.vel = jnp.array(a*t + b*n) * self.speed

        # Push table vectors to jax and expand dims so broadcasting is easier later
        # reshape is [#walls, #particles, dim] & make tangents & normals unit vectors
        tbl.x = jnp.array(     tbl.vert [:-1,None,:]).tile([1,self.p,1])
        tbl.t = jnp.array(unit(tbl.tang)[:  ,None,:]).tile([1,self.p,1])
        tbl.n = jnp.array(unit(tbl.norm)[:  ,None,:]).tile([1,self.p,1])
        self.acc = jnp.array(self.force / self.mass).tile([self.p,1])
        self.a = dot(tbl.n, self.acc) / 2
        
        # initialize wall mask
        self.wall = jnp.array(w)
        m = np.zeros_like(self.a, dtype=bool)
        for p, w in enumerate(self.wall):
            m[w, p] = True
        self.mask = jnp.array(m)
        self.time = jnp.zeros(self.p)
        self.record()

    def record(self):
        for val in ['Pos', 'Vel', 'Acc', 'Time', 'Wall']:
            self.__dict__.setdefault(val, []).append(np.array(self[val.lower()]))

    def finalize(self):
        for val in ['Pos', 'Vel', 'Acc', 'Time', 'Wall']:
            self[val] = np.array(self[val])

    def specular_collision(self, tbl):
        n = tbl.n.at[self.wall, 0].get()  # get normal for the wall each particle hits
        c = 2 * dot(self.vel, n, keepdims=True)
        self.vel -= (c * n)
    
    def update(self, tbl, report=False):
        dx = (self.pos - tbl.x)
        self.b  = dot(tbl.n, self.vel)
        self.c  = dot(tbl.n, dx)
        self.T = solve_quadratic(self.a, self.b, self.c, self.mask)
        self.dt = self.T.min(axis=0)
        if report:
            print(f'================================================\n================================================\nSTEP {len(self.Time)-1}')
            for val in ['wall', 'mask', 'pos', 'vel', 'acc', 'a', 'b', 'c', 'T', 'dt']:
                print(val)
                pprint(self[val])
        self.mask = self.T == self.dt
        self.wall = self.T.argmin(axis=0)
        self.time += self.dt
        dt = self.dt[:,None]
        self.pos += self.vel*dt + 1/2*self.acc*dt**2
        self.vel += self.acc*dt
        if report:
            for val in ['wall', 'mask', 'pos', 'vel']:
                print(f'new {val}')
                pprint(self[val])
        self.collide(tbl)
        self.record()
        if report:
            print('post collision vel')
            pprint(self.vel)

    def draw(self, p, frame_length=0.1):
        self.finalize()
        t   = self.Time[:,p].copy()
        pos = self.Pos [:,p].copy()
        vel = self.Vel [:,p].copy()
        acc = self.Acc [:,p].copy()

        frames = np.ceil(np.diff(t) / frame_length).astype(int)
        t_interp = np.unique(np.hstack([np.linspace(t[k], t[k+1], frames[k]) for k in range(len(t)-1)]))
        dt = (t_interp[:,None] - t)[...,None]
        dt[dt < 0] = np.nan
        dt[dt > np.nanmin(dt, axis=1, keepdims=True)] = np.nan
        pos = pos + vel * dt + 1/2 * acc * dt**2
        vel = vel + acc * dt
        vel = np.nanmin(vel, axis=1).T
        pos = np.nanmin(pos, axis=1).T
        t = t_interp.round(2)
        self.df_interp = pd.DataFrame({'t' : t, 'x' : pos[0], 'y' : pos[1], 'vx': vel[0], 'vy': vel[1], 'size':10})
        self.fig =  px.scatter(self.df_interp, x='x', y='y', animation_frame='t', size='size', width=750, height=700)
        self.fig.add_scatter(x=self.df_interp['x'], y=self.df_interp['y'], line=dict(color="gray", width=0.5))
        self.fig.add_scatter(x=tbl.vert[:,0], y=tbl.vert[:,1], line=dict(color="red", width=2))
        self.fig.update(layout_showlegend=False)
        self.fig.show()


tbl = Ngon(sides=6, length=10)
part = Particles(p=5, collision_law='specular', force=[0.0,-0.5])
part.initialize(tbl)
for k in range(10):
    part.update(tbl, report=False)
part.finalize()

In [3]:
part.draw(p=0, frame_length=0.1)