In [None]:
from hsflfm.analysis import (
    ResultManager,
    ResultPlotter,
    BulkAnalyzer,
    convert_to_percentile,
    get_random_percentile_index,
    sort_by_camera,
    get_percentiles,
)
from hsflfm.util import MetadataManager
from scipy.spatial import cKDTree
import os
from matplotlib import pyplot as plt
import torch
from hsflfm.util import load_dictionary, save_dictionary, play_video
from pathlib import Path
import numpy as np
from tqdm import tqdm

In [None]:
# get all the filenames
all_filenames = []
f = "temporary_result_storage_4"
folders = os.listdir(f)
for inner in folders:
    path = Path(f) / inner
    if path.is_dir():
        filenames = os.listdir(path)
        for filename in filenames:
            if filename[-4:] == "json":
                all_filenames.append(str(path / filename))

In [None]:
analyzer = BulkAnalyzer(all_filenames)

In [None]:
num_points = []
is_new_strike = np.zeros(len(all_filenames), dtype=int) 
prev_name = None 
names = []
for i, f in enumerate(tqdm(all_filenames)):
    info = load_dictionary(f)
    num_points.append(len(info["point_numbers"]))
    name = info["specimen_number"]
    is_new_strike[i] =(name != prev_name)

    prev_name = name 
    names.append(name)

    #if i > 30:
    #    break

In [None]:
reload = True
res_filename = "temp_loaded_results_4.json"


if reload:

    analyzer.load_results()
    save_dictionary(analyzer.all_results, res_filename)
else:
    analyzer.all_results = load_dictionary(res_filename)

    for key, value in analyzer.all_results.items():

        if key == "specimen_number":

            analyzer.all_results[key] = np.asarray(value)

            continue

        analyzer.all_results[key] = torch.asarray(value)

In [None]:
# look at all results on the mesh

p = analyzer.all_results["mesh_points"]

# jitter the points

jitter = 10000

rand = (torch.rand(p.shape) - 0.5) * jitter

p = p + rand

v = analyzer.all_results["normalized_displacement"][:, 2]
good_point_indices = torch.where(analyzer.error_scores < 0.0015)[0]

v = v[good_point_indices]

v = convert_to_percentile(v)



ResultPlotter.plot_mesh_with_points(
    points=p[good_point_indices],
    opacity=0.0,
    point_values=v,
    points_on_surface=False,
    marker_dict={"size": 1, "colorscale": "Turbo"},
)

In [None]:
# try out getting a strength score for each video
specimen_names = MetadataManager.all_specimen_numbers()
scores = []
names = []
strike_nums = []

all_error_scores = analyzer.error_scores
for name in tqdm(specimen_names):
    strike_numbers = MetadataManager(name).strike_numbers
    for strike_number in strike_numbers:
        idx = analyzer.get_specimen_indices(name, strike_number=strike_number)

        if len(idx) < 15:
            print(f"skipping {name} strike {strike_number}, {len(idx)} points")
            continue

        k = 15
        _, neighbor_indices = analyzer.get_closest_point_indices(
            k=k, indices=good_point_indices
        )
        neighbor_indices = neighbor_indices[idx]

        ratios = np.zeros(neighbor_indices.shape[0])
        for pi, neighbor_index in enumerate(neighbor_indices):
            displacements = analyzer.all_results["displacement"][neighbor_index]
            disp_norm = displacements[:, :2]

            point_disp = analyzer.all_results["displacement"][idx[pi]][2]
            ratios[pi] = torch.abs(point_disp) / torch.mean(torch.abs(disp_norm))

        # only use points below the threshold
        error_values = all_error_scores[idx]
        strike_good_point_indices = error_values < 0.0015

        if torch.count_nonzero(strike_good_point_indices) < 15:
            print(
                f"skipping {name} strike {strike_number}, {torch.count_nonzero(strike_good_point_indices)} good points"
            )
            continue

        #if "20240503_OB_3" in name:
        #    # will address this later, I just know this video is bad
        #    continue

        score = np.mean(ratios[strike_good_point_indices])
        scores.append(score)
        strike_nums.append(strike_number)
        names.append(name)

In [None]:
_ = plt.hist(scores, bins=50)
# plt.xlim(0, 10)

In [None]:
n = np.nanargmax(scores)
name = names[n]
num = strike_nums[n]

strike_indices = analyzer.get_specimen_indices(name, num)
print(name, "strike", num, "score:", scores[n])

