## Step 07: Vectorized Treecode

In [91]:
def evaluate(particles, p, t, cells, n_crit, theta):
    """Evaluate the gravitational potential at target point i, caused by source particles cell p. If leaf number of cell p is less than n_crit (twig), use direct summation. Otherwise (non-twig), loop in p's child cells. If child cell c is in far-field of target particle i, use multipole expansion. Otherwise (near-field), call the function recursively.
    
    Arguments:
        particles: the list of particles
        p: cell index in cells list
        t: leaf cell's index in cells list
        cells:   the list of cells
        n_crit:  maximum number of leaves in a single cell
        theta:   tolerance parameter
    
    """
    # non-twig cell
    if cells[p].nleaf >= n_crit:
        # loop in p's child cells (8 octants)
        for octant in range(8):
            if cells[p].nchild & (1 << octant):
                c = cells[p].child[octant]
                r = cells[t].distance(cells[c])
                # near-field child cell
                if (cells[c].r+cells[t].r) > theta*r:
                    evaluate(particles, c, t, cells, n_crit, theta)
                # far-field child cell
                else:
                    for i in range(cells[t].nleaf):
                        l = cells[t].leaf[i]
                        dx = particles[l].x - cells[c].x
                        dy = particles[l].y - cells[c].y
                        dz = particles[l].z - cells[c].z
                        r = particles[l].distance(cells[c])
                        r3 = r**3
                        r5 = r3*r**2
                        # calculate the weight for each multipole
                        weight = [1/r, -dx/r3, -dy/r3, -dz/r3, 3*dx**2/r5 - 1/r3, \
                                  3*dy**2/r5 - 1/r3, 3*dz**2/r5 - 1/r3, 3*dx*dy/r5, \
                                  3*dy*dz/r5, 3*dz*dx/r5]
                        particles[l].phi += numpy.dot(cells[c].multipole, weight)
    #twig cell
    else:
        for i in range(cells[t].nleaf):
            l = cells[t].leaf[i]
            for j in range(cells[p].nleaf):
                source = particles[cells[p].leaf[j]]
                r = particles[l].distance(source)
                if r != 0:
                    particles[l].phi += source.m / r

In [92]:
def eval_potential(particles, cells, leaves, n_crit, theta):
    for t in leaves:
        evaluate(particles, 0, t, cells, n_crit, theta)

In [93]:
import numpy
import time
from treecode_helper import *
from matplotlib import pyplot, rcParams
%matplotlib inline

# customizing plot parameters
rcParams['figure.dpi'] = 100
rcParams['font.size'] = 14
rcParams['font.family'] = 'StixGeneral'

In [94]:
n_crit = 10      # max number of particles in a single cell
theta = 0.2      # a parameter to determine far-field or near-field

In [95]:
#n = 10648           # number of particles
#particles = [ Particle(m=1.0/n) for i in range(n) ]

In [96]:
particles = read_particle('cube1000')

# direct summation
tic = time.clock()
direct_sum(particles)
toc = time.clock()

time_direct = toc - tic
phi_direct = numpy.asarray([particle.phi for particle in particles])

In [97]:
for particle in particles:
    particle.phi = 0.

In [98]:
# build tree
tic = time.clock()
root = Cell(n_crit)
root.x, root.y, root.z = 0.5, 0.5, 0.5
root.r = 0.5
cells = build_tree(particles, root, n_crit)
toc = time.clock()

time_src = toc - tic

In [99]:
# P2M: particle to multipole
tic = time.clock()
leaves = []
get_multipole(particles, 0, cells, leaves, n_crit)
toc = time.clock()

time_P2M = toc - tic

In [100]:
# M2M: multipole to multipole (upward translation)
tic = time.clock()
upward_sweep(cells)
toc = time.clock()

time_M2M = toc - tic

In [101]:
# evaluate potential
tic = time.clock()
eval_potential(particles, cells, leaves, n_crit, theta)
toc = time.clock()

time_eval = toc - tic
phi_tree = numpy.asarray([particle.phi for particle in particles])

In [102]:
time_tree = time_src + time_P2M + time_M2M + time_eval

In [103]:
#l2_err(phi_direct, phi_tree)

In [104]:
#plot_err(phi_direct, phi_tree)

In [105]:
#print time_direct
print time_src, time_P2M, time_M2M, time_eval
print time_tree
#print time_direct/time_tree

0.0309870000001 0.0200060000001 0.021065 5.224149
5.296207


In [13]:
print len(cells)
print len(leaves)
print float(len(leaves))/float(len(cells))

680
580
0.852941176471


In [14]:
print phi_tree

[ 4.44409759  4.02221784  3.37993562 ...,  4.46223907  2.89910495
  3.80525284]


In [15]:
from IPython.core.display import HTML
def css_styling():
    styles = open('./style/fmmstyle.css', 'r').read()
    return HTML(styles)
css_styling()