In [34]:
# Comment: Import everything we need for this test notebook.
import os
import numpy as np
import pandas as pd

from data_loading.gather_mat_files_multiple import gather_mat_files_multiple_condition
from data_loading.load_intervals import load_intervals_data
from data_loading.load_scores import load_scores_data
from data_loading.load_postprocessed import load_postprocessed_data

# Comment: Define the base folder that contains your multiple-condition subfolders.
base_folder = r"D:\behavior_ethogram_project_Ilya"

# Comment: Gather .mat files per condition (list-of-lists) and discover condition names.
all_condition_files, condition_names = gather_mat_files_multiple_condition([base_folder])

# Comment: Load interval data for all flies in all conditions.
all_flies_intervals, intervals_index_map, intervals_cond_names = load_intervals_data(
    all_condition_files,
    single=False,
    condition_names=condition_names
)

# Comment: Load continuous scores data (with per-file scoreNorm).
all_flies_scores, scores_index_map, scores_cond_names = load_scores_data(
    all_condition_files,
    single=False,
    condition_names=condition_names
)

# Comment: Load postprocessed (binary) data.
all_flies_post, post_index_map, post_cond_names = load_postprocessed_data(
    all_condition_files,
    single=False,
    condition_names=condition_names
)

# Comment: Let's pick the "last loaded fly" from each dataset to inspect.
# The total number of flies is len(all_flies_*) for each approach.
# We'll just pick the final entry in these lists:
last_intervals = all_flies_intervals[-1]   # This is a list of intervals for that fly.
last_scores   = all_flies_scores[-1]       # This is a (T, B) array of scores for that fly.
last_post     = all_flies_post[-1]         # This is a (T, B) binary array for that fly.

########################################################################################
# Comment: Build a table for intervals (the "first intervals" for each behavior).
# Note: The 'load_intervals_data' structure might have appended intervals from multiple .mat files
#       If each .mat corresponds to a separate behavior, we can sample the earliest interval from each.

# We'll assume the "last_intervals" is a single list of dictionaries:
#  [ {"tStart": val, "tEnd": val}, ... ] across all behaviors.
# If you want them grouped by behavior, you'd need to store that info. We'll assume we can chunk them.

intervals_df_rows = []
# Suppose each .mat is a separate behavior => we can just enumerate them in chunks if we know how many
# intervals per behavior. If that's not tracked, we'll just show the entire list or the first intervals.
# Here, we'll pick the first interval from each chunk of intervals if you prefer that logic.

# For simplicity, let's treat each dictionary as a separate "interval row" in a single table.
for i, iv in enumerate(last_intervals):
    intervals_df_rows.append({
        "IntervalIndex": i,
        "tStart": iv["tStart"],
        "tEnd":   iv["tEnd"]
    })

intervals_df = pd.DataFrame(intervals_df_rows)
intervals_df


ValueError: too many values to unpack (expected 2)

In [30]:
# Comment: Build a table for scores. The user wants the maximum for this last loaded fly,
#          plus we demonstrate before/after normalization info. However, the provided
#          load_scores_data function already divides by scoreNorm, so we only have normalized values
#          in 'last_scores'. If we want the "raw" data, we'd need to modify that loader
#          to store raw arrays as well. For demo, let's just show the maximum post-normalized score.

# last_scores is shape (T, B). We'll compute max per behavior (column).
beh_count = last_scores.shape[1]
score_table_rows = []
for b_idx in range(beh_count):
    col_data = last_scores[:, b_idx]
    # ignoring NaNs
    valid_col = col_data[~np.isnan(col_data)]
    max_val = valid_col.max() if len(valid_col) > 0 else np.nan
    score_table_rows.append({
        "BehaviorIndex": b_idx,
        "MaxPostNormScore": max_val,
        # If we had the raw array or separate scoreNorm, we could show them, but
        # by default load_scores_data divides by scoreNorm in place.
        "ScoreNorm_PLACEHOLDER": "Need custom code to show actual norm"
    })

scores_df = pd.DataFrame(score_table_rows)
scores_df


In [31]:
# Comment: Build a table for binary postprocessed data. The user wants for each behavior:
#          the number of 1's and the number of 0's in last_fly_post (excluding NaN).

binary_table_rows = []
if last_post.ndim == 2:
    B = last_post.shape[1]
    for b_idx in range(B):
        col = last_post[:, b_idx]
        # filter out nan
        col_no_nan = col[~np.isnan(col)]
        # count how many are 1, how many are 0
        num_ones  = np.sum(col_no_nan == 1)
        num_zeros = np.sum(col_no_nan == 0)
        binary_table_rows.append({
            "BehaviorIndex": b_idx,
            "Count1s": int(num_ones),
            "Count0s": int(num_zeros)
        })
