In [3]:
import numpy as np
import pandas as pd
import plotly
import plotly.graph_objs as go
from plotly.subplots import make_subplots
from scipy.optimize import linear_sum_assignment


## Setup

In [5]:
"""Standardize subject colors for plotting"""

subject_colors = plotly.colors.qualitative.Plotly
subject_colors_dict = {
    "BAA-1104045": subject_colors[0],
    "BAA-1104047": subject_colors[1],
    "BAA-1104048": subject_colors[2],
    "BAA-1104049": subject_colors[3],
}


In [None]:
"""Define functions"""


def plot_xy(df):
    """Function to plot the x and y positions of the subjects."""
    fig = make_subplots(rows=2, cols=1, shared_xaxes=True)
    classes = df["class"].unique()
    for class_ in classes:
        data = df[df["class"] == class_]
        fig.add_trace(
            go.Scatter(
                x=data.index,
                y=data["x"],
                mode="markers",
                name=class_,
                marker=dict(color=subject_colors_dict[class_]),
                hovertemplate="Speed: %{text}",
                text=data["speed"].tolist(),
            ),
            row=1,
            col=1,
        )
        fig.add_trace(
            go.Scatter(
                x=data.index,
                y=data["y"],
                mode="markers",
                name=class_,  # Use the class as the name of the trace
                marker=dict(color=subject_colors_dict[class_]),
                hovertemplate="Speed: %{text}",
                text=data["speed"].tolist(),
            ),
            row=2,
            col=1,
        )
    fig.update_yaxes(title_text="x position", row=1, col=1)
    fig.update_yaxes(title_text="y position", row=2, col=1)
    fig.show()


## Load data

In [2]:
centroid_df_cp = pd.read_feather("centroid_df_cp.feather")
centroid_df_cp


Unnamed: 0_level_0,class,class_likelihood,x,y,speed
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2024-02-18 11:00:00.000000000,BAA-1104047,0.999277,362,155,
2024-02-18 11:00:00.000000000,BAA-1104045,0.552562,1309,499,
2024-02-18 11:00:00.059999943,BAA-1104047,0.999123,370,147,188.561987
2024-02-18 11:00:00.059999943,BAA-1104045,0.554072,1309,499,0.000000
2024-02-18 11:00:00.119999886,BAA-1104047,0.998829,378,142,157.233168
...,...,...,...,...,...
2024-02-18 11:05:00.840000153,BAA-1104045,0.999992,397,933,0.000000
2024-02-18 11:05:00.900000095,BAA-1104047,0.996929,1309,509,0.000000
2024-02-18 11:05:00.900000095,BAA-1104045,0.999986,397,931,33.333366
2024-02-18 11:05:00.960000038,BAA-1104047,0.996008,1309,509,0.000000


In [None]:
speed_threshold = 500
classes = centroid_df_cp["class"].unique()
timestamps = centroid_df_cp.index.unique()  # assuming timestamps are sorted
speed_mask = (np.isfinite(centroid_df_cp["speed"].values)) & (
    centroid_df_cp["speed"] > speed_threshold
)
dtypes_dict = centroid_df_cp.dtypes.to_dict()


In [None]:
while speed_mask.any():  # while there are speed violations
    # work on "windows" of 2 consecutive violations
    first_two_violations = centroid_df_cp[speed_mask].index.unique().sort_values()[:2]
    if len(first_two_violations) > 1:
        start_window, end_window = pd.to_datetime(first_two_violations)
        swap_window = centroid_df_cp.loc[start_window:end_window]
    else:  # last violation
        start_window = first_two_violations.iloc[0]
        swap_window = centroid_df_cp.loc[start_window:]
    for curr_timestamp in swap_window.index.unique():
        # use previous rows to correct id assignment based on speed
        prev_timestamp = timestamps[np.where(timestamps == curr_timestamp)[0][0] - 1]
        prev_rows = centroid_df_cp.loc[prev_timestamp]
        curr_rows = centroid_df_cp.loc[curr_timestamp]
        # if isinstance(prev_rows, pd.Series):
        if prev_rows.__class__.__name__ == "Series":
            prev_rows = prev_rows.to_frame().T.astype(dtypes_dict)
        # if isinstance(curr_rows, pd.Series):
        if curr_rows.__class__.__name__ == "Series":
            curr_rows = curr_rows.to_frame().T.astype(dtypes_dict)
        prev_rows = prev_rows.reset_index().rename(columns={"index": "time"})
        curr_rows = curr_rows.reset_index().rename(columns={"index": "time"})
        # initialise np array to store the hungarian cost matrix
        cost_matrix = np.zeros((len(prev_rows), len(curr_rows)))
        for prev_row in prev_rows.itertuples():
            for curr_row in curr_rows.itertuples():
                # Calculate speed
                x_diff = curr_row.x - prev_row.x
                y_diff = curr_row.y - prev_row.y
                distance = np.sqrt(x_diff**2 + y_diff**2)
                time_diff = (curr_row.time - prev_row.time).total_seconds()
                cost_matrix[prev_row.Index, curr_row.Index] = (
                    distance / time_diff if time_diff != 0 else np.nan
                )
        row_ind, col_ind = linear_sum_assignment(cost_matrix)
        if len(prev_rows) < len(curr_rows):
            classes_to_assign = classes.copy()
            swap_order = {
                prev_rows.loc[r, "class"]: c for r, c in zip(row_ind, col_ind)
            }
            for id, new in swap_order.items():
                # get the index of ID
                old = np.where(classes_to_assign == id)[0][0]
                if old != new:
                    # swap the values
                    classes_to_assign[new], classes_to_assign[old] = (
                        classes_to_assign[old],
                        classes_to_assign[new],
                    )
            centroid_df_cp.loc[curr_timestamp, "class"] = classes_to_assign[
                : len(curr_rows)
            ]
        else:
            curr_rows.loc[col_ind, "class"] = prev_rows.loc[row_ind, "class"].values
            curr_rows = curr_rows.loc[col_ind].set_index("time", drop=True)
            centroid_df_cp.loc[curr_timestamp, "class"] = prev_rows.loc[
                row_ind, "class"
            ].values
    # recompute speed and speed_mask
    centroid_df_cp["speed"] = (
        centroid_df_cp.groupby("class")[["x", "y"]].diff().apply(np.linalg.norm, axis=1)
        / centroid_df_cp.reset_index()
        .groupby("class")["time"]
        .diff()
        .dt.total_seconds()
        .values
    )
    speed_mask = (np.isfinite(centroid_df_cp["speed"].values)) & (
        centroid_df_cp["speed"] > speed_threshold
    )


In [None]:
# Compute final speed
centroid_df_cp["speed"] = (
    centroid_df_cp.groupby("class")[["x", "y"]].diff().apply(np.linalg.norm, axis=1)
    / centroid_df_cp.reset_index()
    .groupby("class")["time"]
    .diff()
    .dt.total_seconds()
    .values
)


In [10]:
plot_xy(centroid_df_cp)