In [None]:
from hsflfm.processing import Aligner, StrikeProcessor
from hsflfm.util import MetadataManager, load_dictionary, save_dictionary

import os
from tqdm import tqdm
import numpy as np

In [None]:
import warnings

warnings.filterwarnings("ignore")

In [None]:
save_folder = "temporary_result_storage_5"
assert os.path.exists(save_folder)

saved_results_filename = save_folder + "/processed.json"
if not os.path.exists(saved_results_filename):
    processed_vids = {}
else:
    processed_vids = load_dictionary(saved_results_filename)

In [None]:
specimen_numbers = MetadataManager.all_specimen_numbers()

In [None]:
processed_vids

In [None]:
for i, num in enumerate(specimen_numbers):    
    print(num, f"specimen {i} of {len(specimen_numbers)}")

    try:

        data_manager = MetadataManager(num)
        aligner = Aligner(num)
        strike_numbers = np.sort(data_manager.strike_numbers)
        f = save_folder + f"/{num}"

        if not os.path.exists(f):
            os.mkdir(f)

        if num not in processed_vids:
            processed_vids[num] = []

    except Exception as e:
       print(f"failed on specimen {num}, {e}")

    start_strike = None
    bad_strikes = []

    for strike_num in tqdm(strike_numbers):
        if strike_num in processed_vids[num]:
            print(f"skipping {num} strike {strike_num}")
            continue

        

        if strike_num == 2 and num == "20240506_OB_3":
            start_strike = 1
            continue

        try:
            result_info = aligner.prepare_strike_results(
                strike_num, start_strike=start_strike
            )

            # we should check how many points tracked correctly
            # if the number is too low, this strike may not be used for alignment
            # of the next strike
            # this is not a perfect solution but it should be okay for now (2024/11/18)
            based_strike = result_info["aligned_from_strike_number"]
            total_points = len(aligner.stored_point_numbers[1])
            prev_points = len(aligner.stored_point_numbers[based_strike])
            cur_points = len(aligner.stored_point_numbers[strike_num])

            is_good_strike = True

            acceptable_loss = 5
            if prev_points - cur_points > acceptable_loss:
                if based_strike == 1:
                    bad_strikes.append(strike_num)
                    is_good_strike = False
                else:
                    # try again with a different point
                    # we want to go back to the most recent strike that wasn't deemed too low
                    start_strike_try2 = based_strike - 1
                    while start_strike_try2 in bad_strikes:
                        start_strike_try2 = start_strike_try2 - 1
                    result_info_try2 = aligner.prepare_strike_results(
                        strike_num, start_strike=start_strike_try2
                    )
                    cur_points_try2 = len(aligner.stored_point_numbers[strike_num])

                    # this is not ideal... but go back to other version if this was worse
                    if cur_points_try2 < cur_points:
                        result_info_try2 = aligner.prepare_strike_results(
                            strike_num, start_strike=start_strike
                        )

                    result_info = result_info_try2

                    # if the total points is still really low,
                    # we should not use this as a start point for the next strike ?
                    if prev_points - cur_points_try2 > acceptable_loss:
                        bad_strikes.append(strike_num)
                        is_good_strike = False

            if is_good_strike:
                start_strike = strike_num
            else:
                start_strike = based_strike







            processor = StrikeProcessor(result_info)
            processor.get_flow_vectors()
            processor.run_regression()
            processor.get_relative_displacements()

            result_info = processor.condense_info()
            save_name = f + f"/strike_{strike_num}_results.json"
            save_dictionary(result_info, save_name)

            assert os.path.exists(save_name)
            processed_vids[num].append(int(strike_num))

            save_dictionary(processed_vids, saved_results_filename)

        except Exception as e:
           print(f"""failed on specimen {num} strike {strike_num}, {e}""")

        # break

    # break