In [1]:
from hsflfm.util import MetadataManager, load_dictionary, save_dictionary
from hsflfm.analysis import ResultManager
from hsflfm.calibration import CalibrationInfoManager
import numpy as np
import torch 
import os 
import pandas as pd
from tqdm import tqdm

In [2]:
new_metadata_filename = "C:/Users/clare/OneDrive - Duke University/Projects/Re-imaging/Odontomachus brunneus/Metadata.xlsx"
assert os.path.exists(new_metadata_filename)
metadata = pd.read_excel(new_metadata_filename)

In [3]:
def get_result_filename(specimen_number, strike_number):
    return f"complete_results_20241227/{specimen_number}/strike_{strike_number}_results.json"

def get_save_filename(specimen_number, strike_number):
    folder = f"basic_results_20241227/{specimen_number}"
    if not os.path.exists(folder):
        os.mkdir(folder) 
    filename = folder + f'/{specimen_number}_strike_{strike_number}.json'
    return filename

In [4]:
save_keys = [
    "image_crop_indices", 
    "match_points",
    "point_numbers", 
    "error_scores",
    "error_threshold",
    "global_movement",
    "specimen_number", 
    "strike_number", 
    "point_start_locations", 
    "point_displacement_ant_coords",
    "metadata",
    "key_descriptions"
]

descriptions = """
    "image_crop_indices" - crops in pixels for the three sub-images within each Photron image/video,
                           This is a dictionary, where each key is the number of the sub-image. 
                           Each value is a list of length 4, [startx, endx, starty, endy]. 
                           So a loaded frame could be cropped as "image_1 = frame[startx:endx, starty:endx]".
    "match_points" - pixel locations for the tracked points in the three sub-imates. This is a dictionary 
                     where each key is the number of the sub-image. The value is a list of shape (# points, 2).
    "point_numbers" - point numbers for the tracked points included in these results. These numbers may be relevant 
                      when comparing results across strikes for the same specimen. 
    "error_scores" - error score for each point. List of length (# points). 
    "error_threshold" - threshold for error score used to select points included in the results.  
    "global_movement" - Estimated global movement for the ant head, in camera coordinates. This is a dictionary
                        with keys ["x", "y", "z", "roll", "pitch", "yaw"]. "x", "y", and "z" are in millimeters,
                        and "roll", "pitch", "yaw" are in radians. 
    "specimen_number" - ant identifier. 
    "strike_number" - the strike number. 
    "point_start_locations" - list of shape [# points, 3]. Contains the start location (x, y, z) of all the points, in millimeters,
                              in the ant's coordinate system. 
    "point_displacement_ant_coords" - list of shape (# points, # frames, 3). Displacements for each point in millimeters,
                                      in the ant coordinate system. 
    "metadata" - all info pulled from "Metadata.xlsx" for this strike 

"""

In [5]:
specimen_numbers = ["20240507_OB_2"]#MetadataManager.all_specimen_numbers()
for specimen in tqdm(specimen_numbers):
    mm = MetadataManager(specimen)
    strike_numbers = mm.strike_numbers
    for strike_number in strike_numbers:
        result_filename = get_result_filename(specimen, strike_number=strike_number) 
        if not os.path.exists(result_filename):
            print(f"skipping {specimen} strike {strike_number}")
            continue
        updated_specimen = "2024" + specimen[4:]
        if "rp" in updated_specimen: 
            updated_specimen = "20240418_OB_1_alignment_3"

        save_filename = get_save_filename(updated_specimen, strike_number)
        if os.path.exists(save_filename) and "20240418_OB_1" not in specimen:
            continue

        results = load_dictionary(get_result_filename(specimen, strike_number=strike_number))

        rm = ResultManager(results)
        point_start_locations = rm.point_start_locs_ant_mm
        error_scores = rm.error_scores
        rel_displacements = rm.rel_displacements
        global_movement = results["global_movement"]

        error_threshold = 0.0015
        good_indices = torch.where(error_scores < error_threshold)

        match_points = results["match_points"]
        for key, item in match_points.items():
            match_points[key] = np.asarray(item)[good_indices]

        save_results = {
            "match_points": match_points,
            "point_numbers": np.asarray(results["point_numbers"])[good_indices],
            "error_threshold": error_threshold, 
            "error_scores": error_scores[good_indices], 
            "global_movement": results["global_movement"],
            "specimen_number": updated_specimen, 
            "strike_number": strike_number, 
            "point_displacement_ant_coords": rel_displacements[good_indices],
            "point_start_locations": point_start_locations[good_indices]
        }

        for key in ["point_displacement_ant_coords", "point_start_locations", 
                    "point_numbers", "error_scores"]:
            assert len(save_results[key]) == len(good_indices[0]) 
        for item in save_results["match_points"].values():
            assert len(item) == len(good_indices[0])

        # find the new metadata entry
        if "0418_OB_1" in specimen:
            specimen_data = metadata.loc[metadata["Specimen #"] == "20240418_OB_1"]
            if "rp" in specimen:
                specimen_data = specimen_data[specimen_data['VideoFileName'].str.contains('alignment3', case=False, na=False)]
            else: 
                specimen_data = specimen_data[~specimen_data['VideoFileName'].str.contains('alignment3', case=False, na=False)]
        else:
            specimen_data = metadata.loc[metadata["Specimen #"].isin([specimen, updated_specimen])]
        strike_data = specimen_data.loc[specimen_data["Strike #"] == strike_number]
        assert len(strike_data) == 1

        data_dict = strike_data.to_dict()
        for key, item in data_dict.items():
            assert len(item.values()) == 1
            data_dict[key] = [i for i in item.values()][0]

        save_results["metadata"] = data_dict
        save_results["key_descriptions"] = descriptions
        save_results["image_crop_indices"] = CalibrationInfoManager(mm.calibration_filename).crop_indices
        for key in save_keys:
            assert key in save_results

        save_dictionary(save_results, save_filename)


100%|██████████| 1/1 [00:00<00:00, 157.33it/s]
