In [1]:
from hsflfm.analysis import (
    ResultManager,
    ResultPlotter,
    BulkAnalyzer,
    convert_to_percentile,
    get_random_percentile_index,
    sort_by_camera,
    get_percentiles,
)
from hsflfm.util import MetadataManager
from scipy.spatial import cKDTree
import os
from matplotlib import pyplot as plt
import torch
from hsflfm.util import (
    load_dictionary,
    save_dictionary,
    play_video,
    create_video_from_numpy_array,
)
from pathlib import Path
import numpy as np
from tqdm import tqdm
import matplotlib

In [2]:
# get all the filenames
all_filenames = []
f = "complete_results_20241227"
# f = "temporary_result_storage_5"
folders = os.listdir(f)
for inner in folders:
    path = Path(f) / inner
    if path.is_dir():
        filenames = os.listdir(path)
        for filename in filenames:
            if filename[-4:] == "json":
                all_filenames.append(str(path / filename))

In [3]:
# edit this if the naming convention changes
def get_filename(specimen_number, strike_number):
    return f"{f}/{specimen_number}/strike_{strike_number}_results.json"

In [4]:
print(len(all_filenames))
# print(all_filenames)

329


In [5]:
analyzer = BulkAnalyzer(all_filenames)

In [6]:
reload = False
res_filename = f"{f}/loaded_results.json"
# res_filename = "temp_loaded_results_5.json"


if reload:

    analyzer.load_results()
    save_dictionary(analyzer.all_results, res_filename)
else:
    analyzer.all_results = load_dictionary(res_filename)

    for key, value in analyzer.all_results.items():

        if key in ["specimen_number", "mandible_order"]:

            analyzer.all_results[key] = np.asarray(value)

            continue

        analyzer.all_results[key] = torch.asarray(value)

In [7]:
# define a "good point" for now
error_scores = analyzer.error_scores
good_indices = torch.where(error_scores < 0.0015)[0]

print("good ratio: {:.2f}".format(len(good_indices) / len(error_scores)))

good ratio: 0.92


In [8]:
# look at all results on the mesh

p = analyzer.all_results["mesh_points"]

# jitter the points

jitter = 10000

rand = (torch.rand(p.shape) - 0.5) * jitter

p = p + rand

v = analyzer.all_results["normalized_displacement"][:, 2]
v = analyzer.all_results["displacement"][:, 2]

v = v[good_indices]

v = convert_to_percentile(v)


fig = ResultPlotter.plot_mesh_with_points(
    points=p[good_indices],
    opacity=0.95,
    point_values=v,
    points_on_surface=False,
    marker_dict={"size": 0.7, "colorscale": "Turbo", "opacity": 0.7},
)

from hsflfm.ant_model import M_mesh_ant, mesh_scale
from hsflfm.util import matmul
import plotly.graph_objects as go

std_radius = 200
theta = np.linspace(0, 2 * np.pi, 200)
x = np.cos(theta) * std_radius
y = np.sin(theta) * std_radius * -1
z = np.zeros_like(theta)
points = np.concatenate((x[:, None], y[:, None], z[:, None]), axis=1)
mesh_points = matmul(np.linalg.inv(M_mesh_ant), points) * mesh_scale * 0.1

# fig.add_trace(
#     go.Scatter3d(
#         x=mesh_points[:, 0],
#         y=mesh_points[:, 1],
#         z=mesh_points[:, 2],
#         mode='lines',
#         marker={
#             "size": 100, "color": 'black'
#         }

#     )
# )

In [16]:
# try out getting a strength score for each video
specimen_names = MetadataManager.all_specimen_numbers()
specimen_names_old = specimen_names.copy()
scores = []
names = []
strike_nums = []

switch_names = ["20240418_OB_1", "20240422_OB_1", "20240427_OB_5"]
for n in switch_names:
    specimen_names_old[np.where(specimen_names == n)] = "2022" + n[4:]

all_error_scores = analyzer.error_scores
named_scores = {}
for name in tqdm(specimen_names_old):
    strike_numbers = MetadataManager(name).strike_numbers
    named_scores[name] = {}
    for strike_number in strike_numbers:
        strike_number = int(strike_number)
        idx = analyzer.get_specimen_indices(name, strike_number=strike_number)
        # only use points below the threshold
        idx = np.intersect1d(idx, good_indices)

        if len(idx) < 1:
            print(name, strike_number)

        # if len(idx) < 15:
        #     print(f"skipping {name} strike {strike_number}, {len(idx)} points")
        #     continue

        k = 15
        _, neighbor_indices = analyzer.get_closest_point_indices(
            k=k, indices=good_indices
        )
        neighbor_indices = neighbor_indices[idx]

        ratios = np.zeros(neighbor_indices.shape[0])
        for pi, neighbor_index in enumerate(neighbor_indices):
            displacements = analyzer.all_results["displacement"][neighbor_index]
            #disp_norm = displacements[:, :2]
            disp_norm = np.linalg.norm(displacements, axis=-1)

            #point_disp = analyzer.all_results["displacement"][idx[pi]][2]
            point_disp = np.linalg.norm(analyzer.all_results["displacement"][idx[pi]]) 
            ratios[pi] = point_disp / np.mean(disp_norm)

        score = np.mean(ratios)
        scores.append(score)
        strike_nums.append(strike_number)
        names.append(name)

        named_scores[name][strike_number] = score


Mean of empty slice.


invalid value encountered in scalar divide

 57%|█████▋    | 17/30 [00:11<00:10,  1.28it/s]

20240503_OB_3 14
20240503_OB_3 15


 70%|███████   | 21/30 [00:14<00:06,  1.41it/s]

20240506_OB_1 9
20240506_OB_1 10


 73%|███████▎  | 22/30 [00:14<00:05,  1.41it/s]

20240506_OB_3 2


100%|██████████| 30/30 [00:21<00:00,  1.38it/s]


In [12]:
analyzer.all_results["displacement"][idx[pi]]

tensor([ 0.0007, -0.0017,  0.0006])

In [None]:
scores = []
delays = []

