In [12]:
# Import libraries
import os
os.chdir(r"C:\\Users\\Dell #050\\Documents\\MABe")
from glob import glob
from concurrent.futures import ThreadPoolExecutor, as_completed
from collections import Counter
import pandas as pd
from tqdm import tqdm

In [9]:
# Read annotation files
annotation_files = glob(os.path.join("datasets", "train_annotation","**", "*.parquet"))
print(f"Number of annotation files: {len(annotation_files)}")

Number of annotation files: 847


In [10]:
# Read a sample annotation file
sample_data = pd.read_parquet(annotation_files[0])
sample_data.head()

Unnamed: 0,agent_id,target_id,action,start_frame,stop_frame
0,1,3,chase,2,54
1,1,3,chase,128,234
2,3,2,avoid,324,342
3,3,1,avoid,324,342
4,1,2,chase,942,1052


In [13]:
# Count frequency of each action
def process_file(file_path):
    """Reads one file and returns a Counter of action frequencies."""
    data = pd.read_parquet(file_path)
    counter = Counter()
    for row in data.itertuples():
        counter[row.action] += (row.stop_frame - row.start_frame) + 1
    return counter

# Initialize global counter
action_freq_map = Counter()

# Use multithreading to process multiple files in parallel
with ThreadPoolExecutor(max_workers=8) as executor:  # adjust max_workers as needed
    futures = {executor.submit(process_file, f): f for f in annotation_files}
    for future in tqdm(as_completed(futures), total=len(futures), desc="Processing annotation files"):
        action_freq_map.update(future.result())

# Convert back to normal dict if desired
action_freq_map = dict(action_freq_map)
action_freq_map

Processing annotation files: 100%|██████████| 847/847 [00:00<00:00, 69326.66it/s]


{'rear': 235297,
 'approach': 89238,
 'avoid': 26204,
 'attack': 526332,
 'chase': 27530,
 'submit': 8564,
 'chaseattack': 5537,
 'sniff': 2188427,
 'mount': 287678,
 'intromit': 348655,
 'sniffgenital': 711793,
 'sniffbody': 106792,
 'reciprocalsniff': 42174,
 'dominancemount': 17966,
 'attemptmount': 6161,
 'escape': 92375,
 'defend': 91391,
 'sniffface': 74338,
 'shepherd': 29652,
 'selfgroom': 83117,
 'dig': 79897,
 'allogroom': 7625,
 'dominancegroom': 4777,
 'follow': 40377,
 'climb': 61344,
 'run': 1808,
 'rest': 87806,
 'huddle': 24447,
 'genitalgroom': 6320,
 'exploreobject': 3783,
 'flinch': 1861,
 'biteobject': 2359,
 'tussle': 5178,
 'ejaculate': 1614,
 'freeze': 31765,
 'dominance': 38483,
 'disengage': 12300}

Let's create a map where action will be stored such a way that each action will have a list of dicts, key will be file name or path and the value will be a list of (start frame, stop frame) where the action is found.

The frequency map should also include 'no action'.

The dataset split should be done such a way that it will traverse all the files. Number of frames for each action should be predefined. Now, a segment should be fetched randomly to create the dataset.
- Segment size should be decided from the average action duration distribution.
- Some rebalnce method is needed to handle overlapping actions which can make the sampled data improper.
- For each video file action map should be prepared and stored in json for reference.
- For each action a map should be prepared for non-overlapping action segment which can be used for oversampling.

## Create Frequency Map

In [14]:
# Import libraries
import os
import pandas as pd
pd.set_option('display.max_columns', None)  # To display all columns in the dataframe
from tqdm import tqdm

In [6]:
# Read metadata file
metadata_df = pd.read_csv(r"C:\Users\Dell #050\Documents\MABe\datasets\train.csv")
metadata_df.head()

