In [1]:
import pandas as pd

full = pd.read_parquet('../merged_outputs.parquet')

In [2]:
def is_correct(frame, k, column):
    attack = frame[frame.attack_team]
    attack = attack.sort_values(f'xReceiver_{column}', ascending=False)
    return attack.iloc[:k].receiver.any()

columns = ['broadcast', 'imputed', 'ground_truth']
k_s = [1, 2, 3, 4]

for name in columns:
    for k in k_s:
        successful = full.groupby('pass_number', group_keys=False).filter(lambda x: x.receiver.any())
        accuracy = successful.groupby('pass_number').apply(
            lambda x: (is_correct(x, k=k, column=name))
        ).mean()
        
        print(f"{name} top-{k} loss = {round(accuracy*100, 2)}")

broadcast top-1 loss = 42.03
broadcast top-2 loss = 59.42
broadcast top-3 loss = 71.01
broadcast top-4 loss = 80.43
imputed top-1 loss = 55.43
imputed top-2 loss = 76.81
imputed top-3 loss = 88.04
imputed top-4 loss = 93.84
ground_truth top-1 loss = 59.06
ground_truth top-2 loss = 79.35
ground_truth top-3 loss = 90.94
ground_truth top-4 loss = 93.48


In [3]:
def get_iou(frame, column):
    gt_ids = set(frame[(0.1 < frame['xReceiver_ground_truth'])].agent_index)
    remote_ids = set(frame[(0.1 < frame[f'xReceiver_{column}'])].agent_index)
    
    intersection = len(gt_ids.intersection(remote_ids))
    union = len(gt_ids.union(remote_ids))
    
    return intersection/union


imputed_iou = full.groupby('pass_number', group_keys=False).apply(get_iou, column='imputed').mean()
broadcast_iou = full.groupby('pass_number', group_keys=False).apply(get_iou, column='broadcast').mean()

print(f"Imputed IOU = {round(imputed_iou, 4)}")
print(f"Raw Broadcast IOU = {round(broadcast_iou, 4)}")

Imputed IOU = 0.7121
Raw Broadcast IOU = 0.5105