for name, old_name in zip(specimen_names, specimen_names_old):
    if "20220418_OB_1" in old_name:
        continue
    mm = MetadataManager(name)
    strike_numbers = mm.strike_numbers
    for num in strike_numbers:
        try:
            num = int(num)
            mandible_frames = mm.mandible_start_frames(strike_number=num)
            diff = mandible_frames[0] - mandible_frames[1]
            frame_rate = int(mm.get_strike_data(num)["Frame Rate"])
            diff = diff / frame_rate * 1e3

            scores.append(named_scores[old_name][num])
            delays.append(diff)
        except Exception as e:
            print(f"skipping {name}, strike {num}, {e}")
plt.scatter(delays, scores)
plt.xlabel("Mandible delay (ms)")
plt.ylabel("Strength Score")

In [85]:
# indices = np.where(analyzer.all_results["mandible_order"] == "R only")[0]
# np.random.shuffle(indices)
# index = indices[0]
# specimen_name = analyzer.all_results["specimen_number"][index]
# strike_number = analyzer.all_results["specimen_number"][index]
# filename = f"temporary_result_storage_5/{name}/strike_{num}_results.json"
# plotter = ResultPlotter(load_dictionary(filename))
# video = plotter.get_arrow_video(movement_mag=5, disp_threshold=5e-3,
#                                 force_arrow_after_strike=False)

In [86]:
# reassign the strength scores to each point
strength_scores = torch.zeros_like(analyzer.error_scores)
for name, strike_num, score in zip(names, strike_nums, scores):
    idx0 = np.where(analyzer.all_results["specimen_number"] == name)[0]
    idx1 = np.where(analyzer.all_results["strike_number"] == strike_num)[0]

    indices = np.intersect1d(idx0, idx1)
    strength_scores[indices] = score

In [None]:
_ = plt.hist(scores, bins=50)
plt.title("Strength Scores")
plt.ylabel("Count")
plt.xlabel("Score")

In [None]:
# I'm curious which specimens have the largest variation in strike strength
# so somewhat inefficiently re-find the strength score for each strike
# and see how it varies
stds = []
for name in tqdm(specimen_names):
    strike_numbers = MetadataManager(name).strike_numbers
    scores = []
    for strike_number in strike_numbers:
        idx = analyzer.get_specimen_indices(name, strike_number=strike_number)
        if len(idx) < 1:
            continue
        strength_score = strength_scores[idx[0]]
        scores.append(strength_score)

    stds.append(np.std(scores))
    print(name, "{:.2f}".format(np.std(scores)))

In [None]:
# this can probably be deleted
name = "20240503_OB_4"
highlight_point = 23
strike_numbers = MetadataManager(name).strike_numbers

# cheating a bit here
# min_disp = np.inf
# max_disp = -np.inf

import matplotlib

cmap = matplotlib.cm.coolwarm
plt.figure()

disp = []
vel = []
c = []
for num in strike_numbers:
    filename = f"temporary_result_storage_5/{name}/strike_{num}_results.json"
    filename = f"test_full_results_from_manual_strike_transfer/{name}/strike_{num}_results.json"
    res = load_dictionary(filename)
    point_index = np.where(np.asarray(res["point_numbers"]) == highlight_point)[0][0]
    rm = ResultManager(res)

    # peak_index = rm.get_peak_indices()[point_index]
    peak_disp = rm.peak_displacements()[point_index]

    strike_center = int(rm.strike_center_index())
    rel_displacements = rm.rel_displacements[
        :, strike_center - 5 : strike_center + 11, 2
    ]
    frame_rate = 100e3  # shouldn't be hardcoding this...
    x_axis = np.arange(rel_displacements.shape[1]) / frame_rate * 1e3
    plt.figure()
    for p in rel_displacements:
        plt.plot(x_axis, p * 1e3, color="black")
    plt.plot(
        x_axis,
        1e3 * rel_displacements[point_index],
        color="red",
        label=f"point {highlight_point}",
    )
    plt.legend()
    plt.xlabel("Time (ms)")
    plt.ylabel("Displacement (um)")
    plt.title(f"{name}, strike{num}")

    plotter = ResultPlotter(res)
    # plotter.plot_all_displacement(good_only=True, highlight_point=highlight_point)
    plotter.scatter_peak_disp(good_only=True, highlight_point=highlight_point)
    ax = plt.gca()
    ax.set_title(
        f"{name} strike {num}, \n point {highlight_point} circled \n displacement (mm)"
    )


plotter.show_image_numbers()

In [None]:
# this can probably be deleted
name = "20240503_OB_4"
highlight_point = 23
strike_numbers = MetadataManager(name).strike_numbers

# cheating a bit here
# min_disp = np.inf
# max_disp = -np.inf

import matplotlib

cmap = matplotlib.cm.coolwarm
plt.figure()

disp = []
vel = []
c = []

mm = MetadataManager(name)

for num in strike_numbers:
    filename = f"temporary_result_storage_5/{name}/strike_{num}_results.json"
    filename = f"test_full_results_from_manual_strike_transfer/{name}/strike_{num}_results.json"
    res = load_dictionary(filename)
    point_index = np.where(np.asarray(res["point_numbers"]) == highlight_point)[0][0]
    rm = ResultManager(res)

    # peak_index = rm.get_peak_indices()[point_index]
    peak_disp = rm.peak_displacements()[point_index]

    min_disp = min(peak_disp, min_disp)
    max_disp = max(peak_disp, max_disp)

    c_value = (peak_disp - min_disp) / (max_disp - min_disp)
    color = cmap(c_value)

    strike_center = int(rm.strike_center_index())
    velocities, vel_indices = rm.max_abs_velocity(return_indices=True)
    strike_center = vel_indices[point_index]

    frame_rate = 100e3  # shouldn't be hardcoding this...
    x_axis = np.arange(-9, 9) / frame_rate * 1e3
    plt.plot(
        x_axis,
        1e3
        * rm.rel_displacements[point_index, strike_center - 9 : strike_center + 9, 2]
        / peak_disp,
        color=color,
    )

    mandible_frame = mm.mandible_start_frames(strike_number=num)[0] - strike_center
    plt.axvline(mandible_frame / frame_rate * 1e3, color=color)

    disp.append(peak_disp)
    vel.append(velocities[point_index])
    c.append(color)
    # plotter = ResultPlotter(res)
    # plotter.plot_all_displacement(good_only=True, highlight_point=23)
    # plotter.scatter_peak_disp(good_only=True, highlight_point=23)
