In [205]:
import json
import numpy as np
import pandas as pd
import math

from scipy.spatial.distance import cdist, euclidean
from scipy.optimize import linear_sum_assignment

In [247]:
with open("example_gt.traco") as f:
    gt_json = json.load(f)
    
with open("example_pred.traco") as f:
    pred_json = json.load(f)

In [198]:
PENALTY_MORE_THAN_FIVE_HEXBUGS = 5_000
PENALTY_ZERO_HEXBUGS = 5_000
PENALTY_N_WRONG_HEXBUGS = 1_000 # only applies if the amount of predicted hexbugs is between 1 and 5

In [101]:
def check_guidelines(pred_json):
    if 'rois' not in pred_json.keys():
        raise Exception("Submitted json should be a Dictionary and contain the key 'rois'.")
    
    pred_rois = pred_json['rois']
    
    if not type(pred_rois) == list:
        raise Exception("The dictionary entry 'rois' should contain a list of dictionaries.")
    
    for idx, elem in enumerate(pred_rois):
        if not type(elem) == dict:
            raise Exception("Submitted json should only contain dictionaries, nothing else.")
        
        if 'id' not in elem.keys() or 'z' not in elem.keys() or 'pos' not in elem.keys():
            raise Exception(f"Dictionary at index {idx} is wrong: {elem}.")
        
        if type(elem['pos']) is not list or len(elem['pos']) != 2:
            raise Exception(f"Entry 'pos' at index {idx} is wrong: {elem}.")
            
        if type(elem['id']) != int or type(elem['z']) != int:
            raise Exception(f"Entry 'id' or 'z' at index {idx} is wrong: {elem}.")

In [249]:
def get_score(pred_json, gt_json):
    
    # Check ths submitted json data
    check_guidelines(pred_json)
    
    # Create variable to accumulate the final score
    final_score = 0
    
    # Get the relevant information from the files content
    gt_df = pd.DataFrame(gt_json['rois'])
    pred_df = pd.DataFrame(pred_json['rois'])

    # Determine how many frames we have and how many hexbugs are used in the video of the ground truth
    n_frames = gt_df['z'].max()
    n_hexbugs_gt = gt_df['id'].max()    
    n_hexbugs_pred = pred_df['id'].max()
            
    # Apply penalties if the amount of predicted hexbugs is wrong
    if n_hexbugs_pred < 1:
        final_score += PENALTY_ZERO_HEXBUGS
    elif n_hexbugs_pred > 5:
        final_score += PENALTY_MORE_THAN_FIVE_HEXBUGS
    else:
        final_score += (np.abs(n_hexbugs_pred - n_hexbugs_gt) * PENALTY_N_WRONG_HEXBUGS)  
        
    # Match the predicted hexbugs in the first frame with the ground truth 
    # --> the matching IDs are used in all the following frames
    distance_matrix = cdist(list(gt_df[gt_df['z'] == 0]['pos']), list(pred_df[pred_df['z'] == 0]['pos']))
    row_ind, col_ind = linear_sum_assignment(distance_matrix)  # Hungarian algorithm
    
    matched_ids = {}
    for i, j in zip(row_ind, col_ind):
        matched_ids[f"{i}"] = j
    
    # Iterate over each frame and calculate distances
    for i in range(n_frames):
        frame_gt_df = gt_df[gt_df['z'] == i]
        frame_pred_df = pred_df[pred_df['z'] == i]
        for j in range(n_hexbugs_gt):
            row_gt = frame_gt_df[frame_gt_df['id'] == j]
            row_pred = frame_pred_df[frame_pred_df['id'] == matched_ids[str(j)]]
            score = euclidean(list(row_gt['pos'])[0], list(row_pred['pos'])[0])
            final_score += score
    
    return final_score

In [250]:
print("Score: ", get_score(pred_json, gt_json))

Score:  22.894304243126793
