# EXPERIMENTAL NOTEBOOK, DO NOT USE!

In [None]:
# fix imports
import os
import sys

module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from wtracker.eval import *
from wtracker.sim.config import TimingConfig
from wtracker.utils.gui_utils import UserPrompt
from wtracker.utils.path_utils import join_paths

### Timing configuration and log files selection

In [None]:
################################ User Input ################################

# The folder path to the experiment to analyze.
# This folder should contain the log file and time config file.
# If None, the user will be prompted to select a folder.
experiment_folder = "/mnt/c/Users/freid/Desktop/FinalEvaluations/Exp1_config1_CSV"

############################################################################


if experiment_folder is None:
    experiment_folder = UserPrompt.open_directory(title="Select log directory to analyze")

log_file = join_paths(experiment_folder, "bboxes.csv")
time_config_path = join_paths(experiment_folder, "time_config.json")
analysis_save_path = join_paths(experiment_folder, "analyzed.csv")

print("Base directory: ", experiment_folder)
print("Log file: ", log_file)
print("Time config file: ", time_config_path)
print("Analysis save file: ", analysis_save_path)

In [None]:
from pprint import pprint

# load the data from the directory
timing_config = TimingConfig.load_json(time_config_path)
analyzer = DataAnalyzer.load(timing_config, log_file)

pprint(timing_config)

### Analyze log data

The analyzer cleans up the data to our needs, and able to display useful statistics.  
Afterwards, the analyzed data will be passed to a `Plotter` class which draws graphs of the resulting analyzed data.

In [None]:
################################ User Input ################################

# initialize the analyzer on the log data.
analyzer.initialize(period=10)

# remove unwanted frames from the data
analyzer.clean(
    trim_cycles=True,
    imaging_only=True,
    bounds=None,
)

# find anomalies in the resulting data, and remove them if needed
analyzer.calc_anomalies(
    no_preds=True,
    min_bbox_error=np.inf,
    min_dist_error=np.inf,
    min_speed=np.inf,
    min_size=np.inf,
    remove_anomalies=True,  # whether to remove the anomalies from the data, on only detect them
)

# change the units of the time and distance of the resulting analyzed data
#analyzer.change_unit("sec")

############################################################################

In [None]:
from wtracker.sim.config import ExperimentConfig
from wtracker.utils.bbox_utils import *
from tqdm.auto import tqdm
import cv2 as cv
import matplotlib.pyplot as plt

exp_config = ExperimentConfig.load_json(join_paths(experiment_folder, "exp_config.json"))

H, W = exp_config.orig_resolution

class ImageCreator:
    def __init__(self, analyzer: DataAnalyzer, h: int, w: int) -> None:
        self.image = np.ones((h, w, 3), dtype=np.uint8) * 255

        bboxes = analyzer._orig_data[["wrm_x", "wrm_y", "wrm_w", "wrm_h"]].to_numpy()

        bboxes, is_valid = BoxUtils.discretize(bboxes, (h, w), BoxFormat.XYWH)

        self.is_valid = is_valid
        self.bboxes = bboxes

    def __getitem__(self, idx: int):
        if self.is_valid[idx]:
            bbox = self.bboxes[idx]
            assert bbox[2] > 0 and bbox[3] > 0, f"{idx}"
            img = self.image[bbox[1] : bbox[1] + bbox[3], bbox[0] : bbox[0] + bbox[2]]
            assert img.shape[0] > 0 and img.shape[1] > 0, f"{idx}"
            return img
        else:
            return None
        

def save_images(analyzer: DataAnalyzer, save_folder: str, h: int, w: int):
    image_creator = ImageCreator(analyzer, h, w)
    for i in tqdm(range(len(analyzer._orig_data))):
        image = image_creator[i]
        if image is None:
            continue

        try:
            cv.imwrite(join_paths(save_folder, f"{i:09d}.png"), image)
        except Exception as e:
            print(f"Failed to save image {i:09d}: {e}")

