In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from DataLoader import *
import datetime
import random
import warnings
from BehaviourFeatureExtractor import *
from BehaviourAnnotation import *
from VocalFeatureExtractor import *
from colour import Color
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
import copy
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import warnings

# 0. Classes and directories #

In [2]:
config_path = "config.json"
data_dir = "data/"
DL = DataLoader(data_dir, path_to_config_file="config.json")
BA = BehaviourAnnotator(config_path)
BF = BehaviourFeatureExtractor("config.json")
VF = VocalFeatureExtractor("config.json")

In [3]:
processed_data_dir = "processed_data"
annotated_data_dir = "annotated_resolved_data"
transition_path_export_dir = "transition_paths"
export_csv_dir = "annotated_cleaned_resolved_data"

# plots
plot_dir = "full_cleaned_resolved_annotation_plots"


# 1. Data loading #

In [None]:
mouse_ids = [mouse_id for mouse_id in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, mouse_id))]
print("Mouse IDs:", mouse_ids)
days = [f"d{i}" for i in range(1, 7)]
print("Days:", days)

# 2. USV + DLC preprocessing #

In [None]:
experiment_data_processed = DL.collect_and_process_experiment_data(mouse_ids, days, BF, VF,
                                                        processed_data_dir = processed_data_dir,
                                                        export = True)

# 3. Behaviour annotation #

In [None]:
mouse_ids = [mouse_id for mouse_id in os.listdir(processed_data_dir) if os.path.isdir(os.path.join(processed_data_dir, mouse_id)) and (mouse_id.startswith("M") or mouse_id.startswith("V"))]
days = [f"d{i}" for i in range(1, 7)]
processed_data = load_processed_data(processed_data_dir, mouse_ids, days)


In [None]:
error_examples = []
timing_error_examples = [("VACO143LLR", "d6", 1), ("MBQ0231R", "d3", 5), ('MBQ0231R', 'd5', 8), ("MBQ0012RR", "d6", 4)]
trials_to_avoid = timing_error_examples

for mouse_id in mouse_ids:
    for day in days:
        for trial_num in range(1, 11):

            print(f" ==== Example: {mouse_id} - {day} - {trial_num} ==== ")
            example = (mouse_id, day, trial_num)

            if example in trials_to_avoid:
                print(f" ==== Skipping example: {example} ==== ")
                continue

            try:
                trial_df =  copy.deepcopy(processed_data[mouse_id][day]["trials"][trial_num]["dlc_data"])
                df_summary = copy.deepcopy(processed_data[mouse_id][day]["Behavior"]["df_summary"])
                pup_locations = copy.deepcopy(processed_data[mouse_id][day]["trials"][trial_num]["pup_locations"])

                BA.run_pup_directed_behavior_annotation(mouse_id, day, trial_num,
                                                        trial_df, df_summary, pup_locations,
                                                        processed_data_dir = annotated_data_dir, export = True)
            except Exception as e:
                print(f"**** !!!! **** !!!! Error on example: {example} !!!! **** !!!! ****")
                error_examples.append(example)


In [None]:
error_examples

# 4. Transition paths and counts matrix #

In [5]:
# load processed annotated data into a dictionary
mouse_ids = [mouse_id for mouse_id in os.listdir(annotated_data_dir) if os.path.isdir(os.path.join(annotated_data_dir, mouse_id)) and (mouse_id.startswith("M") or mouse_id.startswith("V"))]
days = [f"d{i}" for i in range(1, 7)]
processed_and_annotated_data = load_processed_data(annotated_data_dir, mouse_ids, days)

## Run transition path extraction loop ##

In [None]:
BA.get_and_export_transition_paths_for_animal(processed_and_annotated_data, mouse_ids, days, export = True,
                                                transition_path_export_dir = transition_path_export_dir,
                                                export_csv_dir = export_csv_dir,
                                                plot_export_dir = plot_dir)

## Load transition paths from directory ##

In [5]:
# load all transition paths in a dictionary per animal
category = ["Mother", "Virgin"]
transition_paths_dict = load_transition_paths_dict(transition_path_export_dir, mouse_ids, days)

In [None]:
transition_paths_dict.keys()

In [None]:
dict_transition_matrices = BA.create_transition_matrices_from_transition_paths(mouse_ids, days, transition_paths_dict)

In [None]:
# Create a figure with 2 rows and 6 columns

# Get list of categories and days
categories = list(dict_transition_matrices.keys())
days = list(dict_transition_matrices[categories[0]].keys())

######    Individual plots    #######
for category in categories:
    for day in days:
        if day in dict_transition_matrices[category]:
            BA.plot_transition_graph(dict_transition_matrices[category][day], 
                                title=f"{category} - session {day}", ax=None)
        else:
            plt.title(f"{category} - session {day}\n(No data)")

fig, axs = plt.subplots(2, 6, figsize=(50, 15))

#####    Plot each graph on a grid    #######
for i, category in enumerate(categories):
    for j, day in enumerate(days):
        
        axs[i, j].clear()
        axs[i, j].axis('off')

        if day in dict_transition_matrices[category]:
            BA.plot_transition_graph(dict_transition_matrices[category][day],title=f"{category} - session {day}", ax=axs[i,j])

plt.tight_layout()
plt.show()