In [None]:
!pip install tqdm

In [None]:
import ast
import keras
import random
import numpy as np
import tensorflow as tf
import pandas as pd
import pathlib
from math import ceil, sqrt
from tqdm import tqdm
import csv
import os
output_file = "submission.csv"

In [None]:
# These need to be here to load the model
@keras.saving.register_keras_serializable()
def make_pairs(x):
    # x: (B, M, T, E)
    M = tf.shape(x)[1]
    # Expand dims to prepare for broadcasting
    x1 = tf.expand_dims(x, axis=2)  # (B, M, 1, T, E)
    x2 = tf.expand_dims(x, axis=1)  # (B, 1, M, T, E)
    # Tile to get all pairs
    x1_tiled = tf.tile(x1, [1, 1, M, 1, 1])  # (B, M, M, T, E)
    x2_tiled = tf.tile(x2, [1, M, 1, 1, 1])  # (B, M, M, T, E)
    # Stack pair dimension
    pairs = tf.stack([x1_tiled, x2_tiled], axis=-2)  # (B, M, M, T, 2, E)
    return pairs


@keras.saving.register_keras_serializable()
def scale_broadcast(x):
    scale, embedding = x
    scale = tf.reshape(scale, (-1, 1, 1, 1))  # (batch, 1, 1, 1)
    return scale * tf.ones_like(embedding[..., :1])


@keras.saving.register_keras_serializable()
def time_broadcast(x):
    scale, embedding = x  # scale: (batch,), embedding: (BATCH, M, T, E)

    BATCH = tf.shape(embedding)[0]
    N = tf.shape(embedding)[1]
    X = tf.shape(embedding)[2]
    DIM = tf.shape(embedding)[3]
    increments = tf.linspace(0.0, 20.0, X)
    increments = tf.reshape(increments, (1, 1, X, 1))  # broadcast shape
    increments = tf.tile(increments, [BATCH, N, 1, 1])  # shape: [BATCH, M, T, 1]

    # multiply by per-batch scale
    scale = tf.reshape(scale, (-1, 1, 1, 1))  # (BATCH, 1, 1, 1)
    return scale * increments


model = keras.models.load_model('/kaggle/input/mabe-resnet/keras/default/14/median_300.keras', safe_mode=False)

In [None]:
model.summary()

In [None]:
TRAIN_SEQ_SIZE = 2000
seed = 12
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
data_dir = pathlib.Path('/kaggle/input/MABe-mouse-behavior-detection')
pd.options.mode.chained_assignment = None
pd.set_option('future.no_silent_downcasting', True)

train_metadata = pd.read_csv(data_dir / 'train.csv')
test_metadata = pd.read_csv(data_dir / 'test.csv')
train_metadata['x_cm_scale'] = train_metadata['video_width_pix']/train_metadata['pix_per_cm_approx']
train_metadata['y_cm_scale'] = train_metadata['video_height_pix']/train_metadata['pix_per_cm_approx']

test_metadata['x_cm_scale'] = test_metadata['video_width_pix']/test_metadata['pix_per_cm_approx']
test_metadata['x_cm_scale'] = test_metadata['x_cm_scale']/(train_metadata['x_cm_scale'].max())
test_metadata['y_cm_scale'] = test_metadata['video_height_pix']/test_metadata['pix_per_cm_approx']
test_metadata['y_cm_scale'] = test_metadata['y_cm_scale']/(train_metadata['y_cm_scale'].max())

behaviors_train = (train_metadata['behaviors_labeled']
             .apply(lambda x: ast.literal_eval(x) if x is not np.nan else x)
             .explode('behaviors_labeled'))
behaviors = (test_metadata['behaviors_labeled']
             .apply(lambda x: ast.literal_eval(x) if x is not np.nan else x)
             .explode('behaviors_labeled'))
if len(behaviors_train) > len(behaviors):
    behaviors = behaviors_train
