In [None]:
# put a description here

In [None]:
 import numpy as np 
import torch 
from matplotlib import pyplot as plt 
import matplotlib 
from sklearn.gaussian_process.kernels import ConstantKernel as C

from hsflfm.analysis import build_bulk_analyzer, ResultPlotter, convert_to_percentile 
from hsflfm.util import load_dictionary, MetadataManager

from analysis_example_scripts._build_grid import build_grid

In [2]:
result_folder = "../complete_results_20241227"
result_folder = "../complete_results_pseudo_inverse_20250803"

reload = False   
analyzer = build_bulk_analyzer(result_folder, reload=reload)

In [3]:
# remove all points that don't meet a defined error threshold
# error threhsold is in pixels^2 
total_points = len(analyzer.all_results["specimen_number"])
analyzer.enforce_error_threshold(0.0015) 
safe_points = len(analyzer.all_results["specimen_number"])

print(f"total # points: {total_points}")
print(f"after enforcing error threshold: {safe_points}")
print(f"{safe_points/total_points * 100:.1f}%")

# get indices for R only strikes so we can neglect those 
analyzer.all_results["mandible_order"]
R_indices = np.where(analyzer.all_results["mandible_order"] == "R only")[0]
not_R_indices =  np.where(analyzer.all_results["mandible_order"] != "R only")[0]

total # points: 14061
after enforcing error threshold: 11556
82.2%


In [None]:
# show all points on a 3D model of the trap jaw ant

p = analyzer.all_results["mesh_points"]
jitter = 10000
rand = (torch.rand(p.shape) - 0.5) * jitter
p = p + rand

v = analyzer.all_results["displacement"][:, 2]
v = convert_to_percentile(v)

fig = ResultPlotter.plot_mesh_with_points(
    points=p,
    opacity=0.95,
    point_values=v,
    points_on_surface=False,
    marker_dict={"size": 0.7, "colorscale": "Turbo", "opacity": 0.7},
)

In [None]:
# assess results based on a defined grid 

# load the outline for display
outline_filename = "../hsflfm/ant_model/model_outline.npy"
outline = np.load(outline_filename)
avg_ant_scale = 1693 # use average so results are in mm

# form the grid
x_bounds, y_bounds =  build_grid(analyzer=analyzer, show=True)

In [None]:
# every point gets a grid index 
# (-1, -1) means its outside the grid 

grid_indices = np.ones_like(locs) * -1 
grid_indices = grid_indices.astype(np.uint8)

# and get a count of how many points are in each grid 
grid_counts = np.zeros((x_count, y_count), dtype=np.uint16) 

for i, j in np.ndindex((x_count, y_count)):
    # check if its in the four corners of the bound 
    i0 = locs[:, 0] >= x_bounds[i]
    i1 = locs[:, 0] < x_bounds[i + 1]
    i2 = locs[:, 1] >= y_bounds[j]
    i3 = locs[:, 1] < y_bounds[j + 1]

    point_indices = np.where(i0 & i1 & i2 & i3)[0]

    grid_indices[point_indices] = [i, j] 

    grid_counts[i, j] = len(point_indices) 

In [None]:
# plot mean/std for x, y, z disp in grid regions
means_3d = np.ones((3, x_count, y_count)) * np.nan 
std_dev_3d = np.ones((3, x_count, y_count)) * np.nan

exclude_R_only = True  

for i, j in np.ndindex((x_count, y_count)):
    #ax = axes[x_count - j - 1, i]
    grid_index = [i, j]
    point_indices = np.where((grid_indices[:, 0] == grid_index[0]) & (grid_indices[:, 1] == grid_index[1]))

    if exclude_R_only:
        point_indices = np.intersect1d(point_indices, not_R_indices)
        point_indices = [point_indices]
        

    if len(point_indices[0]) >= 65: 
        grid_disp = analyzer.all_results["displacement"][point_indices].numpy()

        for dim in range(3):
            r = grid_disp[:, dim] 
            means_3d[dim, i, j] = np.mean(r) 
            std_dev_3d[dim, i, j] = np.std(r)

means_3d[1] *= -1 

means_3d_temp = means_3d.copy() 
means_3d_temp[0] = means_3d[1] 
means_3d_temp[1] = means_3d[0] 