filename = f"temporary_result_storage_4/{name}/strike_{num}_results.json"
assert os.path.exists(filename)

result_info = load_dictionary(filename)
plotter = ResultPlotter(result_info)
strike_good_point_indices = all_error_scores[strike_indices] < 0.0015

plotter.scatter_peak_disp(highlight_point=10)

fig = plotter.scatter_values(analyzer.error_scores[strike_indices])
fig.suptitle("error scores")
ant_start_locs = plotter.result_manager.point_start_locs_ant_mm
ant_start_locs = ant_start_locs[np.where(~strike_good_point_indices)]
ax = fig.axes[0]
ax.scatter(ant_start_locs[:, 1], ant_start_locs[:, 0], marker="x", color="red", s=7)


fig = plotter.plot_all_displacement(highlight_point=10)
# highlight the points below the error threshold in blue
ax = fig.axes[0]
bad_disp = plotter.result_manager.rel_displacements[~strike_good_point_indices]
for p in bad_disp:
    plt.plot(p[:, 2] * 1e3, "--", color=(0.5, 0.5, 1))
# plotter.plot_displacement(10)

vid = plotter.get_arrow_video(cam_num=2)

In [None]:
play_video(vid)

In [None]:
plt.imshow(vid[0])

In [None]:
# histogram flow differences
key = "average_flow_error"
all_flow = torch.mean(torch.abs(analyzer.all_results[key]), axis=1)
all_flow, _ = torch.sort(all_flow)

# cut-off at some percentile
cutoff = 0.995
cutoff_index = int(len(all_flow) * cutoff)

bins = plt.hist(all_flow[:cutoff_index], bins=50, alpha=0.5, label="all 3 cameras")
width = bins[1][1] - bins[1][0]

# add in the top two
flow = analyzer.get_top_values(key)
flow, _ = torch.sort(torch.mean(torch.abs(flow), axis=1))
flow = flow[:cutoff_index]
bins = np.arange(min(flow), max(flow) + width, width)
_ = plt.hist(flow, bins=bins, alpha=0.5, label="top 2 cameras")

plt.legend()
plt.ylabel("Point Count")
plt.xlabel("Flow Error (pixels)")
plt.title("Flow error in region around strike")

In [None]:
# histogram flow differences
key = "average_flow_sq"
all_flow = torch.mean(torch.abs(analyzer.all_results[key]), axis=1)
all_flow, _ = torch.sort(all_flow)

# cut-off at some percentile
cutoff = 0.95
cutoff_index = int(len(all_flow) * cutoff)

bins = plt.hist(all_flow[:cutoff_index], bins=50, alpha=0.5, label="all 3 cameras")
width = bins[1][1] - bins[1][0]

# # add in the top two
flow = analyzer.get_top_values(key)
flow, _ = torch.sort(torch.mean(torch.abs(flow), axis=1))
# cut-off at some percentile
cutoff = 0.99
cutoff_index = int(len(all_flow) * cutoff)
flow = flow[:cutoff_index]
bins = np.arange(min(flow), max(flow) + width, width)
_ = plt.hist(flow, bins=bins, alpha=0.5, label="top 2 cameras")

plt.legend()
plt.ylabel("Point Count")
plt.xlabel("Flow Error (pixels)")
plt.title("Flow error in region around strike")

In [None]:
# histogram flow differences
key = "average_huber_loss"
all_flow = torch.mean(torch.abs(analyzer.all_results[key]), axis=1)
all_flow, _ = torch.sort(all_flow)

# cut-off at some percentile
cutoff = 0.995
cutoff_index = int(len(all_flow) * cutoff)

bins = plt.hist(all_flow[:cutoff_index], bins=50, alpha=0.5, label="all 3 cameras")
width = bins[1][1] - bins[1][0]

# add in the top two
flow = analyzer.get_top_values(key)
flow, _ = torch.sort(torch.mean(torch.abs(flow), axis=1))
flow = flow[:cutoff_index]
bins = np.arange(min(flow), max(flow) + width, width)
_ = plt.hist(flow, bins=bins, alpha=0.5, label="top 2 cameras")

plt.legend()
plt.ylabel("Point Count")
plt.xlabel("Huber Loss")
plt.title("Huber Loss in region around strike")

In [None]:
# percentile differences between huber and flow in top 2 cameras
p0 = analyzer.get_percentile("average_flow_error", num_cams=2)
p1 = analyzer.get_percentile("average_flow_sq", num_cams=2)

