<a href="https://colab.research.google.com/github/healthonrails/annolid/blob/main/docs/tutorials/Annolid_post_processing_fix_left_right_switch.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Post-processsing of left right switch issues


In [None]:
import pandas as pd
import numpy as np
import ast
import warnings
from google.colab import files
from pathlib import Path

warnings.filterwarnings("ignore")
%matplotlib inline

In [None]:
TRACKING_CSV_FILE = "/content/mask_rcnn_tracking_results_with_segmentation.csv"
df = pd.read_csv(TRACKING_CSV_FILE)

In [None]:
df.head()

In [None]:
height, width = ast.literal_eval(df.iloc[0].segmentation)["size"]

# Calculate the bbox center point x, y locations

In [None]:
cx = (df.x1 + df.x2) / 2
cy = (df.y1 + df.y2) / 2
df["cx"] = cx
df["cy"] = cy

## Fix the left right switch by checking the middle point of the video width. It works best for non-moving or objects not cross the middle.
### We assume your labels have Left and Right in it e.g. LeftZone, RightZone, LeftTeaball, or RightTeaball. 

In [None]:
def switch_left_right(row, width=800):
    instance_name = row["instance_name"]
    if "cx" in row:
        x_val = row["cx"]
    else:
        x_val = row["x1"]
    if "Left" in instance_name and x_val >= width / 2:
        return instance_name.replace("Left", "Right")
    elif "Right" in instance_name and x_val < width / 2:
        return instance_name.replace("Right", "Left")
    return instance_name

In [None]:
df["instance_name"] = df.apply(lambda row: switch_left_right(row, width), axis=1)

In [None]:
df.tail()

## Fill the left zone and right zone with mode 

In [None]:
df_leftzone = df[df.instance_name == "LeftZone"].mode().iloc[0]
df_rightzone = df[df.instance_name == "RightZone"].mode().iloc[0]
# Fill missing LeftZone
instance_name = "LeftZone"
fill_value = df_leftzone
for frame_number in df.frame_number:
    instance_names = df[df.frame_number == frame_number].instance_name.to_list()
    if instance_name not in instance_names:
        fill_value.frame_number = frame_number
        df = df.append(fill_value, ignore_index=True)

# Fill missing RightZone
instance_name = "RightZone"
fill_value = df_rightzone
for frame_number in df.frame_number:
    instance_names = df[df.frame_number == frame_number].instance_name.to_list()
    if instance_name not in instance_names:
        fill_value.frame_number = frame_number
        df = df.append(fill_value, ignore_index=True)

In [None]:
def get_missing_instances_names(frame_number, expected_instance_names=None):
    """Find the missing instance names in the current frame not in the expected list
    Args:
        frame_number (int): current video frame number
        expected_instance_names (list): a list of expected instances e.g.[mouse_1,mouse_2]
    """
    instance_names = df[df.frame_number == frame_number].instance_name
    unique_names_in_current_frame = set(instance_names.to_list())
    return set(expected_instance_names) - unique_names_in_current_frame

In [None]:
def instance_center_distances(old_instances, cur_instances):
    """calculate the center distance between instances in the previous and current frames.

    Args:
        old_instances (pd.DataFrame): instances in the previous frame
        cur_instances (pd.DataFrame): instances in  the current frame

    Returns:
        dict: key: (prev frame_number, prev int(center_x), prev int(center_y),
                    current frame_number, current int(center_x),curent int(center_y)
              val: (dist, old instance name, current instance name)
    """
    dists = {}
    for cidx, ci in cur_instances.iterrows():
        for oidx, oi in old_instances.iterrows():
            if (
                ci["frame_number"] == oi["frame_number"]
                and int(ci["cx"]) == int(oi["cx"])
                and int(ci["cy"]) == int(oi["cy"])
            ):
                continue
            dist = np.sqrt((ci["cx"] - oi["cx"]) ** 2 + (ci["cy"] - oi["cy"]) ** 2)
            key = (
                oi["frame_number"],
                int(oi["cx"]),
                int(oi["cy"]),
                ci["frame_number"],
                int(ci["cx"]),
                int(ci["cy"]),
            )
            dists[key] = (dist, oi["instance_name"], ci["instance_name"])
    return dists

In [None]:
def find_last_show_position(
    instance_name="Female_52", frame_number=0, frames_backward=30
):
    """Find the last detection location and mask info the given instance and frame number

    Args:
        instance_name (str, optional): Instance name. Defaults to 'Female_52'.
        frame_number (int, optional): frame number. Defaults to 0.
        frames_backword (int, optional): number of frames back. Defaults to 30.

    Returns:
        pd.DataFrame: dataframe row
    """
    return (
        df[
            (df.instance_name == instance_name)
            & (df.frame_number < frame_number)
            & (df.frame_number > frame_number - frames_backward)
        ]
        .sort_values(by=["frame_number", "class_score"], ascending=False)
        .head(1)
    )