std_dev_3d_temp = std_dev_3d.copy()
std_dev_3d_temp[0] = std_dev_3d[1]  
std_dev_3d_temp[1] = std_dev_3d[0]


fig, axes = plt.subplots(1, 3)
for dim, ax in enumerate(axes):
    im = ax.imshow(np.fliplr(np.flipud(means_3d_temp[dim])) * 1e3,
                   clim=(np.nanmin(means_3d)*1e3, np.nanmax(means_3d)*1e3),
              cmap='turbo')
    ax.set_xticks([]) 
    ax.set_yticks([])

    for yb in range(means_3d[dim].shape[1]):
        ax.axvline(x=yb - 0.5, color='black') 
    for xb in range(means_3d[dim].shape[0]):
        ax.axhline(y=xb - 0.5, color='black')
    
fig.suptitle("Magnitude") 
#cax = fig.add_axes([1.0, 0.1, 0.03, 0.7])
fig.colorbar(im, ax=axes, orientation="vertical")
plt.tight_layout()


fig, axes = plt.subplots(1, 3)
for dim, ax in enumerate(axes):
    im = ax.imshow(np.fliplr(np.flipud(std_dev_3d_temp[dim])) * 1e3,
                   clim=(np.nanmin(std_dev_3d)*1e3, np.nanmax(std_dev_3d) * 1e3),
              cmap='plasma')
    
    ax.set_xticks([]) 
    ax.set_yticks([])

    for yb in range(means_3d[dim].shape[1]):
        ax.axvline(x=yb - 0.5, color='black') 
    for xb in range(means_3d[dim].shape[0]):
        ax.axhline(y=xb - 0.5, color='black')
#cax = fig.add_axes([1.0, 0.1, 0.03, 0.7])
fig.colorbar(im, ax=axes, orientation="vertical")
fig.suptitle("Standard Deviation") 
plt.tight_layout()

In [None]:
# then show colored arrows in each grid based on means_3d 
scale = 13 
cmap = matplotlib.cm.turbo 
maxz = np.nanmax(means_3d[2])
minz = np.nanmin(means_3d[2])

fig, ax = plt.subplots(1, 1)

ax.set_ylim(-1.2, 1.5) 

for yb in y_bounds:
    plt.plot([yb / avg_ant_scale, yb / avg_ant_scale],
             [np.min(x_bounds) / avg_ant_scale, np.max(x_bounds) / avg_ant_scale],
             color='lightgray', alpha=0.5)
for xb in x_bounds:
    plt.plot([np.min(y_bounds) / avg_ant_scale, np.max(y_bounds) / avg_ant_scale],
             [xb / avg_ant_scale, xb / avg_ant_scale],
             color='lightgray', alpha=0.5)

for i, j in np.ndindex(means_3d.shape[1:]): 
    locx = (x_bounds[i] + x_bounds[i + 1]) / (2 * avg_ant_scale)
    # locy = (y_bounds[j] + y_bounds[j + 1]) / (2 * avg_ant_scale) * -1
    locy = (y_bounds[j]) / (avg_ant_scale) * -1

    magx = means_3d[0, i, j] * scale
    magy = means_3d[1, i, j] * scale
    z = means_3d[2, i, j] 
    color = cmap((z - minz) / (maxz - minz))
    ax.arrow(locy - magy / 2, locx - magx / 2, magy, magx,
              head_width=0.02, color=color)

ax.set_aspect("equal") 
ax.plot(outline[:, 1] / avg_ant_scale, outline[:, 0] / avg_ant_scale,
         color='black')


ax.set_ylabel("y (mm)")
ax.set_xlabel("z (mm)") 
ax.set_xlabel("x (mm)") 

In [None]:
# Plot over image of an ant
from hsflfm.util import matmul 
from hsflfm.processing import world_frame_to_pixel
from hsflfm.calibration import FLF_System 

specimen = "20240502_OB_1" 
result_filename = f"{result_folder}/{specimen}/strike_1_results.json"
results = load_dictionary(result_filename) 
mm = MetadataManager(specimen) 
system = FLF_System(mm.calibration_filename) 
M = np.linalg.inv(np.asarray(results["A_cam_to_ant_start"]))

