In [54]:
%matplotlib qt
# %matplotlib qt 

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection

import adaptoctree.morton as morton
import adaptoctree.tree as tree
from fmm import Fmm
from fmm.kernel import laplace_p2p_serial

In [112]:
e = Fmm('spherical2')

In [113]:
e.nleaves

270

In [114]:
e.run()

In [115]:
def find_vertices(anchor, r0, x0):
    
    level = anchor[-1]
    scale = 1./(1 << level)
    r = r0*scale
    
    vertices = np.array([
        [0, 0, 0],
        [1, 0, 0],
        [1, 1, 0],
        [0, 1, 0],
        [0, 0, 1],
        [1, 0, 1],
        [1, 1, 1],
        [0, 1, 1]
    ]) + anchor[:3]
    
    vertices = vertices*scale + (x0-r0)

    return vertices, ([
        [vertices[0], vertices[1], vertices[2], vertices[3]],
        [vertices[4], vertices[5], vertices[6], vertices[7]], 
        [vertices[0], vertices[1], vertices[5], vertices[4]], 
        [vertices[2], vertices[3], vertices[7], vertices[6]], 
        [vertices[1], vertices[2], vertices[6], vertices[5]],
        [vertices[4], vertices[7], vertices[3], vertices[0]]
    ])


def plot_node(ax, key, r0, x0):
    anchor = morton.decode_key(key)
    vertices, edges = find_vertices(anchor, r0, x0)
    zline = vertices[:, 2]
    xline = vertices[:, 0]
    yline = vertices[:, 1]
    ax.scatter3D(xline, yline, zline, c='gray')
    ax.add_collection3d(Poly3DCollection(edges, edgecolors='black', alpha=0.1))

In [116]:
ax = plt.axes(projection='3d')

ax.scatter3D(e.sources[:, 0], e.sources[:, 1], e.sources[:, 2], c='k', s=0.1)

for leaf in e.leaves:
    plot_node(ax, leaf, e.r0, e.x0)
    
plt.show()

In [117]:
e.target_potentials

array([817.0498 , 794.9653 , 799.9665 , ..., 778.36066, 782.69214,
       799.25214], dtype=float32)

In [118]:
direct = laplace_p2p_serial(e.sources, e.targets, e.source_densities)

In [119]:
error = abs(direct-e.target_potentials)/direct

In [120]:
plt.figure()
plt.hist(error*100, bins=100, range=(0, 0.5))
plt.xlabel('Percentage Error')
plt.ylabel('Frequency')
plt.show()