---
layout: post
title: N-Body
---

In [1]:
from pathlib import Path
import os
import functools

from IPython.display import HTML, Image
import matplotlib.pyplot as plt
from numpy import *
from celluloid import Camera
import matplotlib.patches as patches
from scipy.integrate import odeint, solve_ivp
from simple_pid import PID

ROOT = Path("./assets/img/")

if not os.path.exists(ROOT):
    os.makedirs(ROOT)

N-Body problem.

$$
\frac{d^2 r}{dt^2} = -G \sum_{i=1}^N \frac{m_i}{|r_i|^2} r_i
$$

In [82]:
G = 1e-3
BODY_NUM = 10
MASS = [1e4] + [1]*(BODY_NUM-1)

In [83]:
def get_color(idx, count):
    return plt.get_cmap("hsv", count)(idx)

In [90]:
def get_bodies():

    rng = random.default_rng(seed=0)
    
    R = 1
    V = 1

    state_arr = [[0., 0., 0., 0., 0., 0.]]

    theta_arr = linspace(0, 2*pi, BODY_NUM)[:-1]
    
    for theta in theta_arr:
        x = R*cos(theta)
        y = R*sin(theta)
        z = 0

        vx = -V*sin(theta)
        vy = V*cos(theta)
        vz = 0

        state_arr.append([x, y, z, vx, vy, vz])
        
    state_arr = array(state_arr)
    state_arr[1:] = state_arr[1:] + rng.uniform(-0.1, 0.1, size=state_arr[1:].shape)
    return state_arr

In [91]:
def motion_step(s0, t):
    
    s0 = s0.reshape(-1, 6)
    r0, v0 = s0[:, :3], s0[:, 3:]

    N = r0.shape[0]
    
    mask = ones(shape=(N, N))
    mask = mask - eye(N)
    mask = mask[..., None]

    mass = array(MASS)
    mass = mass[None, :, None]

    
    r = r0[:, None, :] - r0[None, :, :]
    
    eps = 1e-7
    dist_sq = (r**2).sum(axis=-1) + eps**2   # (N, N)
    inv_r3  = 1.0 / (dist_sq * sqrt(dist_sq))
    r *= inv_r3[..., None]
    
    dv = -G*(mask*mass*r).sum(axis=1)
    dr = v0

    s = concat([dr, dv], axis=1)
    s = s.reshape(-1)
    return s

In [111]:
def sim_n_body(tail=False):

    rng = random.default_rng(seed=0)

    T = 5
    t = linspace(0, T, 250)

    s0 = get_bodies()
    
    sol = solve_ivp(lambda t, s: motion_step(s, t), (0, T), s0.reshape(-1), t_eval=t)

    sol = sol.y.transpose(1, 0)
    
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')

    camera = Camera(fig)
    
    for idx, step in enumerate(sol):
        step = step.reshape(-1, 6)

        body_num = step.shape[0]
        
        r = step[:,:3]
        colors = [get_color(idx, body_num) for idx in range(body_num)]
        ax.scatter(r[:,0], r[:,1], r[:,2], color=colors)

        if tail:
            line = sol[:idx+1]
            line = line.reshape(line.shape[0], -1, 6)
            line = line.transpose(1, 0, 2)

            for body_idx, body in enumerate(line):
                ax.plot(body[:,0], body[:,1], body[:,2], color=colors[body_idx])
        
        camera.snap()

    anim = camera.animate()
    plt.close()

    gif_path = ROOT / "nbody.gif"  
    anim.save(gif_path, writer="pillow", fps=10)
    return Image(url=gif_path)

sim_n_body(tail=True)