# N-Body Problem (parallel)

24 July 2017 | Python

In [32]:
import time
import numpy as np
from multiprocessing import Pool

### Functions for calculation

In [33]:
def remove_i(x, i):
    """Remove i-th element from an array."""
    
    shape = (x.shape[0] - 1,) + x.shape[1:]
    
    y = np.empty(shape, dtype=float)
    y[:i] = x[:i]
    y[i:] = x[i + 1:]
    
    return y    

In [34]:
def a(i, x, G, m):
    """Compute the acceleration."""
    
    x_i = x[i]
    x_j = remove_i(x, i)
    m_j = remove_i(m, i)
    
    diff = x_j - x_i
    
    mag3 = np.sum(diff**2, axis=1)**1.5
    
    result = G * np.sum(diff * (m_j / mag3)[:, np.newaxis], axis=0)
    
    return result

In [35]:
def timestep(x0, v0, G, m, dt, pool):
    """Compute the next position and velocity for all masses using pool."""
    
    N = len(x0)
    
    # multiprocessing
    tasks = [(i, x0, v0, G, m, dt) for i in range(N)]
    results = pool.map(timestep_i, tasks)
    
    x1 = np.empty(x0.shape, dtype=float)
    v1 = np.empty(v0.shape, dtype=float)
    
    for i, x_i1, v_i1 in results:
        x1[i] = x_i1
        v1[i] = v_i1
        
    return x1, v1

In [36]:
def initial_condition(N, D):
    """Generate initial condition for N masses in D space."""
    
    x0 = np.random.rand(N, D)
    v0 = np.zeros((N, D), dtype=float)
    m = np.ones(N, dtype=float)
    
    return x0, v0, m

### Multiprocessing

In [37]:
def timestep_i(args):
    """Compute the next position and velocity for all masses."""
    
    i, x0, v0, G, m, dt = args
    a_i0 = a(i, x0, G, m)
    v_i1 = a_i0 * dt + v0[i]
    x_i1 = a_i0 * dt**2 + v0[i] * dt + x0[i]
    
    return i, x_i1, v_i1

### Simulate time steps

In [38]:
# function for simulation
def simulate(P, N, D, S, G, dt):
    
    # define initial condition variables
    x0, v0, m = initial_condition(N, D)
    pool = Pool(P)
    
    # iterate time steps
    for s in range(S):
        x1, v1 = timestep(x0, v0, G, m, dt, pool)
        x0, v0 = x1, v1

In [40]:
# list of processes
lst_p = [1, 2, 4, 8]

# empty list for run times
runtimes = []

for P in lst_p:
    start = time.time()
    simulate(P, 128, 3, 300, 1.0, 1e-3)
    stop = time.time()
    runtimes.append(stop - start)

print(runtimes)

[2.3386130332946777, 1.330625057220459, 1.127058982849121, 1.290069818496704]


<i>Notebook by <a href="https://www.michaelsjoeberg.com">Michael Sjoeberg</a>, updated 24 July 2017.</i>