# plotter.show_image_numbers()
plt.xlabel("Time (ms)")
plt.ylabel("Displacement (um)")
plt.title(f"Point {highlight_point} trajectories for {name}")
ax = plt.gca()
ax.set_facecolor("black")

plt.figure()
plt.scatter(disp, vel, c=c)

In [93]:
# n = np.nanargmax(scores)
# name = names[n]
# num = strike_nums[n]

# strike_indices = analyzer.get_specimen_indices(name, num)
# good_strike_indices = np.intersect1d(strike_indices, good_indices)
# print(name, "strike", num, "score:", scores[n])

# filename = f"temporary_result_storage_4/{name}/strike_{num}_results.json"
# assert os.path.exists(filename)

# result_info = load_dictionary(filename)
# plotter = ResultPlotter(result_info)

# plotter.scatter_peak_disp(highlight_point=10)


# fig = plotter.plot_all_displacement(highlight_point=10)
# # highlight the points below the error threshold in blue
# ax = fig.axes[0]
# bad_disp = plotter.result_manager.rel_displacements[~strike_good_point_indices]
# for p in bad_disp:
#     plt.plot(p[:, 2] * 1e3, "--", color=(0.5, 0.5, 1))
# # plotter.plot_displacement(10)

# #vid = plotter.get_arrow_video(cam_num=2)

In [None]:
n = np.nanargmax(scores)
name = names[n]
num = strike_nums[n]

strike_indices = analyzer.get_specimen_indices(name, num)
print(name, "strike", num, "score:", scores[n])

filename = f"temporary_result_storage_4/{name}/strike_{num}_results.json"
assert os.path.exists(filename)

result_info = load_dictionary(filename)
plotter = ResultPlotter(result_info)
good_strike_indices, plotter.good_indices, _ = np.intersect1d(
    strike_indices, good_indices, return_indices=True
)
plotter.scatter_peak_disp(highlight_point=10)
plotter.scatter_peak_disp(good_only=False)


fig = plotter.plot_all_displacement(highlight_point=10)

vid = plotter.get_arrow_video(cam_num=2)

In [None]:
play_video(vid)

In [None]:
# want to plot a bunch of points from different strikes
# find points in the region I want
distances = torch.linalg.norm(
    analyzer.all_results["start_locations_std"][:, :2], axis=1
)
low_indices = torch.where(distances < 200)[0]

# and for simplicity, only those with positive velocity
pos_vel_indices = torch.where(analyzer.all_results["max_z_velocity"][:, 2] > 0)

indices = np.intersect1d(good_indices, low_indices)
indices = np.intersect1d(indices, pos_vel_indices)

min_disp = np.inf
max_disp = -np.inf


# silly way to handle this, but its fine....
data = np.zeros((len(indices), 18))
strike_centers = np.zeros(len(indices))

plt.figure()
for i, index in enumerate(tqdm(indices)):
    try:
        point_number = analyzer.all_results["point_number"][index]
        name = analyzer.all_results["specimen_number"][index]
        strike_number = int(analyzer.all_results["strike_number"][index])

        # filename = (
        #     f"temporary_result_storage_5/{name}/strike_{strike_number}_results.json"
        # )
        filename = f"test_full_results_from_manual_strike_transfer/{name}/strike_{num}_results.json"
        res = load_dictionary(filename)
        rm = ResultManager(res)

        point_index = np.where(np.asarray(res["point_numbers"]) == point_number)[0][0]

        peak_disp = rm.peak_displacements()[point_index]

        min_disp = min(peak_disp, min_disp)
        max_disp = max(peak_disp, max_disp)

        velocities, vel_indices = rm.max_abs_velocity(return_indices=True)
        strike_center = vel_indices[point_index]

        frame_rate = int(
            MetadataManager(name).get_strike_data(strike_number)["Frame Rate"]
        )
        if frame_rate < 0.9e5:
            print(f"small frame rate for {name} {strike_number}")
            continue
        x_axis = np.arange(-9, 9) / frame_rate * 1e3
        plt.plot(
            x_axis,
            1e3
            * rm.rel_displacements[
                point_index, strike_center - 9 : strike_center + 9, 2
            ]
            / peak_disp,
            color="black",
        )

        data[i] = rm.rel_displacements[
            point_index, strike_center - 9 : strike_center + 9, 2
        ]
        strike_centers[i] = strike_center

    except Exception as e:
        print(f"failed on {name}, strike {strike_number}, point {point_number}, {e}")

    # if i > 50:
    #    break

In [None]:
len(indices)

In [None]:
count = 0

maxv = np.max(data)
minv = 0.008
print(minv, maxv)

import matplotlib

cmap = matplotlib.cm.magma

for i, p in enumerate(data):
    p = p - p[0]
    if np.max(np.abs(p)) < minv:
        continue

    # p = np.diff(p)

    index = indices[i]
    specimen_name = analyzer.all_results["specimen_number"][index]
    strike_number = int(analyzer.all_results["strike_number"][index])
    mm = MetadataManager(specimen_name)
    frame = mm.mandible_start_frames(strike_number)[0]
    # print(frame, strike_centers[i])
    frame = frame - strike_centers[i]
    frame = frame + (np.random.random() * 0.5) - 0.25

    # print(frame)

    peak = np.max(np.abs(p))
    color = cmap((peak - minv) / (maxv - minv))

    # p2 = np.diff(p)
    # x_axis = (np.arange(len(p2)) - 9) / 1e5 * 1e3

    # p2 is in mm/frame
    # p2 = p2 * 100000 / 1000 #(m/s)

    # p = np.diff(p)
    plt.plot(
        p, linewidth=0.4, color=color  # / np.max(np.abs(p[1:])),
    )  # / np.max(np.abs(p)),

    mandible_order = mm.mandible_order(strike_number)
    if mandible_order == "L":
        c2 = "red"
    elif mandible_order == "R":
        c2 = "green"
    elif mandible_order == "S":
        c2 = "blue"
    elif mandible_order == "L only":
        c2 = "purple"
    else:
        c2 = "white"

    # plt.axvline(frame / 1e5 * 1e3, color=c2)
    count += 1
