# Track Identity Assignment Problem

In [1]:
import numpy as np
import pandas as pd
import cv2
import queue
import plotly
import plotly.graph_objs as go
from plotly.subplots import make_subplots

## Setup

In [2]:
"""Define functions"""


def plot_xy(df):
    """Function to plot the x and y positions of the subjects."""
    fig = make_subplots(rows=2, cols=1, shared_xaxes=True)
    classes = df["class"].unique()
    for class_ in classes:
        data = df[df["class"] == class_]
        fig.add_trace(
            go.Scatter(
                x=data.index,
                y=data["x"],
                mode="markers",
                name=class_,  # Use the class as the name of the trace
                marker=dict(color=subject_colors_dict[class_], symbol="circle"),
                hovertemplate="Speed: %{text}",
                text=data["speed"].tolist(),
            ),
            row=1,
            col=1,
        )
        fig.add_trace(
            go.Scatter(
                x=data.index,
                y=data["y"],
                mode="markers",
                name=class_,  # Use the class as the name of the trace
                marker=dict(color=subject_colors_dict[class_], symbol="square"),
                hovertemplate="Speed: %{text}",
                text=data["speed"].tolist(),
            ),
            row=2,
            col=1,
        )
    fig.update_yaxes(title_text="x position", row=1, col=1)
    fig.update_yaxes(title_text="y position", row=2, col=1)
    return fig


def compute_class_speed(df):
    """Function to compute the instantaneous speed of each class."""
    return (
        df.groupby("class")[["x", "y"]].diff().apply(np.linalg.norm, axis=1)
        / df.reset_index().groupby("class")["time"].diff().dt.total_seconds().values
    )


def hex_to_bgr(hex_color):
    """Convert hex color to BGR color (for opencv2)."""
    hex_color = hex_color.lstrip("#")
    int_color = int(hex_color, 16)
    blue = int_color & 255
    green = (int_color >> 8) & 255
    red = (int_color >> 16) & 255
    return blue, green, red

In [3]:
"""Standardize subject colors for plotting"""

subject_colors = plotly.colors.qualitative.Plotly
subject_colors_dict = {
    "BAA-1104045": subject_colors[0],
    "BAA-1104047": subject_colors[1],
    "BAA-1104048": subject_colors[2],
    "BAA-1104049": subject_colors[3],
}

# BGR colors for each class
bgr_colors = {
    class_: hex_to_bgr(color) for class_, color in subject_colors_dict.items()
}

## Load data

In [4]:
pose_df = pd.read_feather("data/pose.feather")
pose_df["speed"] = compute_class_speed(pose_df)
pose_df

Unnamed: 0_level_0,class,class_likelihood,x,y,speed
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2024-02-18 11:00:00.000000000,BAA-1104047,0.999277,362,155,
2024-02-18 11:00:00.000000000,BAA-1104045,0.552562,1309,499,
2024-02-18 11:00:00.059999943,BAA-1104047,0.999123,370,147,1.885620e+02
2024-02-18 11:00:00.059999943,BAA-1104045,0.554072,1309,499,0.000000e+00
2024-02-18 11:00:00.119999886,BAA-1104047,0.998829,378,142,1.572332e+02
...,...,...,...,...,...
2024-02-18 11:05:00.840000153,BAA-1104045,0.999992,397,933,inf
2024-02-18 11:05:00.900000095,BAA-1104045,0.996929,1309,509,1.676241e+04
2024-02-18 11:05:00.900000095,BAA-1104045,0.999986,397,931,inf
2024-02-18 11:05:00.960000038,BAA-1104045,0.996008,1309,509,1.674838e+04


In [5]:
rfid_df = pd.read_feather("data/rfid.feather")
rfid_df