else:
    # If it's not 2D, handle differently
    pass

binary_df = pd.DataFrame(binary_table_rows)
binary_df



=== Testing Scores Data Loading Across Multiple Conditions ===
Gathering .mat files from these folders:
D:\\behavior_ethogram_project_Ilya\Assa_Females_Mated_Unknown_RigA_20220207T130211
D:\\behavior_ethogram_project_Ilya\Assa_Females_Singles_Unknown_RigA_20220206T100525
Flies Data Paths Structure:
[]
Number of conditions detected: 0


IndexError: list index out of range

In [12]:
#Test Postprocessed Data Loading
print("\n=== Testing Postprocessed Data Loading ===")
all_flies, _, _ = load_postprocessed_data(flies_data_paths[:1], single=True)

table = PrettyTable()
table.field_names = ["Local Index", "Global Index", "Condition Name", "First Non-Empty Value"]

global_index = 0
for fly_idx, fly_data in enumerate(all_flies[:20]):  # Limit to 20 entries
    for behavior_idx in range(fly_data.shape[1]):
        behavior_data = fly_data[:, behavior_idx]
        non_empty_indices = (~pd.isna(behavior_data) & (behavior_data != 0)).nonzero()[0]

        if len(non_empty_indices) > 0:
            first_non_empty = behavior_data[non_empty_indices[0]]
            table.add_row([fly_idx, global_index, condition_names[0], first_non_empty])

        global_index += 1

print(table)


=== Testing Postprocessed Data Loading ===


RuntimeError: For multiple conditions, index_map and condition_names must be provided.

In [11]:
#Test Intervals Data Loading
print("\n=== Testing Intervals Data Loading ===")
all_flies, _, _ = load_intervals_data(flies_data_paths[:1], single=True)  # Test only the first condition

table = PrettyTable()
table.field_names = ["Local Index", "Global Index", "Condition Name", "First Start", "First End"]

global_index = 0
for fly_idx, fly_data in enumerate(all_flies[:20]):  # Limit to 20 entries
    if len(fly_data) > 0:
        first_start = fly_data[0]["tStart"] if "tStart" in fly_data[0] else None
        first_end = fly_data[0]["tEnd"] if "tEnd" in fly_data[0] else None
        table.add_row([fly_idx, global_index, condition_names[0], first_start, first_end])

    global_index += 1

print(table)


=== Testing Intervals Data Loading ===
+-------------+--------------+----------------------------------------------------+-------------+-----------+
| Local Index | Global Index |                   Condition Name                   | First Start | First End |
+-------------+--------------+----------------------------------------------------+-------------+-----------+
|      0      |      0       | Assa_Females_Grouped__Unknown_RigA_20220206T090743 |     435     |    481    |
|      1      |      1       | Assa_Females_Grouped__Unknown_RigA_20220206T090743 |      36     |     37    |
|      2      |      2       | Assa_Females_Grouped__Unknown_RigA_20220206T090743 |      27     |     28    |
|      3      |      3       | Assa_Females_Grouped__Unknown_RigA_20220206T090743 |     877     |    880    |
|      4      |      4       | Assa_Females_Grouped__Unknown_RigA_20220206T090743 |     182     |    195    |
|      5      |      5       | Assa_Females_Grouped__Unknown_RigA_20220206T09074

In [37]:
print(gather_mat_files_multiple_condition([base_folder]))

(['D:\\behavior_ethogram_project_Ilya\\Assa_Females_Mated_Unknown_RigA_20220207T130211\\scores_Grooming.mat', 'D:\\behavior_ethogram_project_Ilya\\Assa_Females_Mated_Unknown_RigA_20220207T130211\\scores_Jump.mat', 'D:\\behavior_ethogram_project_Ilya\\Assa_Females_Mated_Unknown_RigA_20220207T130211\\scores_Long_Distance_Approach.mat', 'D:\\behavior_ethogram_project_Ilya\\Assa_Females_Mated_Unknown_RigA_20220207T130211\\scores_Long_Lasting_Interaction.mat', 'D:\\behavior_ethogram_project_Ilya\\Assa_Females_Mated_Unknown_RigA_20220207T130211\\scores_Short_Distance_Approach.mat', 'D:\\behavior_ethogram_project_Ilya\\Assa_Females_Mated_Unknown_RigA_20220207T130211\\scores_Social_Clustering.mat', 'D:\\behavior_ethogram_project_Ilya\\Assa_Females_Mated_Unknown_RigA_20220207T130211\\scores_Stable_Interaction.mat', 'D:\\behavior_ethogram_project_Ilya\\Assa_Females_Mated_Unknown_RigA_20220207T130211\\scores_Stop.mat', 'D:\\behavior_ethogram_project_Ilya\\Assa_Females_Mated_Unknown_RigA_20220207T