Unnamed: 0,lab_id,video_id,mouse1_strain,mouse1_color,mouse1_sex,mouse1_id,mouse1_age,mouse1_condition,mouse2_strain,mouse2_color,mouse2_sex,mouse2_id,mouse2_age,mouse2_condition,mouse3_strain,mouse3_color,mouse3_sex,mouse3_id,mouse3_age,mouse3_condition,mouse4_strain,mouse4_color,mouse4_sex,mouse4_id,mouse4_age,mouse4_condition,frames_per_second,video_duration_sec,pix_per_cm_approx,video_width_pix,video_height_pix,arena_width_cm,arena_height_cm,arena_shape,arena_type,body_parts_tracked,behaviors_labeled,tracking_method
0,AdaptableSnail,44566106,CD-1 (ICR),white,male,10.0,8-12 weeks,wireless device,CD-1 (ICR),white,male,24.0,8-12 weeks,wireless device,CD-1 (ICR),white,male,38.0,8-12 weeks,wireless device,CD-1 (ICR),white,male,51.0,8-12 weeks,wireless device,30.0,615.6,16.0,1228,1068,60.0,60.0,square,familiar,"[""body_center"", ""ear_left"", ""ear_right"", ""head...","[""mouse1,mouse2,approach"", ""mouse1,mouse2,atta...",DeepLabCut
1,AdaptableSnail,143861384,CD-1 (ICR),white,male,3.0,8-12 weeks,,CD-1 (ICR),white,male,17.0,8-12 weeks,,CD-1 (ICR),white,male,31.0,8-12 weeks,,CD-1 (ICR),white,male,44.0,8-12 weeks,,25.0,3599.0,9.7,968,608,60.0,60.0,square,familiar,"[""body_center"", ""ear_left"", ""ear_right"", ""late...","[""mouse1,mouse2,approach"", ""mouse1,mouse2,atta...",DeepLabCut
2,AdaptableSnail,209576908,CD-1 (ICR),white,male,7.0,8-12 weeks,,CD-1 (ICR),white,male,21.0,8-12 weeks,,CD-1 (ICR),white,male,35.0,8-12 weeks,,CD-1 (ICR),white,male,48.0,8-12 weeks,,30.0,615.2,16.0,1266,1100,60.0,60.0,square,familiar,"[""body_center"", ""ear_left"", ""ear_right"", ""late...","[""mouse1,mouse2,approach"", ""mouse1,mouse2,atta...",DeepLabCut
3,AdaptableSnail,278643799,CD-1 (ICR),white,male,11.0,8-12 weeks,wireless device,CD-1 (ICR),white,male,25.0,8-12 weeks,wireless device,CD-1 (ICR),white,male,39.0,8-12 weeks,wireless device,,,,,,,30.0,619.7,16.0,1224,1100,60.0,60.0,square,familiar,"[""body_center"", ""ear_left"", ""ear_right"", ""head...","[""mouse1,mouse2,approach"", ""mouse1,mouse2,atta...",DeepLabCut
4,AdaptableSnail,351967631,CD-1 (ICR),white,male,14.0,8-12 weeks,,CD-1 (ICR),white,male,28.0,8-12 weeks,,CD-1 (ICR),white,male,42.0,8-12 weeks,,,,,,8-12 weeks,,30.0,602.6,16.0,1204,1068,60.0,60.0,square,familiar,"[""body_center"", ""ear_left"", ""ear_right"", ""late...","[""mouse1,mouse2,approach"", ""mouse1,mouse2,atta...",DeepLabCut


In [13]:
# Get ids of all training videos
video_ids = metadata_df['video_id'].tolist()
lab_ids = metadata_df['lab_id'].tolist()
print(f"Number of training videos: {len(video_ids)}")
print(f"Number of unique labs: {len(set(lab_ids))}")

Number of training videos: 8789
Number of unique labs: 21


In [10]:
# Declare a frequency map for actions
action_freq_map = {}

In [12]:
# Decalre the tracking and annotation directories
tracking_dir = r"C:\Users\Dell #050\Documents\MABe\datasets\train_tracking"
annotation_dir = r"C:\Users\Dell #050\Documents\MABe\datasets\train_annotation"

In [21]:
# Get valid video and lab ids
issues_found = 0
valid_video_ids = []
for lab_id, vid_id in tqdm(zip(lab_ids, video_ids), total=len(video_ids), desc="Counting action frequencies"):
    tracking_file_path = os.path.join(tracking_dir, lab_id, f"{vid_id}.parquet")
    annoatation_file_path = os.path.join(annotation_dir, lab_id, f"{vid_id}.parquet")
    if not os.path.exists(annoatation_file_path) or not os.path.exists(tracking_file_path):
        issues_found += 1
    else:
        valid_video_ids.append((vid_id, lab_id))
print(f"Proper tracking and annotation files found for {len(valid_video_ids)} videos. Issues found for {issues_found} videos.")        

Counting action frequencies: 100%|██████████| 8789/8789 [00:00<00:00, 88689.16it/s]

Proper tracking and annotation files found for 847 videos. Issues found for 7942 videos.





In [22]:
valid_video_ids

[(44566106, 'AdaptableSnail'),
 (143861384, 'AdaptableSnail'),
 (209576908, 'AdaptableSnail'),
 (278643799, 'AdaptableSnail'),
 (351967631, 'AdaptableSnail'),
 (355542626, 'AdaptableSnail'),
 (678426900, 'AdaptableSnail'),
 (705948978, 'AdaptableSnail'),
 (878123481, 'AdaptableSnail'),
 (1212811043, 'AdaptableSnail'),
 (1260392287, 'AdaptableSnail'),
 (1351098077, 'AdaptableSnail'),
 (1408652858, 'AdaptableSnail'),
 (1596473327, 'AdaptableSnail'),
 (1643942986, 'AdaptableSnail'),
 (1717182687, 'AdaptableSnail'),
 (2078515636, 'AdaptableSnail'),
 (402963089, 'BoisterousParrot'),
 (459610814, 'BoisterousParrot'),
 (613246188, 'BoisterousParrot'),
 (1059582964, 'BoisterousParrot'),
 (1184291605, 'BoisterousParrot'),
 (1201849558, 'BoisterousParrot'),
 (1459695188, 'BoisterousParrot'),
 (1985626297, 'BoisterousParrot'),
 (363958890, 'CRIM13'),
 (415181540, 'CRIM13'),
 (670907179, 'CRIM13'),
 (793202924, 'CRIM13'),
 (840324395, 'CRIM13'),
 (1009459450, 'CRIM13'),
 (1057221056, 'CRIM13'),
 (