behaviors = behaviors.dropna().str.split(',').str[2].unique()
behaviors = list(set(map(lambda x: x.replace("'", ""), behaviors)))
behaviors = sorted(behaviors)
behaviors = ["nothing"] + behaviors
NUM_BEHAVIORS = len(behaviors)
# number to behavior
behaviors_map = {i: x for i, x in enumerate(behaviors)}
# behavior to number
behaviors_map_rev = {x: i for i, x in enumerate(behaviors)}
_behaviors = [',' + x for x in behaviors]

labeled_videos = []
for _, row in test_metadata.loc[test_metadata['behaviors_labeled'].notna()].iterrows():
    behaviors_found = set()
    for _behavior, behavior in zip(_behaviors, behaviors):
        if _behavior in str(row['behaviors_labeled']):
            # get rid of the extra quote pairs
            behaviors_found.add(behaviors_map_rev[behavior.replace("'", "")])
    behaviors_found.add(behaviors_map_rev["nothing"])
    labeled_videos.append({
        'lab': row['lab_id'],
        'video': row['video_id'],
        'seconds_per_frame': 1 / row['frames_per_second'],
        'video_width_pix': row['video_width_pix'],
        'video_height_pix': row['video_height_pix'],
        'x_cm_scale': row['x_cm_scale'],
        'y_cm_scale': row['y_cm_scale'],
        'behaviors': list(behaviors_found)
    })

body_parts_train = (train_metadata['body_parts_tracked']
              .apply(lambda x: ast.literal_eval(x) if x is not np.nan else x).explode().unique())
body_parts = (test_metadata['body_parts_tracked']
              .apply(lambda x: ast.literal_eval(x) if x is not np.nan else x).explode().unique())
if len(body_parts_train) > len(body_parts):
    body_parts = body_parts_train
body_parts = sorted(body_parts)
BODY_PARTS = len(body_parts)
FEATURES = BODY_PARTS+1
body_parts_map = {x: i for i, x in enumerate(body_parts)}

pd.options.mode.chained_assignment = None

In [None]:
output_file = "submission.csv"