relative_scale = avg_ant_scale / results["ant_scale"]

image = mm.light_calibration_images[2] 
# cheating to fake a colorbar 
img = np.random.random(image.shape) * (maxz - minz) + minz 
plt.imshow(img * 1e3, cmap='turbo') 
plt.colorbar() 
plt.imshow(image, cmap='gray')

ax = plt.gca() 
for i, j in np.ndindex(means_3d.shape[1:]): 
    locx = (x_bounds[i] + x_bounds[i + 1]) / (2 * avg_ant_scale)
    # locy = (y_bounds[j] + y_bounds[j + 1]) / (2 * avg_ant_scale) * -1
    locy = (y_bounds[j]) / (avg_ant_scale) * -1

    magx = means_3d[0, i, j] * scale
    magy = means_3d[1, i, j] * scale
    if np.isnan(magx) or np.isnan(magy):
        continue
    z = means_3d[2, i, j] 
    color = cmap((z - minz) / (maxz - minz))

    # switch arrow into image coordinates 
    # first need to go from ant coordinates to world coordinates
    # then from world coordinates to image coordinates 
    locx *= relative_scale
    locy *= relative_scale
    magx *= relative_scale
    magy *= relative_scale
    start = [[locx - magx / 2, -(locy - magy / 2), 1]]
    end = [[start[0][0] + magx, start[0][1] + magy, 1]]
    start_world = matmul(M, start).squeeze() 
    end_world = matmul(M, end).squeeze()
    start_pixel = world_frame_to_pixel(system, start_world) 
    end_pixel = world_frame_to_pixel(system, end_world) 

    mag0 = start_pixel[0][0] - end_pixel[0][0]
    mag1 = start_pixel[1][0] - end_pixel[1][0]
    ax.arrow(start_pixel[1][0], start_pixel[0][0], mag1, -mag0,
              head_width=2, color=color)

ax.set_xlim(20, 150)
ax.set_ylim(190, 10)    
ax.set_xticks([])
ax.set_yticks([]) 

# and get the arrow scale 
# this is the pixel scale with the display scale 
pixel_size = system.calib_manager.pixel_size 
mag = system.get_magnification_at_plane(2, 0, 0) 
m_per_pixel = (pixel_size / mag)

bar_size = 15e-6 
num_pixels = bar_size / m_per_pixel * scale 

start = [120, 30] 
ax.arrow(start[0], start[1], num_pixels, 0,
          head_width=2, color='white')

In [None]:
# plot angles for non-R-only strikes 
disp = analyzer.all_results["displacement"][not_R_indices].numpy()
rho = np.linalg.norm(disp, axis=-1) 
polar_angles = np.arccos(disp[:, 2] / rho)
azi_angles = np.arctan2(disp[:, 0], disp[:, 1])

loc_temp = locs_jittered[not_R_indices] / avg_ant_scale

fig, ax = plt.subplots(1, 1)
im = ax.scatter(loc_temp[:, 1], loc_temp[:, 0],
            c=polar_angles, s=0.5, cmap='turbo_r',
            rasterized=True)
ax.set_aspect("equal") 
ax.set_ylim(-1.2, 1.5) 
fig.subplots_adjust(wspace=-0.5)

ax.set_ylabel("y (mm)")
ax.set_xlabel("x (mm)") 

cax = fig.add_axes([0.8, 0.15, 0.02, 0.7])
fig.colorbar(im, cax=cax, orientation="vertical")

ax.plot(outline[:, 1] / avg_ant_scale, outline[:, 0] / avg_ant_scale,
         color='black')

fig, ax = plt.subplots(1, 1)
im = ax.scatter(loc_temp[:, 1], loc_temp[:, 0],
            c=azi_angles, s=0.5, cmap='twilight',
            rasterized=True)
ax.set_aspect("equal") 
ax.set_ylim(-1.2, 1.5) 

fig.subplots_adjust(wspace=-0.5)

cax = fig.add_axes([0.8, 0.15, 0.02, 0.7])
fig.colorbar(im, cax=cax, orientation="vertical")

ax.plot(outline[:, 1] / avg_ant_scale, outline[:, 0] / avg_ant_scale,
         color='black')

ax.set_ylabel("y (mm)")
ax.set_xlabel("x (mm)") 