diffs = torch.abs(p1 - p0)

_ = plt.hist(diffs, bins=50)

In [None]:
# percentile differences between huber and flow in top 2 cameras
huber_percentiles = analyzer.get_percentile("average_huber_loss", num_cams=2)
flow_percentiles = analyzer.get_percentile("average_flow_error", num_cams=2)

diffs = torch.abs(huber_percentiles - flow_percentiles)

_ = plt.hist(diffs, bins=50)

In [None]:
# look at which points vary the most from their neighbors
points = analyzer.all_results["start_locations_std"]
values = analyzer.all_results["normalized_displacement"]
tree = cKDTree(points)
k = 25
distances, indices = tree.query(points, k=k + 1)
# exclude self
neighbor_indices = indices[:, 1:]

neighbor_avgs = values[neighbor_indices].mean(axis=1)
difference = values - neighbor_avgs
neighbor_diff_z = difference[:, 2]

_ = plt.hist(torch.abs(neighbor_diff_z), bins=50)

In [None]:
# look at which points vary the most from their neighbors, not normalized
points = analyzer.all_results["start_locations_std"]
values = analyzer.all_results["displacement"]
tree = cKDTree(points)
k = 25
distances, indices = tree.query(points, k=k + 1)
# exclude self
neighbor_indices = indices[:, 1:]

neighbor_avgs = values[neighbor_indices].mean(axis=1)
difference = values - neighbor_avgs
neighbor_diff_z = difference[:, 2]

_ = plt.hist(torch.abs(neighbor_diff_z), bins=50)

In [None]:
# look at some points with bad flow
array = torch.mean(analyzer.get_top_values("average_flow_sq", num_cams=2), axis=1)
array = neighbor_diff_z


index = get_random_percentile_index(array.numpy(), 95, 100)

specimen_number = analyzer.all_results["specimen_number"][index]
point_number = int(analyzer.all_results["point_number"][index])
strike_number = int(analyzer.all_results["strike_number"][index])

print(specimen_number, "point", point_number, "strike", strike_number)
print("flow error: {:.5f} pixels".format(array[index]))
print(
    "percentile: {:.0f}%".format(
        100 * get_percentiles(array.numpy(), float(array[index]))
    )
)

In [None]:
# specimen_number = "20240503_OB_3"
# point_number = 15
# strike_number = 3

In [None]:
# specimen_number = str(analyzer.all_results["specimen_number"][0])
# point_number = 30
# strike_number = 5

In [None]:
# get the indices related to this strike
indices1 = np.where(analyzer.all_results["specimen_number"] == specimen_number)[0]
indices2 = np.where(analyzer.all_results["strike_number"] == strike_number)[0]

indices = np.intersect1d(indices1, indices2)

In [None]:
filename = (
    f"temporary_result_storage_3/{specimen_number}/strike_{strike_number}_results.json"
)
assert os.path.exists(filename)

result_info = load_dictionary(filename)



plotter = ResultPlotter(result_info)



# result_info["point_numbers"]



# plotter.result_info["removed_points"]



# plotter.result_info["points_used_in_gm"]

In [None]:
array = torch.mean(analyzer.get_top_values("average_flow_sq", num_cams=2), axis=1)
error_values = array[indices]
good_point_indices = error_values < 0.0015
fig = plotter.scatter_values(error_values, highlight_point=point_number)

# mark the points above the error threshold with a red x
ant_start_locs = plotter.result_manager.point_start_locs_ant_mm
ant_start_locs = ant_start_locs[np.where(~good_point_indices)]
ax = fig.axes[0]
ax.scatter(ant_start_locs[:, 1], ant_start_locs[:, 0], marker="x", color="red", s=7)

In [None]:
flow_diffs = plotter.result_manager.flow_diff_around_strike()
_, sorted = sort_by_camera(flow_diffs[:, :, None], treat_individually=False)
values = sorted.squeeze()[:, :2]
values = torch.mean(values, axis=1)
_ = plotter.scatter_values(values, highlight_point=point_number)

In [None]:
plotter.scatter_peak_disp(highlight_point=point_number, cmap="turbo")

# mark the points above the error threshold with a black x
ant_start_locs = plotter.result_manager.point_start_locs_ant_mm
ant_start_locs = ant_start_locs[np.where(~good_point_indices)]
ax = plt.gca()
ax.scatter(ant_start_locs[:, 1], ant_start_locs[:, 0], marker="x", color="black", s=15)

