In [None]:
import sys, os
from pathlib import Path
import matplotlib.pyplot as plt

# Find the repository root by searching upward for a 'pyFMM' directory
start = Path.cwd().resolve()
repo_root = None
for p in [start] + list(start.parents):
    if (p / 'pyFMM').is_dir():
        repo_root = str(p)
        break
if repo_root is None:
    raise RuntimeError("Could not find 'pyFMM' in any parent directory of cwd")
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)
    
from pyFMM import *
import numpy as np
import time as time

In [None]:
np.random.seed(42)
N = 10000                              # number of points
q = np.random.choice([-1, 1], size=N)  # source strengths with magnitude 1 and random sign
#q = np.ones(N)  # source strengths with magnitude 1 and random sign

source_area_size = 1.0
LLC = np.array([ -1.0, -1.0, -1.0 ]) * source_area_size   # Lower Left Corner
URC = np.array([  1.0,  1.0,  1.0 ]) * source_area_size   # Upper Right Corner
size = URC - LLC
center = 0.5 * (LLC + URC)
X = np.random.uniform(low=LLC, high=URC, size=(N, 3))

p = 4

In [None]:
log4 = lambda x: np.log(x) / np.log(4)
log8 = lambda x: np.log(x) / np.log(8)

max_level_2D = np.floor(log4(N) - 1.0)
max_level_3D = np.floor(log8(N) - 1.0)

avg_point_per_leaf_2D = N / (4**max_level_2D)
avg_point_per_leaf_3D = N / (8**max_level_3D)

print("max level 2D:", max_level_2D)
print("avg points per leaf 2D:", avg_point_per_leaf_2D)

print("max level 3D:", max_level_3D)
print("avg points per leaf 3D:", avg_point_per_leaf_3D)

In [None]:
index_collapse = 1

size[index_collapse] = 0.0
X[:, index_collapse] = 0.0

In [None]:
max_level = 5
min_leaf_size = 10

#------------ make tree ------------------------------------------
tree = FMMTree(center, size, X, q, p=p, min_leaf_size=min_leaf_size, max_level=max_level)
#----------------------------------------------------------------
#-------------- make all nodes in the tree using BFS ---------------
tree.build_tree(BFS=True)
#--------------------------------------------------------------------
#---------- constrcut near-neighbors lists ----------------------
tree.make_lists()
#----------------------------------------------------------------

In [None]:
#for node in tree.node_list:
#    print(f"Level: {node.level}, Center: {node.center}, Size: {node.size}, Num Points: {node.num_points}", " len indices:", len(node.indices))

In [None]:
#------------ plot tree and points ------------------
fig, ax = plot_tree(tree)
ax.scatter(X[:,0], X[:,2], color='red', s=5)
#------------------------------------------------------

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 10))



# First subplot: Tree and points
plot_tree(tree, axis=axes[0])
axes[0].scatter(X[:, 0], X[:, 2], color='red', s=5)
axes[0].set_title("Tree and Points")
axes[0].set_xlabel("X-axis")
axes[0].set_ylabel("Z-axis")



points = np.array( [ [ 4.99, 0.0, 4.99]]) * 0.1


# Second subplot: Neighbors and interaction points
plot_tree(tree, axis=axes[1])
plot_neighbors_point_on_axis(points[0], tree, axes[1])
plot_interaction_point_on_axis(points[0], tree, axes[1])
axes[1].set_title("Neighbors and Interaction Points")
axes[1].set_xlabel("X-axis")
axes[1].set_ylabel("Z-axis")

# Third subplot: Node and its neighbors
node = tree.find_leaf_for_point(points[0])
plot_tree(tree, axis=axes[2])
plot_neighbors_point_on_axis(points[0], tree, axes[2])
axes[2].set_title("Node and Its Neighbors")
axes[2].set_xlabel("X-axis")
axes[2].set_ylabel("Z-axis")

plt.tight_layout()
plt.show()

In [None]:
#------------ select a point. Then plot the node it belongs to and its nearneighbohrs --------
points = np.array( [ [ 4.99, 0.0, 4.99]]) * 0.1
node = tree.find_leaf_for_point(points[0])
fig, ax = plot_tree(tree)
plot_neighbors_point_on_axis(points[0], tree, ax)
#---------------------------------------------------------------------------------------------