print(count)

ax = plt.gca()
# ax.set_ylim(-1, 1)
ax.set_facecolor("black")
plt.ylabel("Velocity")
plt.xlabel("Time (ms)")
plt.title(
    f"Velocity normalized by max displacement for {count} points in saddle region"
)
ax = plt.gca()
# ax.set_yticks([])

# r = np.random.random((500, 500)) * (maxv - minv) + minv
# plt.figure()
# plt.imshow(r * 1e3, cmap=cmap)
# plt.colorbar()

In [None]:
1e5

In [None]:
v = np.linalg.norm(analyzer.all_results["max_z_velocity"][good_indices], axis=1)
d = np.linalg.norm(analyzer.all_results["displacement"][good_indices], axis=1)
plt.scatter(d, v, s=2)

distances = torch.linalg.norm(
    analyzer.all_results["start_locations_std"][good_indices, :2], axis=1
)
low_indices = torch.where(distances < 200)[0]

plt.figure()
frame_rate = 1e5
plt.scatter(d[low_indices] * 1e3, v[low_indices] * frame_rate, s=2)
ax = plt.gca()
ax.set_xlabel(r"max |displacement| (um)")
ax.set_ylabel(r"max |velocity| (mm/sec)")
ax.set_title(f"Displacement vs Velocity for {len(low_indices)} Points in Saddle")

In [None]:
locs = analyzer.all_results["start_locations_std"][good_indices]
plt.scatter(locs[:, 1], locs[:, 0], c=d, s=2)
plt.colorbar()
ax = plt.gca()
ax.set_aspect("equal")

In [None]:
v = analyzer.all_results["max_z_velocity"][good_indices, 2]
d = analyzer.all_results["displacement"][good_indices, 2]
plt.scatter(d, v, s=2)
plt.ylabel("max velocities (mm/frame)")
plt.xlabel("max displacement(mm)")

# then restrict to points in the saddle kind of
# for simplicity, going within 100 standardized units of (0, 0)
# plt.figure()
# x = analyzer.all_results["start_locations_std"][:, 0]
# y = analyzer.all_results["start_locations_std"][:, 1]
# c = analyzer.all_results["normalized_displacement"][:, 2]
# plt.scatter(y[good_indices], x[good_indices], c=c[good_indices], s=2, cmap='turbo',
#             clim=(-2, 4))
# ax = plt.gca()
# ax.set_aspect("equal")
# ax.invert_xaxis()
# plt.colorbar()

distances = torch.linalg.norm(
    analyzer.all_results["start_locations_std"][:, :2], axis=1
)
low_indices = torch.where(distances < 200)[0]
indices = np.intersect1d(good_indices.numpy(), low_indices.numpy())

v = analyzer.all_results["max_z_velocity"][indices, 2]
d = analyzer.all_results["displacement"][indices, 2]
orders = analyzer.all_results["mandible_order"][indices]
r_only_indices = np.where(orders == "R only")[0]


idx = np.where(v > 0)[0]
slope, offset = np.polyfit(v[idx], d[idx], deg=1)
x_vals = np.linspace(0, 0.01, 100)
y_vals = x_vals * slope + offset
plt.figure()
R = np.corrcoef(v[idx], d[idx])[0, 1]
R_Squared = R * R
# plt.plot(x_vals, y_vals, label="r^2 = {:.2f}".format(R_Squared))
# plt.legend()


plt.scatter(d, v, s=2)
plt.scatter(d[r_only_indices], v[r_only_indices], s=2, color="red")
plt.ylabel("max velocities (mm/frame)")
plt.xlabel("max displacement(mm)")

In [None]:
locs = analyzer.all_results["start_locations_std"][good_indices]
v = analyzer.all_results["max_z_velocity"][good_indices, 2]
d = analyzer.all_results["displacement"][good_indices, 2]
plt.scatter(locs[:, 1], locs[:, 0], c=v, s=2, cmap="turbo")
plt.colorbar()
ax = plt.gca()
ax.set_aspect("equal")

In [None]:
confusing_indices = indices[np.where(v < 0)[0]]
for i in confusing_indices:
    # print(analyzer.all_results["mandible_order"][i])
    # break
    if "R only" not in analyzer.all_results["mandible_order"][i]:
        print(
            analyzer.all_results["specimen_number"][i],
            int(analyzer.all_results["strike_number"][i]),
            int(analyzer.all_results["point_number"][i]),
            analyzer.all_results["mandible_order"][i],
        )
        # print(analyzer.all_results["displacement"][i])

In [None]:
filename = "temporary_result_storage_5/20220427_OB_4/strike_1_results.json"
info = load_dictionary(filename)
plotter = ResultPlotter(info)
plotter.scatter_peak_disp(highlight_point=14)
plotter.plot_all_displacement(highlight_point=14)
plotter.scatter_peak_disp(with_image=True)

video = plotter.get_arrow_video()

In [None]:
play_video(video, fps=2)

In [None]:
# pick a band
idx0 = torch.where(d > 0.019)[0]
idx1 = torch.where(d < 0.021)[0]
idx = np.intersect1d(idx0.numpy(), idx1.numpy())
idx = indices[idx]

for i in idx:
    specimen_number = analyzer.all_results["specimen_number"][i]
    point_number = int(analyzer.all_results["point_number"][i])
    strike_number = int(analyzer.all_results["strike_number"][i])
    # print(specimen_number, "point", point_number, "strike", strike_number, "index", i)


spec_indices = analyzer.get_specimen_indices("20220427_OB_5")
point_indices = torch.where(analyzer.all_results["point_number"] == 49)[0]
spec_indices1 = np.intersect1d(spec_indices, point_indices)
spec_indices1 = np.intersect1d(spec_indices1, good_indices)
strike_nums = analyzer.all_results["strike_number"][spec_indices1]
args = np.argsort(strike_nums)
spec_indices1 = spec_indices1[args]

point_indices2 = torch.where(analyzer.all_results["point_number"] == 5)[0]
spec_indices2 = np.intersect1d(spec_indices, point_indices2)
spec_indices2 = np.intersect1d(spec_indices2, good_indices)
strike_nums = analyzer.all_results["strike_number"][spec_indices2]
args = np.argsort(strike_nums)
spec_indices2 = spec_indices2[args]