if not os.path.exists(output_file):
    with open(output_file, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([
            "row_id",
            "video_id",
            "agent_id",
            "target_id",
            "action",
            "start_frame",
            "stop_frame"
        ])

In [None]:
def make_windows(pair, segment_start, video_stop):
    up_limit = min(len(pair), video_stop - segment_start)

    if up_limit <= 1:
        return []

    segments = []
    curr = pair[0]
    start = 0

    for i in range(1, up_limit):
        if pair[i] != curr:
            if behaviors_map[curr] != "nothing":
                global_start = start + segment_start
                global_end = (i - 1) + segment_start

                if global_start != global_end:
                    segments.append((behaviors_map[curr], global_start, global_end))

            curr = pair[i]
            start = i

    if behaviors_map[curr] != "nothing":
        global_start = start + segment_start
        global_end = (up_limit - 1) + segment_start

        if global_start != global_end:
            segments.append((behaviors_map[curr], global_start, global_end))

    return segments

In [None]:
def get_sequence_from_df(track_df, video_width, video_height, start_frame, end_frame):
    track_frames_filter = (track_df['video_frame'] >= start_frame) & (track_df['video_frame'] < end_frame)
    track_seq_df = track_df[track_frames_filter]
    track_seq_df['bodypart'] = track_seq_df['bodypart'].replace(body_parts_map)
    track_seq_df['x'] = track_seq_df['x'] / video_width
    track_seq_df['y'] = track_seq_df['y'] / video_height
    track_seq_df['video_frame'] = track_seq_df['video_frame'] - start_frame
    # Mice are not always incremented by 1 in order or 0 indexed in a given segment
    mice_ids = track_df['mouse_id'].unique()
    num_mice = len(mice_ids)
    # Literally no mice data
    if num_mice == 0:
        return None
    track_array = np.zeros((num_mice, TRAIN_SEQ_SIZE, FEATURES, 3))
    for row in track_seq_df.itertuples():
        track_array[row.mouse_id, row.video_frame, row.bodypart] = [row.x, row.y, 1.0]
    coords = track_array[:, :, :-1, :2]
    mask = track_array[:, :, :-1, 2] == 1.0
    masked_coords = np.where(mask[..., None], coords, np.nan)
    median_coords = np.nanmedian(masked_coords, axis=2)
    median_coords = np.nan_to_num(median_coords)
    track_array[:, :, :-1, :2] -= median_coords[:, :, None, :] * mask[..., None]
    track_array[:, :, -1, :2] = median_coords
    track_array[:, :, -1, 2] = 1.0
    return track_array


def ds_generator():
    for video in labeled_videos:
        track_df = pd.read_parquet(data_dir/f"test_tracking/"
                                   f"{video['lab']}/{video['video']}.parquet")
        track_df = track_df.fillna(0)
        mice_ids = track_df['mouse_id'].unique()
        zero_ind_mice_ids = {x: i for i, x in enumerate(mice_ids)}
        track_df['mouse_id'] = track_df['mouse_id'].replace(zero_ind_mice_ids)
        video_start_frame = track_df['video_frame'].min()
        video_stop_frame = track_df['video_frame'].max()
        divisions = ceil((video_stop_frame - video_start_frame) / TRAIN_SEQ_SIZE)
        behaviors_mask = np.zeros((NUM_BEHAVIORS,))
        behaviors_mask[video['behaviors']] = 1
        for division in range(divisions):
            segment_start_frame = division * TRAIN_SEQ_SIZE + video_start_frame
            segment_stop_frame = (division + 1) * TRAIN_SEQ_SIZE + video_start_frame
            seq_out = get_sequence_from_df(track_df, video['video_width_pix'],
                                           video['video_height_pix'], segment_start_frame, segment_stop_frame)
            if seq_out is None:
                continue
            yield (seq_out, video['x_cm_scale'], video['y_cm_scale'],
                   video['seconds_per_frame'], video['video'],
                   segment_start_frame, list(zero_ind_mice_ids.keys()),
                   video_stop_frame, behaviors_mask)


tf_ds = tf.data.Dataset.from_generator(ds_generator, output_signature=(
    tf.TensorSpec((None, TRAIN_SEQ_SIZE, FEATURES, 3)),  # seq
    tf.TensorSpec(()),  # x_cm
    tf.TensorSpec(()),  # y_cm
    tf.TensorSpec(()),  # time,
    tf.TensorSpec((), dtype=tf.int32),  # video
    tf.TensorSpec((), dtype=tf.int32),  # segment start
    tf.TensorSpec((None,), dtype=tf.int32),  # mice ids
    tf.TensorSpec((), dtype=tf.int32),  # video stop
    tf.TensorSpec((NUM_BEHAVIORS,)),  # behaviors mask
))
tf_ds = tf_ds.batch(1).prefetch(tf.data.AUTOTUNE)
row_number = 0
for x in tqdm(tf_ds):
    out = model(x[:4])[0]
    mask = x[8][0]
    out = out * mask
    out = tf.argmax(out, axis=-1)
    out = out.numpy()
    segment_start = x[5][0].numpy()
    video = x[4][0].numpy()
    id_map = x[6][0].numpy()
    video_stop = x[7][0].numpy()
    num_mice = len(id_map)
    pair_ids = [(id_map[i], id_map[j]) for i in range(num_mice) for j in range(num_mice)]
    with open('submission.csv', mode='a') as f:
        writer = csv.writer(f)
        for pair_id, pair in zip(pair_ids, out):
            windows = make_windows(pair, segment_start, video_stop)
            for window in windows:
                target = f"mouse{pair_id[1]}"
                if pair_id[0] == pair_id[1]:
                    target = "self"
                writer.writerow([
                    row_number,
                    video,
                    f"mouse{pair_id[0]}",
                    target,
                    window[0],
                    window[1],
                    window[2]
                ])
                row_number += 1
                