Unnamed: 0_level_0,rfid,class,x,y,location
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2024-02-18 11:00:07.467743874,977200010164158,BAA-1104047,1214,426,Nest2
2024-02-18 11:01:02.616256237,977200010377711,BAA-1104045,1218,642,Nest1
2024-02-18 11:01:13.198847771,977200010377711,BAA-1104045,195,563,Gate
2024-02-18 11:01:13.628160000,977200010377711,BAA-1104045,195,563,Gate
2024-02-18 11:01:13.965439796,977200010377711,BAA-1104045,195,563,Gate
...,...,...,...,...,...
2024-02-18 11:04:20.957727909,977200010377711,BAA-1104045,600,753,Patch2
2024-02-18 11:04:22.397247791,977200010377711,BAA-1104045,600,753,Patch2
2024-02-18 11:04:24.645919800,977200010377711,BAA-1104045,600,753,Patch2
2024-02-18 11:04:25.432896137,977200010377711,BAA-1104045,600,753,Patch2


## Visualise data

In [6]:
fig = plot_xy(pose_df)
fig.show()

Issues
- Track IDs swap between frames (temporal discontinuities)
- Same Track ID is assigned to multiple animals in the same frame/timestamp

In [7]:
"""Assign the row with duplicated ID with lower likelihood to another ID"""

pose_df_cp = pose_df.reset_index().copy()
classes = np.array(pose_df_cp["class"].unique())
# Mask for rows with multiple assignments of the same ID at the same time
many_to_one_mask = pose_df_cp.groupby(["time", "class"]).transform("size") > 1
duplicated_data = pose_df_cp.loc[many_to_one_mask]
# Indices for rows with lower likelihood
low_likelihood_idx = duplicated_data.loc[
    ~duplicated_data.index.isin(
        duplicated_data.groupby(["time", "class"])["class_likelihood"].idxmax()
    )
].index
# This assigns another class randomly (in 2-animal case, it's the other animal, but in >2-animal case, it may assign duplicate IDs again)
pose_df_cp.loc[low_likelihood_idx, "class"] = pose_df_cp.loc[low_likelihood_idx].apply(
    lambda x: np.random.choice(classes[classes != x["class"]]), axis=1
)
pose_df_cp.set_index("time", inplace=True)
pose_df_cp["speed"] = compute_class_speed(pose_df_cp)
fig = plot_xy(pose_df_cp)
fig.show()

Temporal discontinuities
- Typically we use distance between consecutive frames to determine potential swaps
- However, this is not always reliable as there can be missing data (e.g. occlusions) &rarr; use _speed_ 

Pseudocode
```
While speed violation exists
    Flip IDs
    Compute speed
```

In [8]:
# plot boxplot of speed for each class
fig = go.Figure()
for class_ in pose_df_cp["class"].unique():
    fig.add_trace(
        go.Box(
            y=pose_df_cp[pose_df_cp["class"] == class_]["speed"],
            name=class_,
            marker=dict(color=subject_colors_dict[class_]),
        )
    )
fig.show()

In [8]:
speed_threshold = 1000
classes = pose_df_cp["class"].unique()
timestamps = pose_df_cp.index.unique()  # assuming timestamps are sorted
speed_mask = (np.isfinite(pose_df_cp["speed"].values)) & (
    pose_df_cp["speed"] > speed_threshold
)
dtypes_dict = pose_df_cp.dtypes.to_dict()

In [102]:
while speed_mask.any():
    pose_df_cp.loc[speed_mask, "class"] = pose_df_cp.loc[speed_mask].apply(
        lambda x: np.random.choice(classes[classes != x["class"]]), axis=1
    )
    # recompute speed and speed_mask
    pose_df_cp["speed"] = compute_class_speed(pose_df_cp)
    speed_mask = (np.isfinite(pose_df_cp["speed"].values)) & (
        pose_df_cp["speed"] > speed_threshold
    )
pose_df_cp.to_feather("data/pose_cleaned.feather")

In [9]:
pose_df_cp = pd.read_feather("data/pose_cleaned.feather")
pose_df_cp