plotter.scatter_peak_disp(highlight_point=point_number, cmap="turbo", with_image=True)

In [None]:
_ = plotter.plot_camera_weight(point_number)

In [None]:
_ = plotter.plot_displacement(point_number)

In [None]:
plotter.show_flow_differences(point_number)

In [None]:
fig = plotter.plot_all_displacement(highlight_point=point_number)

# highlight the points below the error threshold in blue
ax = fig.axes[0]
bad_disp = plotter.result_manager.rel_displacements[~good_point_indices]
for p in bad_disp:
    plt.plot(p[:, 2] * 1e3, "--", color=(0.5, 0.5, 1))

In [None]:
vid = plotter.make_point_track_video(highlight_point=point_number)

In [None]:
play_video(vid)

In [None]:
vid = plotter.get_arrow_video()

In [None]:
play_video(vid)

In [None]:
m = plotter.result_manager.point_mesh_locations
ResultPlotter.plot_mesh_with_points(points=m)

In [None]:
# look into more:
# 20240507_OB_2 point 31 strike 10
# 20240502_OB_6 29
# 20240502_OB_2 alignment is super off
# "20220422_OB_1" something is wrong with strikes 6 and 7

# good examples
# 20240418_OB_1 barely any movement but clear pattern

In [None]:
# suggested threshold for being used in global movement calculation:
# 0.025 average error with top two cameras
# in region surrounding peak
# or... maybe squared error of 0.0015

In [None]:
# look at results on the mesh
# with ONLY points below the error threshold
array = torch.mean(analyzer.get_top_values("average_flow_sq", num_cams=2), axis=1)
good_indices = array < 0.0015
# good_indices = ~good_indices
p = analyzer.all_results["mesh_points"][good_indices]

# jitter the points

jitter = 1000
rand = (torch.rand(p.shape) - 0.5) * jitter
p = p + rand
v = analyzer.all_results["normalized_displacement"][good_indices, 2]
v = convert_to_percentile(v)

ResultPlotter.plot_mesh_with_points(
    points=p,
    opacity=0.1,
    point_values=v,
    points_on_surface=False,
    marker_dict={"size": 2, "colorscale": "bluered"},
)

In [None]:
# scattering the same stuff
array = torch.mean(analyzer.get_top_values("average_flow_sq", num_cams=2), axis=1)
good_indices = array < 0.0015

p = analyzer.all_results["start_locations_std"]
jitter = 30
rand = (torch.rand(p.shape) - 0.5) * jitter
p = p + rand

v = analyzer.all_results["normalized_displacement"][:, 2]
v = convert_to_percentile(v)
plt.scatter(
    p[good_indices, 1],
    p[good_indices, 0],
    s=1.5,
    c=v[good_indices],
    cmap="coolwarm",
    clim=(0, 100),
)
ax = plt.gca()
ax.set_aspect("equal")

plt.figure()
plt.scatter(
    p[~good_indices, 1],
    p[~good_indices, 0],
    s=1.5,
    c=v[~good_indices],
    cmap="coolwarm",
    clim=(0, 100),
)
ax = plt.gca()
ax.set_aspect("equal")

In [None]:
# some strikes that don't look good
# based on differences from nearby points
# 20220427_OB_3 strike 2 (this whole ant might be weird, look closer)
# we know 20240503_OB_3 was having issues
# check on what other samples had to change the error threshold
# for global movement calculation

# 20220427_OB_4 is interesting because most of the seleted points
# are in the saddle - probably makes computing normalized movement weird
# might want to think about other ways to get like a normalized score...
# like thinking about how much on average the points in that strike deviate
# from the expected strike strength in that region

# 20240507_OB_3 strike 9, example of a super weak strike that doesn't follow
# expected patterns. a lot of strikes from this ant look weak... investigate more

# you need to check on how many points made it over in later strikes
# for instance 20240502_OB_3 strike 9 has very few points
# it could be worth not actually dropping points that aren't CRAZY off
# but just saving information about the liklihood that the track was good
# and maybe just acknowledging in that way that it COULD be a different point

# 20240506_OB_7 strike 2, take a look at alignment here. Points are all really low

# 20240503_OB_3 strike 1 in particular, something is really weird here

In [None]:
# 2024/11/13
# goal for today:
# define some initial metric for a strength score
# then you can use that to do normalized strength measurements
# and also threshold strikes based on the strength score