v = analyzer.all_results["max_z_velocity"][spec_indices1, 2]
d = analyzer.all_results["displacement"][spec_indices1, 2]

v2 = analyzer.all_results["max_z_velocity"][spec_indices2, 2]
d2 = analyzer.all_results["displacement"][spec_indices2, 2]
plt.scatter(v, d)
plt.scatter(v2, d2)


diff = d - d2
plt.scatter(v, diff)

# cheating with this...
mandible_diffs = [14, 2, 3, 0, 0, 4, 0, 2, 2, 1, 2, 2, 1, 2, 0, 2, 3, 1, 0, 2, 1]

mand_start = np.asarray(
    [29, 14, 15, 16, 19, 19, 4, 13, 8, 22, 12, 12, 18, 5, 16, 29, 35, 7, 17, 5, 9]
)

strike_end = np.asarray(
    [55, 22, 24, 24, 28, 28, 11, 20, 17, 29, 19, 18, 25, 13, 22, 37, 42, 15, 24, 13, 18]
)
frame_diff = strike_end - mand_start

plt.figure()
plt.scatter(diff, mandible_diffs)

plt.figure()
plt.scatter(diff, frame_diff - mandible_diffs)

mm = MetadataManager("20220427_OB_5")
# plt.figure()
# for strike_num in mm.strike_numbers:
filename = f"{f}/20220427_OB_5/strike_{strike_num}_results.json"
res = load_dictionary(filename)
#     point_index = np.where(np.asarray(res["point_numbers"])==49)[0][0]
#     plt.plot(np.asarray(res["rel_displacements"])[point_index, :, 2])
plotter = ResultPlotter(res)
# plotter.scatter_peak_disp(highlight_point=49)
plotter.show_image_numbers()

In [None]:
image = mm.get_start_images(strike_number=1)[2]
match_points = np.asarray(load_dictionary(mm.match_points_filename)[2])
plt.imshow(image, cmap="gray")

points = [5, 49]
plt.scatter([match_points[points, 1]], [match_points[points, 0]])
ax = plt.gca()
ax.set_xlim(40, 120)
ax.set_ylim(150, 50)

In [None]:
name = "20240430_OB_2"
num = 1
filename = f"temporary_result_storage_4/{name}/strike_{num}_results.json"

spec_indices = analyzer.get_specimen_indices(name)
point_indices = torch.where(analyzer.all_results["point_number"] == 28)[0]
spec_indices1 = np.intersect1d(spec_indices, point_indices)
spec_indices1 = np.intersect1d(spec_indices1, good_indices)
strike_nums = analyzer.all_results["strike_number"][spec_indices1]
args = np.argsort(strike_nums)
spec_indices1 = spec_indices1[args]

point_indices2 = torch.where(analyzer.all_results["point_number"] == 8)[0]
spec_indices2 = np.intersect1d(spec_indices, point_indices2)
print(len(spec_indices2))
spec_indices2 = np.intersect1d(spec_indices2, good_indices)
print(len(spec_indices2))
strike_nums = analyzer.all_results["strike_number"][spec_indices2]
args = np.argsort(strike_nums)
spec_indices2 = spec_indices2[args]

v = analyzer.all_results["max_z_velocity"][spec_indices1, 2]
d = analyzer.all_results["displacement"][spec_indices1, 2]

v2 = analyzer.all_results["max_z_velocity"][spec_indices2, 2]
d2 = analyzer.all_results["displacement"][spec_indices2, 2]
plt.scatter(v, d)
plt.scatter(v2, d2)


# diff = d - d2
# plt.scatter(v, diff)

# # cheating with this...
# mandible_diffs = [14, 2, 3, 0, 0,
#                   4, 0, 2, 2, 1,
#                   2, 2, 1, 2, 0,
#                   2, 3, 1, 0, 2, 1]

# mand_start = np.asarray([
#     29, 14, 15, 16, 19,
#     19, 4, 13, 8, 22,
#     12, 12, 18, 5, 16,
#     29, 35, 7, 17, 5, 9
# ])

# strike_end = np.asarray([
#     55, 22, 24, 24, 28,
#     28, 11, 20, 17, 29,
#     19, 18, 25, 13, 22,
#     37, 42, 15, 24, 13, 18
# ])
# frame_diff = strike_end - mand_start

# plt.figure()
# plt.scatter(diff, mandible_diffs)

# plt.figure()
# plt.scatter(frame_diff, mandible_diffs)

# mm = MetadataManager("20220427_OB_5")
# plt.figure()
# for strike_num in mm.strike_numbers:
#     filename = f"{f}/20220427_OB_5/strike_{strike_num}_results.json"
res = load_dictionary(filename)
#     point_index = np.where(np.asarray(res["point_numbers"])==49)[0][0]
#     plt.plot(np.asarray(res["rel_displacements"])[point_index, :, 2])
plotter = ResultPlotter(res)
plotter.scatter_peak_disp()
plotter.show_image_numbers()

In [None]:
# look at all results on the mesh
p = analyzer.all_results["mesh_points"]

# jitter the points
jitter = 10000
rand = (torch.rand(p.shape) - 0.5) * jitter
p = p + rand

v = analyzer.all_results["normalized_displacement"][:, 2]
# v = analyzer.all_results["displacement"][:, 2]
good_point_indices = torch.where(analyzer.error_scores < 0.0015)[0]
strength_threshold = 1.4

weak_indices = np.intersect1d(
    good_point_indices, np.where(strength_scores < strength_threshold)[0]
)
v = v[weak_indices]
v = convert_to_percentile(v)

ResultPlotter.plot_mesh_with_points(
    points=p[weak_indices],
    opacity=0.2,
    point_values=v,
    points_on_surface=False,
    marker_dict={"size": 1, "colorscale": "Turbo"},
)

In [None]:
strength_threshold = 1

strong_indices = np.intersect1d(
    good_point_indices, np.where(strength_scores > strength_threshold)[0]
)
v = analyzer.all_results["displacement"][:, 2]
v = v[strong_indices]
v = convert_to_percentile(v)