In [None]:
fig, ax = plot_tree(tree)
plot_neighbors_point_on_axis(points[0], tree, ax)
plot_interaction_point_on_axis(points[0], tree, ax)
#plot_interaction_node_on_axis(node, ax)

In [None]:
start_time = time.time()
tree.construct_moments()
#%prun tree.construct_moments()
end_time = time.time()
print(f"Moment construction time: {end_time - start_time}")

In [None]:
P_mom = tree.eval_P(points)[0]
P_dir = pot_eval.P_direct_cart(X, q, points)[0]
print("Direct potential:", P_dir)
print("FMM potential:", P_mom)

In [None]:
resolution = 2 ** (max_level + 1)
grid_x = np.linspace(LLC[0], URC[0], resolution, endpoint=False) + (URC[0] - LLC[0]) / resolution / 2
grid_y = np.linspace(LLC[1], URC[1], resolution, endpoint=False) + (URC[1] - LLC[1]) / resolution / 2
grid_z = np.linspace(LLC[2], URC[2], resolution, endpoint=False) + (URC[2] - LLC[2]) / resolution / 2   


if index_collapse == 0:
    grid_x = np.array([0.0])
elif index_collapse == 1:
    grid_y = np.array([0.0])
elif index_collapse == 2:
    grid_z = np.array([0.0]) 
grid = np.meshgrid(grid_x, grid_y, grid_z, indexing='ij')
grid_points = np.vstack([grid[0].ravel(), grid[1].ravel(), grid[2].ravel()]).T

In [None]:
start_time = time.time()
P_dir = pot_eval.P_direct_cart(X, q, grid_points)
P_dir_grid = P_dir.reshape(grid_x.shape[0], grid_z.shape[0])


end_time = time.time()
print(f"Direct eval time for {grid_points.shape[0]} points: {end_time - start_time}")   

start_time = time.time()
P_mom = tree.eval_P(grid_points)
P_mom_grid = P_mom.reshape(grid_x.shape[0], grid_z.shape[0])
end_time = time.time()
print(f"Tree eval time for {grid_points.shape[0]} points: {end_time - start_time}")


In [None]:
grid_points_2D = grid_points[:, [0, 2]]


fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# First subplot: Direct potential
plot_potential_grid(grid_x, grid_z, P_dir_grid, axes[0])
axes[0].set_title("Direct Potential")
axes[0].set_xlabel("X-axis")
axes[0].set_ylabel("Z-axis")
cbar = plt.colorbar(axes[0].collections[0], ax=axes[0])
cbar.set_label("Potential Value")

# Second subplot: FMM potential
plot_potential_grid(grid_x, grid_z,P_mom_grid, axes[1])
axes[1].set_title("FMM Potential")
axes[1].set_xlabel("X-axis")
axes[1].set_ylabel("Z-axis")
cbar = plt.colorbar(axes[1].collections[0], ax=axes[1])
cbar.set_label("Potential Value")

# Third subplot: Difference between Direct and FMM potential
difference = (P_dir_grid - P_mom_grid) / np.max(np.abs(P_dir_grid)) * 100
plot_potential_grid(grid_x, grid_z, difference, axes[2])
axes[2].set_title("Difference (Direct - FMM)")
axes[2].set_xlabel("X-axis")
axes[2].set_ylabel("Z-axis")
cbar = plt.colorbar(axes[2].collections[0], ax=axes[2])
cbar.set_label("Relative Difference (%)")

plt.tight_layout()
plt.show()

# Test of adaptive tree capabilities

In [None]:
# #--------------- prune tree ----------------------
# #NOTE - pruning works, but pruned lists are more complicated and not fully implemented 
# tree.prune_tree()
# #----------------------------------------------------------------------
# #------------ plot tree and points again ------------------
# fig, ax = plot_tree(tree)
# ax.scatter(X[:,0], X[:,1], color='red', s=5)
# #------------------------------------------------------

In [None]:
# #-------------- select a point. Then plot the node it belongs to and its nearneighbohrs --------
# point = np.array([ 0.5, -0.6])
# fig, ax = plot_tree(tree)
# plot_neighbors_point_on_axis(point, tree, ax)
# #---------------------------------------------------------------------------------------------