# Plotting and Analysis

The role of this notebook is to plot and analyze logs results of a run (or runs) of a simulator, given some fixed timing configuration.
These logs (bboxes.csv) are obtained by running a simulator on some experiments. The goal of these plots is to analyze worm's behavior,
and to analyze the systems error and how it's affected by different behaviors the worm exhibits.

It's important to note that for proper analysis, all the experiments that are analyzed by this notebook *at once* must have the same timing configuration (TimingConfig) parameters.

In [None]:
import matplotlib.pyplot as plt
from wtracker.eval import *
from wtracker.sim.config import TimingConfig
from wtracker.utils.gui_utils import UserPrompt

### Timing configuration and log files selection

In [None]:
from pprint import pprint

################################ User Input ################################

# path to the timing config file. 
# If None, a file dialog will open to select a file
timing_config_path = "logs/time_config.json"

############################################################################

timing_config = TimingConfig.load_json(timing_config_path)

pprint(timing_config)

In [None]:
################################ User Input ################################

# list containing paths to simulation log files.
# All of these simulations must have been run with the above timing config.
# If empty, a file dialog will open to select files.
log_files = ["logs/bboxes.csv"]

############################################################################

if len(log_files) == 0:
    log_files = UserPrompt.open_file(title="Select log files", filetypes=[("Log files", "*.csv")], multiple=True)

pprint(log_files)

In [None]:
from wtracker.eval.plotter import Plotter
from wtracker.eval.data_analyzer import DataAnalyzer


data = DataAnalyzer(
    time_config=timing_config,
    log_path=log_files[0],
    unit="sec",
)

data.initialize(
    serial=0,
    period=10,
    imaging_only=True,
    legal_bounds=None,
)

### Plotting configuration

Notice that all of below plots accept `condition` as a parameter.
`condition` is expected to be a function of the following signature:

```python
def cond_func1(input_df: pd.DataFrame) -> pd.DataFrame:
    return (input_df["wrm_speed"] > 5) &  (input_df["wrm_speed"] <= 30)
```

In python, such functions can be also declared without an explicit name and declaration, using the following syntax:
(for more information read about lambda functions)

```python
cond_func1 = lambda input_df: (input_df["wrm_speed"] > 5) & (input_df["wrm_speed"] <= 30)
cond_func2 = lambda input_df: input_df["phase"] == "imaging"
```

##### Optionally, Calculate precise error

To calculate precise error of the system, run the following cell, otherwise skip it.
Note, that running this cell might take a while.

For each frame, the exact pixels in which worm's head is located are calculated. To this end, there is a need to access worm images which were extracted during the experiment initialization process.
Afterwards, the error is calculated as the proportion of worm pixels that are outside of the microscope view. 
Since to calculate this error there is a need to load images from the disk, the calculation is relatively slow.

In [None]:
import numpy as np

# TODO: TEST
# TODO: ADD DOCS FOR THIS SECTION

################################ User Input ################################

background_path = "data\\Exp1_GuyGilad_logs_yolo\\background.npy"

worm_folder_path = "D:\\Guy_Gilad\\Exp1_GuyGilad\\logs_yolo\\worms"

diff_thresh = 20

############################################################################

if background_path is None:
    background_path = UserPrompt.open_file(title="Select background images", file_types=[("Numpy files", "*.npy")])

if worm_folder_path is None:
    worm_folder_path = UserPrompt.open_directory(title="Select worm image folders")

print("Background Files: ", background_path)
print("Worm Image Folders: ", worm_folder_path)

background = np.load(background_path, allow_pickle=True)

##### Calibrate Threshold [Optional]

In [None]:
from wtracker.eval.vlc import StreamViewer
from wtracker.eval.error_calculator import ErrorCalculator
from wtracker.utils.frame_reader import FrameReader
import pandas as pd
import numpy as np

viewer = StreamViewer(window_name="Threshold Calibration")

In [None]:

################################ User Input ################################
threshold = 30
exp_number = 0 # the number of the experiment in the list
delay = 0
############################################################################
def show_sementation(wrm_view:np.ndarray, wrm_mask:np.ndarray) -> None:
    wrm_view[~wrm_mask] = 0
    viewer.imshow(wrm_view)
    viewer.waitKey(delay)


ErrorCalculator.probe_hook = show_sementation

reader = FrameReader.create_from_directory(worm_folder_path)
log = pd.read_csv(log_files[exp_number])

viewer.open()
shape = [*reader.frame_shape]
shape[:2] = background.shape[:2]
background.reshape(shape)

ErrorCalculator.calculate_precise(
    background=background,
    worm_bboxes=log[["wrm_x", "wrm_y", "wrm_w", "wrm_h"]].to_numpy(),
    mic_bboxes=log[["mic_x", "mic_y", "mic_w", "mic_h"]].to_numpy(),
    frame_nums=log['frame'].astype(int).to_list(),
    worm_reader=reader,
    diff_thresh=threshold
)

##### Calculate Precise error

In [None]:
worm_reader = FrameReader.create_from_directory(worm_folder_path)

data.calc_precise_error(
    worm_image_paths=worm_folder_path,
    background=background,
    diff_thresh=diff_thresh,
)

##### Save Data

In [None]:
data_save_path = None

if data_save_path is None:
    data_save_path = UserPrompt.save_file(title="Save data", filetypes=[("Pickle files", "*.pkl")])

data.save(data_save_path)

### Plotting and analysis

In [None]:
data_list = [data]

if len(data_list) == 0:
    file_paths = UserPrompt.open_file(title="Select data files", filetypes=[("Pickle files", "*.pkl")], multiple=True)
    data_list = [DataAnalyzer.load(path) for path in file_paths]

In [None]:
# create the plotter
pltr = Plotter(data_list, plot_height=7)

In [None]:
# print column names of the data
pprint([f"{i}: {col}" for i, col in enumerate(data.column_names())])

In [None]:
data.print_stats()

In [None]:
pltr.plot_trajectory(hue_col="log_num", condition=lambda x: x["wrm_y"] >= 0)
plt.show()

In [None]:
data.data.head(20)

In [None]:
pltr.plot_speed(log_wise=False, condition=lambda x: x["wrm_speed"] <= 800)
plt.show()

In [None]:
pltr.plot_error(log_wise=False, error_kind="dist", condition=lambda x: x["bbox_error"] > 0)
plt.show()

In [None]:
pltr.plot_speed_vs_error(error_kind="bbox", condition=lambda x: x["wrm_speed"] < 2000)
plt.show()

In [None]:
pltr.plot_deviation(percentile=0.999, log_wise=False)
plt.show()

In [None]:
pltr.plot_head_size()
plt.show()

In [None]:
data.describe(columns=["wrm_speed", "bbox_error", "worm_deviation"], num=9)

In [None]:
import numpy as np

# find anomalies in the data
data.find_anomalies(
    no_preds=True,
    min_bbox_error=1.0,
    min_dist_error=np.inf,
    min_speed=np.inf,
    min_size=300,
)