ResultPlotter.plot_mesh_with_points(
    points=p[strong_indices],
    opacity=0.2,
    point_values=v,
    points_on_surface=False,
    marker_dict={"size": 1, "colorscale": "Turbo"},
)

In [118]:
# hacky way to do this but it's okay for now
not_a_strike_names = np.asarray(
    [
        "20240418_OB_1_rp",
        "20240506_OB_7",
        "20240507_OB_2",
        "20240502_OB_1",
        "20240503_OB_4",
        "20240507_OB_2",
        "20240507_OB_2",
        "20240507_OB_2",
        "20240418_OB_1",
        "20240507_OB_2",
        "20240507_OB_2",
        "20240418_OB_1_rp",
        "20240417_OB_1",
        "20240418_OB_1_rp",
        "20220427_OB_5",
        "20240503_OB_3",
        "20240507_OB_2",
        "20240507_OB_2",
        "20240507_OB_2",
        "20240507_OB_2",
        "20240418_OB_1_rp",
        "20240418_OB_1_rp",
        "20240507_OB_2",
        "20240503_OB_5",
        "20240506_OB_7",
        "20240417_OB_2",
        "20240507_OB_2",
        "20240417_OB_2",
        "20240418_OB_1_rp",
        "20240418_OB_1_rp",
        "20240418_OB_1_rp",
        "20240418_OB_1_rp",
        "20240418_OB_1",
        "20240506_OB_7",
        "20240506_OB_7",
    ]
)
not_a_strike_numbers = np.asarray(
    [
        18,
        14,
        4,
        10,
        1,
        9,
        8,
        16,
        1,
        10,
        11,
        16,
        5,
        15,
        1,
        5,
        1,
        5,
        6,
        15,
        10,
        4,
        7,
        4,
        8,
        12,
        2,
        5,
        1,
        14,
        2,
        6,
        3,
        1,
        9,
    ]
)


def is_a_strike(specimen_name, strike_number):
    indices = np.where(not_a_strike_names == specimen_name)
    if len(indices) == 0:
        return True
    if strike_number in not_a_strike_numbers[indices[0]]:
        return False
    return True

In [None]:
is_a_strike("20240418_OB_1", 3)

In [120]:
# show results from the weakest strikes
low_indices = np.argsort(scores)

for i, idx in enumerate(low_indices):
    if i <= 40 or i > 50:
        continue

    name = names[idx]
    strike_number = strike_nums[idx]

    indices = analyzer.get_specimen_indices(name, strike_number=strike_number)
    points = analyzer.all_results["start_locations_mm"][indices]
    values = analyzer.all_results["displacement"][indices][:, 2]

    plt.figure()
    plt.scatter(points[:, 1], points[:, 0], c=values)
    plt.title(
        f"""{name}, strike {strike_number}
              score: {scores[idx]} 
              strike? {is_a_strike(name, strike_number)} 
              mandible order {analyzer.all_results["mandible_order"][indices][0]}"""
    )
    ax = plt.gca()
    ax.set_aspect("equal")
    plt.colorbar()

In [None]:
# do some comparisons based on mandible order

In [None]:
# histogram flow differences
key = "average_flow_error"
all_flow = torch.mean(torch.abs(analyzer.all_results[key]), axis=1)
all_flow, _ = torch.sort(all_flow)

# cut-off at some percentile
cutoff = 0.995
cutoff_index = int(len(all_flow) * cutoff)

bins = plt.hist(all_flow[:cutoff_index], bins=50, alpha=0.5, label="all 3 cameras")
width = bins[1][1] - bins[1][0]

# add in the top two
flow = analyzer.get_top_values(key)
flow, _ = torch.sort(torch.mean(torch.abs(flow), axis=1))
flow = flow[:cutoff_index]
bins = np.arange(min(flow), max(flow) + width, width)
_ = plt.hist(flow, bins=bins, alpha=0.5, label="top 2 cameras")

plt.legend()
plt.ylabel("Point Count")
plt.xlabel("Flow Error (pixels)")
plt.title("Flow error in region around strike")

In [None]:
# histogram flow differences
key = "average_flow_sq"
all_flow = torch.mean(torch.abs(analyzer.all_results[key]), axis=1)
all_flow, _ = torch.sort(all_flow)

# cut-off at some percentile
cutoff = 0.95
cutoff_index = int(len(all_flow) * cutoff)

bins = plt.hist(all_flow[:cutoff_index], bins=50, alpha=0.5, label="all 3 cameras")
width = bins[1][1] - bins[1][0]

# # add in the top two
flow = analyzer.get_top_values(key)
flow, _ = torch.sort(torch.mean(torch.abs(flow), axis=1))
# cut-off at some percentile
cutoff = 0.99
cutoff_index = int(len(all_flow) * cutoff)
flow = flow[:cutoff_index]
bins = np.arange(min(flow), max(flow) + width, width)
_ = plt.hist(flow, bins=bins, alpha=0.5, label="top 2 cameras")

plt.legend()
plt.ylabel("Point Count")
plt.xlabel("Flow Error (pixels)")
plt.title("Flow error in region around strike")

In [None]:
# histogram flow differences
key = "average_huber_loss"
all_flow = torch.mean(torch.abs(analyzer.all_results[key]), axis=1)
all_flow, _ = torch.sort(all_flow)

# cut-off at some percentile
cutoff = 0.995
cutoff_index = int(len(all_flow) * cutoff)

bins = plt.hist(all_flow[:cutoff_index], bins=50, alpha=0.5, label="all 3 cameras")
width = bins[1][1] - bins[1][0]

# add in the top two
flow = analyzer.get_top_values(key)
flow, _ = torch.sort(torch.mean(torch.abs(flow), axis=1))
flow = flow[:cutoff_index]
bins = np.arange(min(flow), max(flow) + width, width)
_ = plt.hist(flow, bins=bins, alpha=0.5, label="top 2 cameras")

plt.legend()
plt.ylabel("Point Count")
plt.xlabel("Huber Loss")
plt.title("Huber Loss in region around strike")

In [None]:
# percentile differences between huber and flow in top 2 cameras
p0 = analyzer.get_percentile("average_flow_error", num_cams=2)
p1 = analyzer.get_percentile("average_flow_sq", num_cams=2)