Unnamed: 0_level_0,class,class_likelihood,x,y,speed
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2024-02-18 11:00:00.000000000,BAA-1104047,0.999277,362,155,
2024-02-18 11:00:00.000000000,BAA-1104045,0.552562,1309,499,
2024-02-18 11:00:00.059999943,BAA-1104047,0.999123,370,147,188.561987
2024-02-18 11:00:00.059999943,BAA-1104045,0.554072,1309,499,0.000000
2024-02-18 11:00:00.119999886,BAA-1104047,0.998829,378,142,157.233168
...,...,...,...,...,...
2024-02-18 11:05:00.840000153,BAA-1104045,0.999992,397,933,0.000000
2024-02-18 11:05:00.900000095,BAA-1104047,0.996929,1309,509,0.000000
2024-02-18 11:05:00.900000095,BAA-1104045,0.999986,397,931,33.333366
2024-02-18 11:05:00.960000038,BAA-1104047,0.996008,1309,509,0.000000


In [10]:
fig = plot_xy(pose_df_cp)
fig.show()

In [11]:
for i, class_ in enumerate(classes):
    data = rfid_df[rfid_df["class"] == class_]
    fig.add_trace(
        go.Scatter(
            x=data.index,
            y=data["x"],
            mode="markers",
            name=f"{class_} GT",  # Use the class as the name of the trace
            marker=dict(color=subject_colors[i + 2], symbol="circle"),
        ),
        row=1,
        col=1,
    )
    fig.add_trace(
        go.Scatter(
            x=data.index,
            y=data["y"],
            mode="markers",
            name=f"{class_} GT",  # Use the class as the name of the trace
            marker=dict(color=subject_colors[i + 2], symbol="square"),
        ),
        row=2,
        col=1,
    )
fig.show()

## Render video with pose data

In [80]:
# Load video
cap = cv2.VideoCapture("data/videos/AEON3_social0.2_2024-02-18_11-00-00_11-05-00.mp4")

# Get the frame rate of the video
fps = cap.get(cv2.CAP_PROP_FPS)

# Define the codec and create a VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(
    "data/videos/output.mp4", fourcc, fps, (int(cap.get(3)), int(cap.get(4)))
)

# Convert timestamp to ms and put in a FIFO queue
timestamp_ms = queue.Queue()
for ts in pose_df_cp.index.unique():
    timestamp_ms.put(
        (ts, ts.minute * 60 * 1000 + ts.second * 1000 + ts.microsecond / 1000)
    )

frame_count = 0
while cap.isOpened():
    ret, frame = cap.read()
    if ret:
        # Get the current timestamp in ms
        current_ts = cap.get(cv2.CAP_PROP_POS_MSEC)
        pose_ts = timestamp_ms.queue[0]
        if abs(current_ts - pose_ts[1]) < (1000.0 / fps):
            pose_ts = timestamp_ms.get()
            df = pose_df_cp.loc[pose_ts[0]]
            if isinstance(df, pd.Series):
                frame = cv2.circle(
                    frame,
                    (df["x"], df["y"]),
                    radius=5,
                    color=bgr_colors[class_],
                    thickness=-1,
                )
            else:
                for class_ in classes:
                    points = df[(df["class"] == class_)]
                    if not points.empty:
                        # Draw these points on the frame
                        for _, point in points.iterrows():
                            frame = cv2.circle(
                                frame,
                                (point["x"], point["y"]),
                                radius=5,
                                color=bgr_colors[class_],
                                thickness=-1,
                            )
        # Write the frame
        out.write(frame)
        frame_count += 1
    else:
        break

# Release everything when job is finished
cap.release()
out.release()
cv2.destroyAllWindows()

## Questions

- More efficient ways to assign track IDs?
- Incorporate RFID data to help with track ID assignment?
- Generic approach for situations where number of subjects > 2?