In [None]:
def find_future_show_position(
    instance_name="Female_52", frame_number=0, frames_forward=30
):
    """Find the next detection location and mask info the given instance and frame number

    Args:
        instance_name (str, optional): Instance name. Defaults to 'Female_52'.
        frame_number (int, optional): frame number. Defaults to 0.
        frames_forword (int, optional): number of frames forward. Defaults to 30.

    Returns:
        pd.DataFrame: dataframe row
    """
    tmp_df = (
        df[
            (df.instance_name == instance_name)
            & (df.frame_number > frame_number)
            & (df.frame_number <= frame_number + frames_forward)
        ]
        .sort_values(by=["frame_number", "class_score"], ascending=True)
        .head(1)
    )

    return tmp_df

In [None]:
def get_missing_instance_frames(df, instance_name="mouse_1"):
    """Get the frame numbers that do not have a prediction for instance with the
    provided instance name

    Args:
        instance_name (str, optional): instance name. Defaults to 'mouse_1'.

    Returns:
        set: frame numbers
    """

    _df = df[df.instance_name == instance_name]
    max_frame_number = max(_df.frame_number)
    all_frames = set(range(0, max_frame_number + 1))
    frames_with_preds = set(_df.frame_number)
    del _df
    return all_frames - frames_with_preds

In [None]:
def fill_missing_instances(df, instance_name="mouse_2"):
    fill_rows = []
    missing_frames = list(get_missing_instance_frames(df, instance_name=instance_name))
    for frame_number in sorted(missing_frames):
        fp = find_future_show_position(instance_name, frame_number)
        lp = find_last_show_position(instance_name, frame_number)
        if (
            frame_number - lp.frame_number.values[0]
            > fp.frame_number.values[0] - frame_number
        ):
            fp.frame_number = frame_number
            fill_rows.append(fp)
        else:
            lp.frame_number = frame_number
            fill_rows.append(lp)
    df = df.append(fill_rows, ignore_index=True)
    del fill_rows
    return df

In [None]:
df = fill_missing_instances(df)

In [None]:
expected_instance_names = ["mouse_1", "mouse_2"]

In [None]:
missing = 1
missing_predictions = []
for frame_number in df.frame_number:
    missing_instance_name = get_missing_instances_names(
        frame_number, expected_instance_names
    )
    if missing_instance_name:
        for instance_name in missing_instance_name:
            missing += 1
            last_pos = find_last_show_position(instance_name, frame_number)
            future_pos = find_future_show_position(instance_name, frame_number)
            if len(future_pos) > 0 and len(last_pos) > 0:
                if (
                    future_pos.frame_number.values[0] - frame_number
                    <= frame_number - last_pos.frame_number.values[0]
                ):
                    future_pos.frame_number = frame_number
                    missing_predictions.append(future_pos)
            elif len(future_pos) > 0:
                future_pos.frame_number = frame_number
                missing_predictions.append(future_pos)
            elif len(last_pos) > 0:
                last_pos.frame_number = frame_number
                missing_predictions.append(last_pos)
print("total missing: ", missing)

In [None]:
df = df.append(missing_predictions, ignore_index=True)

### Fix missing predicted instances for each frame with in the given moving window.

In [None]:
# disable false positive warning
pd.options.mode.chained_assignment = None
moving_window = 5
all_instance_names = set(df.instance_name.unique())
count = 0
excluded_instances = set(
    ["Nose", "Center", "Tailbase", "LeftInteract", "RightInteract"]
)
# do not fill body parts
all_instance_names = all_instance_names - excluded_instances
print("Fill the instane with name in the list: ", all_instance_names)
missing_predictions = []
max_frame_number = df.frame_number.max()
for frame_number in df.frame_number:
    pred_instance = set(df[df.frame_number == frame_number].instance_name.unique())
    missing_instance = all_instance_names - pred_instance
    for instance_name in missing_instance:
        frame_range_end = frame_number + moving_window
        if frame_range_end > max_frame_number:
            df_instance = df[
                (
                    df.frame_number.between(
                        max_frame_number - moving_window, max_frame_number
                    )
                )
                & (df.instance_name == instance_name)
            ]

        else:
            df_instance = df[
                (df.frame_number.between(frame_number, frame_range_end))
                & (df.instance_name == instance_name)
            ]
        if df_instance.shape[0] >= 1:
            fill_value = df_instance.iloc[0]
        else:
            # (f"No instances {instance_name} in this window")
            # move to the next frame
            continue
        fill_value.frame_number = frame_number
        missing_predictions.append(fill_value)
        count += 1
        if count % 1000 == 0:
            print(f"Filling {count} missing {instance_name}")
df = df.append(missing_predictions, ignore_index=True)

## Download the post-processed result CSV file to your local device

In [None]:
tracking_results_csv = f"{Path(TRACKING_CSV_FILE).stem}_fixed_left_right_switches.csv"
df.to_csv(tracking_results_csv)
files.download(tracking_results_csv)