diffs = torch.abs(p1 - p0)

_ = plt.hist(diffs, bins=50)

In [None]:
# percentile differences between huber and flow in top 2 cameras
huber_percentiles = analyzer.get_percentile("average_huber_loss", num_cams=2)
flow_percentiles = analyzer.get_percentile("average_flow_error", num_cams=2)

diffs = torch.abs(huber_percentiles - flow_percentiles)

_ = plt.hist(diffs, bins=50)

In [None]:
# look at which points vary the most from their neighbors
points = analyzer.all_results["start_locations_std"]
values = analyzer.all_results["normalized_displacement"]
tree = cKDTree(points)
k = 25
distances, indices = tree.query(points, k=k + 1)
# exclude self
neighbor_indices = indices[:, 1:]

neighbor_avgs = values[neighbor_indices].mean(axis=1)
difference = values - neighbor_avgs
neighbor_diff_z = difference[:, 2]

_ = plt.hist(torch.abs(neighbor_diff_z), bins=50)

In [None]:
# look at which points vary the most from their neighbors, not normalized
points = analyzer.all_results["start_locations_std"]
values = analyzer.all_results["displacement"]
tree = cKDTree(points)
k = 25
distances, indices = tree.query(points, k=k + 1)
# exclude self
neighbor_indices = indices[:, 1:]

neighbor_avgs = values[neighbor_indices].mean(axis=1)
difference = values - neighbor_avgs
neighbor_diff_z = difference[:, 2]

_ = plt.hist(torch.abs(neighbor_diff_z), bins=50)

In [None]:
# look at some points with bad flow
array = torch.mean(analyzer.get_top_values("average_flow_sq", num_cams=2), axis=1)
array = neighbor_diff_z


index = get_random_percentile_index(array.numpy(), 95, 100)

specimen_number = analyzer.all_results["specimen_number"][index]
point_number = int(analyzer.all_results["point_number"][index])
strike_number = int(analyzer.all_results["strike_number"][index])

print(specimen_number, "point", point_number, "strike", strike_number)
print("flow error: {:.5f} pixels".format(array[index]))
print(
    "percentile: {:.0f}%".format(
        100 * get_percentiles(array.numpy(), float(array[index]))
    )
)

In [None]:
# specimen_number = "20240503_OB_3"
# point_number = 15
# strike_number = 3

In [None]:
# specimen_number = str(analyzer.all_results["specimen_number"][0])
# point_number = 30
# strike_number = 5

In [129]:
# get the indices related to this strike
indices1 = np.where(analyzer.all_results["specimen_number"] == specimen_number)[0]
indices2 = np.where(analyzer.all_results["strike_number"] == strike_number)[0]

indices = np.intersect1d(indices1, indices2)

In [131]:
filename = f"test_full_results_from_manual_strike_transfer/{specimen_number}/strike_{strike_number}_results.json"
assert os.path.exists(filename)

result_info = load_dictionary(filename)


plotter = ResultPlotter(result_info)


# result_info["point_numbers"]


# plotter.result_info["removed_points"]


# plotter.result_info["points_used_in_gm"]

In [None]:
array = torch.mean(analyzer.get_top_values("average_flow_sq", num_cams=2), axis=1)
error_values = array[indices]
good_point_indices = error_values < 0.0015
fig = plotter.scatter_values(error_values, highlight_point=point_number)

# mark the points above the error threshold with a red x
ant_start_locs = plotter.result_manager.point_start_locs_ant_mm
ant_start_locs = ant_start_locs[np.where(~good_point_indices)]
ax = fig.axes[0]
ax.scatter(ant_start_locs[:, 1], ant_start_locs[:, 0], marker="x", color="red", s=7)

In [None]:
flow_diffs = plotter.result_manager.flow_diff_around_strike()
_, sorted = sort_by_camera(flow_diffs[:, :, None], treat_individually=False)
values = sorted.squeeze()[:, :2]
values = torch.mean(values, axis=1)
_ = plotter.scatter_values(values, highlight_point=point_number)

In [None]:
plotter.scatter_peak_disp(highlight_point=point_number, cmap="turbo")

# mark the points above the error threshold with a black x
ant_start_locs = plotter.result_manager.point_start_locs_ant_mm
ant_start_locs = ant_start_locs[np.where(~good_point_indices)]
ax = plt.gca()
ax.scatter(ant_start_locs[:, 1], ant_start_locs[:, 0], marker="x", color="black", s=15)

plotter.scatter_peak_disp(highlight_point=point_number, cmap="turbo", with_image=True)

In [None]:
_ = plotter.plot_camera_weight(point_number)

In [None]:
_ = plotter.plot_displacement(point_number)

In [None]:
plotter.show_flow_differences(point_number)

In [None]:
fig = plotter.plot_all_displacement(highlight_point=point_number)

# highlight the points below the error threshold in blue
ax = fig.axes[0]
bad_disp = plotter.result_manager.rel_displacements[~good_point_indices]
for p in bad_disp:
    plt.plot(p[:, 2] * 1e3, "--", color=(0.5, 0.5, 1))

In [None]:
vid = plotter.make_point_track_video(highlight_point=point_number)

In [None]:
play_video(vid)

In [None]:
vid = plotter.get_arrow_video()

In [None]:
play_video(vid)

In [None]:
m = plotter.result_manager.point_mesh_locations
ResultPlotter.plot_mesh_with_points(points=m)

In [None]:
# look into more:
# 20240507_OB_2 point 31 strike 10
# 20240502_OB_6 29
# 20240502_OB_2 alignment is super off
# "20220422_OB_1" something is wrong with strikes 6 and 7

# good examples
# 20240418_OB_1 barely any movement but clear pattern

In [None]:
# suggested threshold for being used in global movement calculation:
# 0.025 average error with top two cameras
# in region surrounding peak
# or... maybe squared error of 0.0015

In [None]:
# look at results on the mesh
# with ONLY points below the error threshold
array = torch.mean(analyzer.get_top_values("average_flow_sq", num_cams=2), axis=1)
good_indices = array < 0.0015
# good_indices = ~good_indices
p = analyzer.all_results["mesh_points"][good_indices]