In [None]:
#save_images(analyzer, "images", H, W)

In [None]:
class ImageLoader:
    def __init__(self, folder: str, name_format: str, imread_flags) -> None:
        self.folder = folder
        self.name_format = name_format
        self.imread_flags = imread_flags

    def __getitem__(self, idx: int) -> np.ndarray:
        path = join_paths(self.folder, self.name_format.format(idx))
        return cv.imread(path, flags=self.imread_flags)

In [None]:
""" loader = ImageLoader("images", "{:09d}.png", cv.IMREAD_COLOR)

bg = np.zeros((H, W, 3), dtype=np.uint8)

analyzer.calc_precise_error(loader, background=bg, diff_thresh=10)

print(analyzer.data["precise_error"].mean()) """

In [None]:
print(analyzer.data["bbox_error"].mean())

In [None]:
# get index of the first true element 

first_true = np.argmax((np.abs(analyzer.data["bbox_error"] - analyzer.data["precise_error"]) > 1e-2))
print(first_true)

print(analyzer.data["bbox_error"].iloc[first_true])
print(analyzer.data["precise_error"].iloc[first_true])

print("frame:", analyzer.data["frame"].iloc[first_true])
print(analyzer.data[["wrm_x", "wrm_y", "wrm_w", "wrm_h"]].iloc[first_true])
print(analyzer.data[["mic_x", "mic_y", "mic_w", "mic_h"]].iloc[first_true])

In [None]:
analyzer.print_stats()

In [None]:
analyzer.describe(
    columns=["wrm_speed", "bbox_error", "precise_error", "worm_deviation"],
    percentiles=[0.25, 0.5, 0.75, 0.8, 0.9, 0.95, 0.97, 0.98, 0.99],
)

In [None]:
# save tge analyzed data back to the experiment directory
analyzer.save(analysis_save_path)

### Plotting

Notice that all of below plots accept `condition` as a parameter.
`condition` is expected to be a function of the following signature:

```python
def cond_func1(input_df: pd.DataFrame) -> pd.DataFrame:
    return (input_df["wrm_speed"] > 5) &  (input_df["wrm_speed"] <= 30)
```

In python, such functions can be also declared without an explicit name and declaration, using the following syntax:
(for more information read about lambda functions)

```python
cond_func1 = lambda input_df: (input_df["wrm_speed"] > 5) & (input_df["wrm_speed"] <= 30)
cond_func2 = lambda input_df: input_df["phase"] == "imaging"
```

In [None]:
# print column names of the data
pprint([f"{i}: {col}" for i, col in enumerate(analyzer.column_names())])

In [None]:
# create the plotter from the analyzed data we created previously
pltr = Plotter([analyzer.data], plot_height=7, palette="bright")

In [None]:
pltr.plot_trajectory()
plt.show()

In [None]:
pltr.plot_head_size()
plt.show()

In [None]:
pltr.plot_speed(
    condition=lambda x: x["wrm_speed"] <= 800,
    aspect=0.5,
)
plt.show()

In [None]:
pltr.plot_cycle_error(
    error_kind="bbox",
    log_wise=True,
    k_depth="proportion",
    outlier_prop=0.01,
)
plt.show()

In [None]:
pltr.plot_cycle_error(
    error_kind="dist",
    log_wise=True,
    k_depth="proportion",
    outlier_prop=0.02,
)
plt.show()

In [None]:
pltr.plot_speed_vs_error(
    error_kind="bbox",
    cycle_wise=True,
    condition=lambda df: (df["wrm_speed"] < 1000) & (df["bbox_error"] > 1e-5),
)
plt.show()

In [None]:
pltr.plot_speed_vs_error(
    error_kind="dist",
    cycle_wise=True,
    condition=lambda df: (df["wrm_speed"] < 1000) & (df["worm_deviation"] < 300),
)
plt.show()