# jitter the points

jitter = 1000
rand = (torch.rand(p.shape) - 0.5) * jitter
p = p + rand
v = analyzer.all_results["normalized_displacement"][good_indices, 2]
v = convert_to_percentile(v)

ResultPlotter.plot_mesh_with_points(
    points=p,
    opacity=0.1,
    point_values=v,
    points_on_surface=False,
    marker_dict={"size": 2, "colorscale": "bluered"},
)

In [None]:
# scattering the same stuff
array = torch.mean(analyzer.get_top_values("average_flow_sq", num_cams=2), axis=1)
good_indices = array < 0.0015

p = analyzer.all_results["start_locations_std"]
jitter = 30
rand = (torch.rand(p.shape) - 0.5) * jitter
p = p + rand

v = analyzer.all_results["normalized_displacement"][:, 2]
v = convert_to_percentile(v)
plt.scatter(
    p[good_indices, 1],
    p[good_indices, 0],
    s=1.5,
    c=v[good_indices],
    cmap="coolwarm",
    clim=(0, 100),
)
ax = plt.gca()
ax.set_aspect("equal")

plt.figure()
plt.scatter(
    p[~good_indices, 1],
    p[~good_indices, 0],
    s=1.5,
    c=v[~good_indices],
    cmap="coolwarm",
    clim=(0, 100),
)
ax = plt.gca()
ax.set_aspect("equal")

In [None]:
# some strikes that don't look good
# based on differences from nearby points
# 20220427_OB_3 strike 2 (this whole ant might be weird, look closer)
# we know 20240503_OB_3 was having issues
# check on what other samples had to change the error threshold
# for global movement calculation

# 20220427_OB_4 is interesting because most of the seleted points
# are in the saddle - probably makes computing normalized movement weird
# might want to think about other ways to get like a normalized score...
# like thinking about how much on average the points in that strike deviate
# from the expected strike strength in that region

# 20240507_OB_3 strike 9, example of a super weak strike that doesn't follow
# expected patterns. a lot of strikes from this ant look weak... investigate more

# you need to check on how many points made it over in later strikes
# for instance 20240502_OB_3 strike 9 has very few points
# it could be worth not actually dropping points that aren't CRAZY off
# but just saving information about the liklihood that the track was good
# and maybe just acknowledging in that way that it COULD be a different point

# 20240506_OB_7 strike 2, take a look at alignment here. Points are all really low

# 20240503_OB_3 strike 1 in particular, something is really weird here

In [None]:
# 2024/11/13
# goal for today:
# define some initial metric for a strength score
# then you can use that to do normalized strength measurements
# and also threshold strikes based on the strength score

In [154]:
# run through and look at images from every strike
i = 30
specimen_numbers = MetadataManager.all_specimen_numbers()

In [155]:
def plot_all_displacement(
        self, dim=2, relative=True, highlight_point=None, good_only=True,
        metadata_manager=None, crop=False, error_scores=None
    ):
        if relative:
            disp = torch.asarray(self.result_info["rel_displacements"])
        else:
            disp = torch.asarray(self.result_info["camera_point_displacements"])
        disp = disp[:, :, dim] * 1e3

        if good_only and self.good_indices is not None:
            disp_ = disp[self.good_indices]
        else:
            disp_ = disp

        fig = plt.figure()
        for i, p in enumerate(disp_):
            # hardcoding for now... 
            if error_scores is not None: 
                assert len(error_scores) == len(disp_)
                score = error_scores[i]
                cmap = matplotlib.cm.turbo 
                color = cmap(score / 0.0015)
            else:
                color = 'black'
            plt.plot(p, color=color)
        if highlight_point is not None:
            point_index = torch.where(
                torch.asarray(self.result_info["point_numbers"]) == highlight_point
            )[0][0]
            plt.plot(disp[point_index], color="red", label=f"point {highlight_point}")

        plt.xlabel("Frame #")
        plt.ylabel("Displacement (um)")

        if metadata_manager is not None:
            start_frames = metadata_manager.mandible_start_frames(strike_number=strike_num)
            if start_frames[0] == start_frames[1]:
                plt.axvline(x=start_frames[0], label="both mandibles", color='purple')
            else:
                plt.axvline(x=start_frames[0], label="left mandible", color='blue')
                plt.axvline(x=start_frames[1], label="right mandible", color='red')
            plt.legend()

        if crop:
            strike_center_index = self.result_manager.strike_center_index() 
            ax = plt.gca()
            ax.set_xlim(strike_center_index - 5, strike_center_index + 5)
        return fig

In [None]:
num = specimen_numbers[i]
strike_numbers = MetadataManager(num).strike_numbers
error_threshold = 0.0015
mm = MetadataManager(num)
for strike_num in strike_numbers:
    filename = get_filename(num, strike_num)
    if not os.path.exists(filename):
        print(f"no file for {num}, strike {strike_num}")
        continue

    info = load_dictionary(filename)
    result_manager = ResultManager(info)
    error_scores = result_manager.error_scores
    good_indices = torch.where(error_scores < error_threshold)
    plotter = ResultPlotter(
        info, good_indices=torch.where(error_scores < error_threshold)
    )
    fig = plotter.scatter_peak_disp(with_image=True)
    ax = plt.gca() 
    ax.set_title(f"{num}, {strike_num}")
    bad_indices = torch.where(error_scores >= error_threshold)
    #for bi in bad_indices:
    locations = result_manager.point_start_locs_ant_mm[bad_indices]
    ax.scatter(locations[:, 1], locations[:, 0], color='gray')

    plot_all_displacement(plotter, metadata_manager=mm,
                          error_scores=error_scores[good_indices])
    plot_all_displacement(plotter, metadata_manager=mm, crop=True,
                          error_scores=error_scores[good_indices])

In [None]:
video = plotter.get_arrow_video(good_only=False)

In [None]:
play_video(video[6:45], fps=2)

In [None]:
for num in specimen_numbers:
    mm = MetadataManager(num)
    strike_numbers = mm.strike_numbers
    
    for sn in strike_numbers:
        filename = mm.video_filename(strike_number=sn)
        #if str(sn) not in filename[-10:]:
        print(num, sn, Path(filename).name)