In [1]:
import numpy as np
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
import random
import shutil
import scipy.interpolate
import math

# Hyperparameters

In [2]:
CONF_THRESH = 0.95
FPS = 30
FRAME_RANGE = 6
OFFSET = 13 # cos we use window of 14 for rnn

In [3]:
DOMAIN_ROOTDIR = 'domain_probs/'
RESNET_ROOTDIR = 'resnet_probs/'
LABEL_ROOTDIR = '../datasets/'

## Sanity check for probs file

In [4]:
npy_rootdir = 'resnet_probs/pro'
label_rootdir = '../datasets/pro'
for matchdir in os.listdir(npy_rootdir):
    for npy in os.listdir(os.path.join(npy_rootdir, matchdir)):
        basename = npy.split('_probs.npy')[0]
        label_file = os.path.join(label_rootdir, matchdir, 'player_hit', basename + '.mp4_player_hit.csv')
        prob_file = os.path.join(npy_rootdir, matchdir, npy)
        
        df_label = pd.read_csv(label_file)
        hit_labels = df_label['player_hit'].values
        
        probs = np.load(prob_file)
        
        if len(probs) != len(hit_labels):
            print(matchdir, basename, len(probs), len(hit_labels))
        

In [5]:
npy_rootdir = 'domain_probs/pro'
label_rootdir = '../datasets/pro'
for matchdir in os.listdir(npy_rootdir):
    for npy in os.listdir(os.path.join(npy_rootdir, matchdir)):
        basename = npy.split('_probs.npy')[0]
        label_file = os.path.join(label_rootdir, matchdir, 'player_hit', basename + '.mp4_player_hit.csv')
        prob_file = os.path.join(npy_rootdir, matchdir, npy)
        
        df_label = pd.read_csv(label_file)
        hit_labels = df_label['player_hit'].values
        
        probs = np.load(prob_file)
        
        if len(probs) != len(hit_labels):
            print(matchdir, basename, len(probs), len(hit_labels))
        

match3 1_08_10 127 140
match3 2_18_15 1155 1168
match3 2_10_12 213 226
match3 1_12_17 335 348
match3 2_04_07 623 636
match3 3_11_10 206 219
match3 1_01_00 440 453
test_match1 1_07_04 202 215
test_match1 1_09_07 668 681
test_match1 1_09_06 120 133
test_match1 2_03_08 133 146
test_match1 2_03_10 588 601
test_match1 1_07_06 330 343
test_match1 1_06_03 336 349
test_match1 2_02_07 142 155
test_match1 1_07_03 674 687
test_match1 1_05_03 382 395
test_match1 1_05_02 504 517
match23 2_07_03 570 583
match23 1_06_04 514 527
match23 2_12_08 516 529
match23 1_17_13 532 545
match23 1_11_08 568 581
match23 2_02_03 314 327
match4 2_05_07 392 405
match4 2_02_05 463 476
match4 1_15_10 389 402
match4 1_03_02 345 358
match4 2_14_17 184 197
match4 3_18_17 543 556
match4 3_07_05 154 167
match4 3_02_00 341 354
match10 2_14_08 739 752
match10 2_04_02 322 335
match10 1_12_16 468 481
match10 1_03_01 567 580
match10 1_03_03 512 525
match2 1_06_08 341 354
match2 1_00_02 598 611
match2 1_02_03 346 359
match2 1_04_

## Optimisation algorithms

In [62]:
# memo[numframes][3]. i.e. memo[f][0] stores the score when label 0 is assigned to current frame f, tgt with hitlabels, numhits
NEGNUM = -100000
# returns score, hit labels, numhits
def maximise_hit_score(hit_probs, fps): # takes in prediction probabiltiies for each frame
    fnum = len(hit_probs)
    if fnum == 1: # base case
        return np.amax(hit_probs), [np.argmax(hit_probs)], 1 if np.argmax(hit_probs) != 0 else 0
    
    # init memoisation table
    memo = {}
    memo[0] = {}
    memo[0][0] = (hit_probs[0,0], [0], 0)
    memo[0][1] = (hit_probs[0,1], [1], 1)
    memo[0][2] = (hit_probs[0,2], [2], 1)
    
    for f in range(1,fnum):
        memo[f] = {}
        
        # get frame indices of positive hit labels 
        positive_hit_labels0 = [(i, label) for i, label in enumerate(memo[f-1][0][1]) if label != 0]
        last_positive_hit_label0 = (0,0) if len(positive_hit_labels0)==0 else positive_hit_labels0[-1]
        
        # assigning hit0 to current frame
        memo[f][0] = max([ (memo[f-1][0][0] + hit_probs[f,0], memo[f-1][0][1]+[0], memo[f-1][0][2]),\
                               (memo[f-1][1][0] + hit_probs[f,0], memo[f-1][1][1]+[0], memo[f-1][1][2]), \
                               (memo[f-1][2][0] + hit_probs[f,0], memo[f-1][2][1]+[0], memo[f-1][2][2])], key=lambda x: x[0]) # max score
        
        # for enforcing number of hits smaller than numsecs
        maxhits_allowed = math.ceil(f/fps)+1
        
        # 3 constraints: (1) consec hits must be 0.5s apart, (2) numhits < numsecs, (3) hits must alternate between players
        satisfy_constraints0_1 = (f - last_positive_hit_label0[0] > 0.5*fps) and (memo[f-1][0][2] < maxhits_allowed) and (last_positive_hit_label0[1] != 1)
        satisfy_constraints0_2 = (f - last_positive_hit_label0[0] > 0.5*fps) and (memo[f-1][0][2] < maxhits_allowed) and (last_positive_hit_label0[1] != 2)
        
        # assigning hit1 to current frame
        if satisfy_constraints0_1:
            memo[f][1] = (memo[f-1][0][0] + hit_probs[f,1], memo[f-1][0][1]+[1], memo[f-1][0][2]+1)
        else:
            memo[f][1] = (NEGNUM, [], 0) # assign large negative number so that this will not be in optimal sequence
        
        # assigning hit2 to current frame
        if satisfy_constraints0_2:
            memo[f][2] = (memo[f-1][0][0] + hit_probs[f,2], memo[f-1][0][1]+[2], memo[f-1][0][2]+1)
        else:
            memo[f][2] = (NEGNUM, [], 0) # assign large negative number so that this will not be in optimal sequence
            
    return np.array(memo[fnum-1][0][1]).astype('int') # for sure last frame would be no hit

In [72]:
# memo[numframes][3]. i.e. memo[f][0] stores the score when label 0 is assigned to current frame f, tgt with hitlabels, numhits
NEGNUM = -100000
# returns score, hit labels, numhits
def maximise_hit_score2(hit_probs, fps): # takes in prediction probabiltiies for each frame
    fnum = len(hit_probs)
    if fnum == 1: # base case
        return np.amax(hit_probs), [np.argmax(hit_probs)], 1 if np.argmax(hit_probs) != 0 else 0
    
    # init memoisation table
    memo = {}
    memo[0] = {}
    memo[0][0] = (0, [0], 0) # force no hit detections on first frame
    memo[0][1] = (NEGNUM, [], 0)
    memo[0][2] = (NEGNUM, [], 0)
    
    for f in range(1,fnum):
        memo[f] = {}
        
        # get frame indices of positive hit labels 
        positive_hit_labels0 = [(i, label) for i, label in enumerate(memo[f-1][0][1]) if label != 0]
        last_positive_hit_label0 = (0,0) if len(positive_hit_labels0)==0 else positive_hit_labels0[-1]
        
        # assigning hit0 to current frame
        memo[f][0] = max([ (memo[f-1][0][0], memo[f-1][0][1]+[0], memo[f-1][0][2]),\
                               (memo[f-1][1][0], memo[f-1][1][1]+[0], memo[f-1][1][2]), \
                               (memo[f-1][2][0], memo[f-1][2][1]+[0], memo[f-1][2][2])], key=lambda x: x[0]) # max score
        
        # for enforcing number of hits smaller than numsecs, and encouraging detection of more hits to hit this target
        maxhits_allowed = math.ceil(f/fps)
        
        # 3 constraints: (1) consec hits must be 0.5s apart, (2) numhits < numsecs, (3) hits must alternate between players
        satisfy_constraints0_1 = (f - last_positive_hit_label0[0] > 0.5*fps) and (memo[f-1][0][2] <= maxhits_allowed) and (last_positive_hit_label0[1] != 1)
        satisfy_constraints0_2 = (f - last_positive_hit_label0[0] > 0.5*fps) and (memo[f-1][0][2] <= maxhits_allowed) and (last_positive_hit_label0[1] != 2)
        
        # assigning hit1 to current frame, setting previous five frames to hits to ensure window of six hits consecutively
        if satisfy_constraints0_1 and f > 5:
            memo[f][1] = (memo[f-6][0][0] + np.sum(hit_probs[f-5:f+1,1]), memo[f-6][0][1]+[1,1,1,1,1,1], memo[f-1][0][2]+1)
        else:
            memo[f][1] = (NEGNUM, [], 0) # assign large negative number so that this will not be in optimal sequence
        
        # assigning hit2 to current frame
        if satisfy_constraints0_2 and f > 5:
            memo[f][2] = (memo[f-6][0][0] + np.sum(hit_probs[f-5:f+1,2]), memo[f-6][0][1]+[2,2,2,2,2,2], memo[f-1][0][2]+1)
        else:
            memo[f][2] = (NEGNUM, [], 0) # assign large negative number so that this will not be in optimal sequence
            
    return np.array(memo[fnum-1][0][1]).astype('int') # for sure last frame would be no hit

In [37]:
def optimise_hits_naive(hit_preds, fps=30): # takes in predicted labels for each frame
    # ensure no two hits are within 0.5s of each other
    min_frames_apart = int(fps/2)
    frames_hit = np.where(np.array(hit_preds) > 0)[0]
    if len(frames_hit) > 0:
        frames_filtered = []
        frames_filtered.append(frames_hit[0])
        for i, fr in enumerate(frames_hit):
            if i < len(frames_hit)-1:
                if frames_hit[i+1] > fr + min_frames_apart:
                    frames_filtered.append(frames_hit[i+1])

        out_pred = np.zeros(len(hit_preds))
        for i in range(len(frames_filtered)):
            frame_hit = frames_filtered[i]
            out_pred[frame_hit] = hit_preds[frame_hit]
            
    else:
        out_pred = np.array([])
    
    return out_pred.astype('int')

In [77]:
def optimise_hits_naive2(hit_preds, fps=30): # takes in predicted labels for each frame
    # ensure no two hits are within 0.5s of each other
    min_frames_apart = int(fps/2)
    frames_hit = np.where(np.array(hit_preds) > 0)[0]
    if len(frames_hit) > 0:
        frames_filtered = []
        frames_filtered.append(frames_hit[0])
        for i, fr in enumerate(frames_hit):
            if i < len(frames_hit)-1:
                if frames_hit[i+1] > fr + min_frames_apart:
                    frames_filtered.append(frames_hit[i+1])

        out_pred = np.zeros(len(hit_preds))
        for i in range(len(frames_filtered)):
            frame_hit = frames_filtered[i]
            out_pred[frame_hit-3:frame_hit+3] = hit_preds[frame_hit]
            
    else:
        out_pred = np.array([])
    
    return out_pred.astype('int')

## Functions to process of probabilities

In [8]:
def domain_probs2labels(probfile, opt='max', conf_thresh=0.9):
    probs = np.load(probfile)
    df_label = pd.read_csv(labelfile)
    hit_labels = df_label['player_hit'].values
    
    if opt=='max':
        pred_labels = maximise_hit_score(probs, FPS)
    elif opt=='naive':
        pred_labels = []
        for prob in probs:
            if np.amax(prob) > conf_thresh:
                pred_labels.append(np.argmax(prob))
            else:
                pred_labels.append(0)

        pred_labels = optimise_hits_naive(pred_labels, fps=FPS)
    else:
        print('no optimisation used')
        pred_labels = np.argmax(probs, axis=1)
        
    return pred_labels


def resnet_probs2labels(probfile, opt='naive', conf_thresh=0.9):
    probs = np.load(probfile) # note that there are 2 sets of probs, one for each player
    
    df_label = pd.read_csv(label_file)
    hit_labels = df_label['player_hit'].values

    pred_labels = []
    for prob in probs:
        prob1 = prob[:3]
        prob2 = prob[3:]

        label1, conf1 = np.argmax(prob1), np.amax(prob1)
        label2, conf2 = np.argmax(prob2), np.amax(prob2)

        assigned_label = 0
        if label1 == 1 and label2 == 0:
            if conf1 > conf_thresh:
                assigned_label = 1
        elif label1 == 0 and label2 == 2:
            if conf2 > conf_thresh:
                assigned_label = 2

        pred_labels.append(assigned_label)
        
    if opt == 'naive':
        pred_labels = optimise_hits_naive(pred_labels, fps=FPS)
    
    return np.array(pred_labels)

def gtfile2labels(labelfile):
    df_label = pd.read_csv(labelfile)
    hit_labels = df_label['player_hit'].values
    return hit_labels

### Sample processing of probabilities

### Process probs from domain-rnn
OK, practically perfect on pro dataset, as before.

In [84]:
'''
probfile = 'domain_probs/am_singles/match_china2/singles0_probs.npy' 
probs = np.load(probfile)
label_file = '../datasets/am_singles/match_china2/player_hit/singles0.mp4_player_hit.csv'
'''
probfile = 'domain_probs/pro/test_match1/2_02_07_probs.npy' 
probs = np.load(probfile)
label_file = '../datasets/pro/test_match1/player_hit/2_02_07.mp4_player_hit.csv'

df_label = pd.read_csv(label_file)
hit_labels = df_label['player_hit'].values

pred_labels = []
for prob in probs:
    if np.amax(prob) > CONF_THRESH:
        pred_labels.append(np.argmax(prob))
    else:
        pred_labels.append(0)
        
pred_labels1 = optimise_hits_naive2(pred_labels, fps=FPS)
pred_labels2 = maximise_hit_score(probs, FPS)
pred_labels3 = maximise_hit_score2(probs, FPS)

print(np.where(hit_labels[13:]))
print(np.where(np.array(pred_labels)))
print(np.where(pred_labels1))
print(np.where(pred_labels2))
print(np.where(pred_labels3))

(array([21, 42, 58, 78, 88]),)
(array([ 0, 22, 23, 24, 25, 26, 42, 43, 44, 45, 46, 47, 48, 59, 60, 61, 62,
       63, 78, 79, 80, 81, 82, 89, 90, 91]),)
(array([19, 20, 21, 22, 23, 24, 39, 40, 41, 42, 43, 44]),)
(array([  0,  16,  43,  61, 116]),)
(array([ 12,  13,  14,  15,  16,  17,  28,  29,  30,  31,  32,  33,  45,
        46,  47,  48,  49,  50,  61,  62,  63,  64,  65,  66,  87,  88,
        89,  90,  91,  92, 116, 117, 118, 119, 120, 121]),)


In [81]:
print(pred_labels1)
print()
print(hit_labels[13:])

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 2 2 2 2 2 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 1 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]


In [10]:
probfile = 'domain_probs/am_singles/match_china2/singles0_probs.npy' 
probs = np.load(probfile)
label_file = '../datasets/am_singles/match_china2/player_hit/singles0.mp4_player_hit.csv'
df_label = pd.read_csv(label_file)
hit_labels = df_label['player_hit'].values

pred_labels = []
for prob in probs:
    if np.amax(prob) > CONF_THRESH:
        pred_labels.append(np.argmax(prob))
    else:
        pred_labels.append(0)

print(hit_labels[13:])
print()
print(np.array(pred_labels))

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0]

[1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 2 2
 2 2 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 2 2 2 2 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 2 0 0 0 0 0 2 2 2 2 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 2 2 2 2 2
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1
 1

## Process probs from ResNet image classifier for indv crops

In [11]:
probfile = 'resnet_probs/pro/test_match1/1_05_02_probs.npy' # contains probs for near and far player, so nfr x 6 
probs = np.load(probfile)
label_file = '../datasets/pro/test_match1/player_hit/1_05_02.mp4_player_hit.csv'
df_label = pd.read_csv(label_file)
hit_labels = df_label['player_hit'].values

pred_labels = []
for prob in probs:
    prob1 = prob[:3]
    prob2 = prob[3:]
    
    label1, conf1 = np.argmax(prob1), np.amax(prob1)
    label2, conf2 = np.argmax(prob2), np.amax(prob2)
    
    assigned_label = 0
    if label1 == 1 and label2 == 0:
        if conf1 > CONF_THRESH:
            assigned_label = 1
    elif label1 == 0 and label2 == 2:
        if conf2 > CONF_THRESH:
            assigned_label = 2
    
    pred_labels.append(assigned_label)

print(hit_labels)
print()
print(np.array(pred_labels))

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 

In [12]:
matchdir = 'test_match3'
basename = '1_09_15'
prefix = 'pro'
opt = ''
npy = basename + '_probs.npy'
labelfile = os.path.join(LABEL_ROOTDIR, prefix, matchdir, 'player_hit', basename + '.mp4_player_hit.csv')
probfile = os.path.join(RESNET_ROOTDIR, prefix, matchdir, npy)

pred_labels = resnet_probs2labels(probfile, opt=opt, conf_thresh=CONF_THRESH)
gt_labels = gtfile2labels(labelfile)

print(pred_labels)
print()
print(gt_labels)

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 2 2 2 2 2 0 2 2 2 2 0 0 0 0 0 1 0 0 1 0 0 0
 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 2 2 2 2 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 1 1 1 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 

# Evaluation

## Frame-level precision

In [13]:
def Intersection(lst1, lst2):
  return list(set(lst1).intersection(lst2))

def Union(lst1, lst2):
  final_list = list(set().union(lst1, lst2))
  return final_list

def clean_pred_set(true_hit_set, pred_hit_set, frame_range=3):
    pred_hit_set = [[tup[0], tup[1]] for tup in pred_hit_set]
    # clean pred_hit_set, such that if predicted hit frame is within +-3 of true hit frame, we adjust the predicted hit frame
    min_j = 0
    for i in range(len(pred_hit_set)):
        for j in range(min_j, len(true_hit_set)):
            if pred_hit_set[i][0] <= true_hit_set[j][0]+frame_range and pred_hit_set[i][0] >= true_hit_set[j][0]-frame_range:
                pred_hit_set[i][0] = true_hit_set[j][0]
                min_j = j
                break
    pred_hit_set = np.unique(pred_hit_set, axis=0)
    pred_hit_set = [(ls[0], ls[1]) for ls in pred_hit_set]
    
    return pred_hit_set

def eval_metrics(gt_label, pred_label, frame_range=3):
  '''
  only look at frames where hits occurred
  returns accuracy, recall, precision, f1
  '''
  true_hit_set = [(i, label) for i, label in enumerate(gt_label) if label!=0] 
  pred_hit_set = [(i, label) for i, label in enumerate(pred_label) if label!=0]
  
  pred_hit_set = clean_pred_set(true_hit_set, pred_hit_set, frame_range=frame_range)

  if len(true_hit_set) == 0 or len(pred_hit_set) == 0:
        acc, recall, prec, f1 = None, None, None, None
        return acc, recall, prec, f1

  acc = len(Intersection(true_hit_set, pred_hit_set)) / len(Union(true_hit_set, pred_hit_set))
  recall = len(Intersection(true_hit_set, pred_hit_set)) / len(true_hit_set)
  prec = len(Intersection(true_hit_set, pred_hit_set)) / len(pred_hit_set)
  if recall+prec == 0:
    f1 = 0
  else:
    f1 = 2*(recall*prec)/(recall+prec)

  return prec, recall, acc, f1

## 0.2s window level precision

In [None]:
def expand_label_frame_range(gt_label, frame_range=3):
    # expand gt hit labels into a window of labels
    orig = gt_label.copy()
    for i in range(len(orig)):
        if orig[i] != 1:
            gt_label[i-frame_range:i+frame_range+1] = orig[i]
    
    return gt_label

def preds2seframes(preds, target_class):
    frames_se = []
    for fr, label in enumerate(preds):
        if fr > 0:
            if label == target_class and label != preds[fr-1]:
                start_frame = fr
            elif label == target_class and label != preds[fr+1]:
                end_frame = fr
                frames_se.append((start_frame, end_frame))
    return frames_se

def timewin_intersection_over_union(timewinA, timewinB):
    # determine the framenums of the intersection rectangle
    frameA = max(timewinA[0], timewinB[0])
    frameB = min(timewinA[1], timewinB[1])
    
    # compute length of intersection
    interLength = max(0, frameB - frameA + 1)
    
    # compute length of union
    lenA = frameA[1] - frameA[0] + 1
    lenB = frameB[1] - frameB[0] + 1
    unionLength = float(lenA + lenB - interLength)
    
    iou = interLength / unionLength
    
    return iou

def calc_precision(frames_se_gt, frames_se_pred):
    tp = 0
    min_j = 0
    for i in range(len(frames_se_pred)):
        timewin_pred = frames_se_pred[i]
        for j in range(min_j, len(frames_se_gt)):
            timewin_gt = frames_se_gt[j]
            t_iou = timewin_intersection_over_union(timewin_pred, timewin_gt)
            if t_iou > iou_thresh:
                tp += 1
                min_j = j
                break
    precision = tp / len(frames_se_pred)
    
    return precision

# returns precision (array of size 2) of temporal proposals for each target class, for each video
def eval_metrics(gt_label, pred_label, frame_range=3, iou_thresh = 0.5):
    
    # expand gt hit labels into a window of labels
    gt_label = expand_label_frame_range(gt_label, frame_range=frame_range)
            
    precision = np.zeros(2)
    for k, target_class in enumerate([1,2]):
        # get start, end frames for each proposed time window of hit
        frames_se_gt = preds2seframes(gt_label, 1)
        frames_se_pred = preds2seframes(pred_label, 1)

        # calculate precision of target class
        precision[k] = calc_precision(frames_se_gt, frames_se_pred)
        
    return precision

## Evaluate resnet probs

### Pro

In [55]:
prefix = 'pro'
opt = 'naive'

rows_list = []
for matchdir in sorted(os.listdir(os.path.join(RESNET_ROOTDIR, prefix))):
    if 'test' in matchdir:
        for npy in sorted(os.listdir(os.path.join(RESNET_ROOTDIR, prefix, matchdir))):
            basename = npy.split('_probs.npy')[0]
            labelfile = os.path.join(LABEL_ROOTDIR, prefix, matchdir, 'player_hit', basename + '.mp4_player_hit.csv')
            probfile = os.path.join(RESNET_ROOTDIR, prefix, matchdir, npy)

            pred_labels = resnet_probs2labels(probfile, opt=opt, conf_thresh=CONF_THRESH)
            gt_labels = gtfile2labels(labelfile)
            gt_labels = gt_labels[OFFSET:]
            pred_labels = pred_labels[OFFSET:]

            prec, rec, acc, f1 = eval_metrics(gt_labels, pred_labels, frame_range=FRAME_RANGE)

            row_dict = {}
            row_dict['match'] = matchdir
            row_dict['vid'] = basename
            row_dict['acc'] = acc
            row_dict['rec'] = rec
            row_dict['prec'] = prec
            row_dict['f1'] = f1
            rows_list.append(row_dict.copy())

df_resnet_pro = pd.DataFrame(rows_list)
df_resnet_pro.describe()

Unnamed: 0,acc,rec,prec,f1
count,28.0,28.0,28.0,28.0
mean,0.162007,0.193248,0.497506,0.266606
std,0.103749,0.132912,0.220197,0.143468
min,0.0,0.0,0.0,0.0
25%,0.111111,0.129076,0.333333,0.2
50%,0.142857,0.166667,0.5,0.25
75%,0.208882,0.232955,0.666667,0.345577
max,0.5,0.666667,1.0,0.666667


### Am-singles

In [54]:
prefix = 'am_singles'
opt = 'naive'

rows_list = []
for matchdir in sorted(os.listdir(os.path.join(RESNET_ROOTDIR, prefix))):
    for npy in sorted(os.listdir(os.path.join(RESNET_ROOTDIR, prefix, matchdir))):
        basename = npy.split('_probs.npy')[0]
        labelfile = os.path.join(LABEL_ROOTDIR, prefix, matchdir, 'player_hit', basename + '.mp4_player_hit.csv')
        probfile = os.path.join(RESNET_ROOTDIR, prefix, matchdir, npy)

        pred_labels = resnet_probs2labels(probfile, opt=opt, conf_thresh=CONF_THRESH)
        gt_labels = gtfile2labels(labelfile)
        gt_labels = gt_labels[OFFSET:]
        pred_labels = pred_labels[OFFSET:]

        prec, rec, acc, f1 = eval_metrics(gt_labels, pred_labels, frame_range=FRAME_RANGE)

        row_dict = {}
        row_dict['match'] = matchdir
        row_dict['vid'] = basename
        row_dict['acc'] = acc
        row_dict['rec'] = rec
        row_dict['prec'] = prec
        row_dict['f1'] = f1
        rows_list.append(row_dict.copy())

df_resnet_am_singles = pd.DataFrame(rows_list)
df_resnet_am_singles.describe()

Unnamed: 0,acc,rec,prec,f1
count,35.0,35.0,35.0,35.0
mean,0.088023,0.109886,0.245238,0.146512
std,0.10367,0.125014,0.298975,0.165668
min,0.0,0.0,0.0,0.0
25%,0.0,0.0,0.0,0.0
50%,0.076923,0.083333,0.2,0.142857
75%,0.171569,0.230159,0.409722,0.292857
max,0.333333,0.4,1.0,0.5


## Evaluate domain probs

### Evaluate max optimiser

#### domain_pro

In [57]:
prefix = 'pro'
opt = 'max'

rows_list = []
for matchdir in sorted(os.listdir(os.path.join(DOMAIN_ROOTDIR, prefix))):
    if 'test' in matchdir:
        for npy in sorted(os.listdir(os.path.join(DOMAIN_ROOTDIR, prefix, matchdir))):
            basename = npy.split('_probs.npy')[0]
            labelfile = os.path.join(LABEL_ROOTDIR, prefix, matchdir, 'player_hit', basename + '.mp4_player_hit.csv')
            probfile = os.path.join(DOMAIN_ROOTDIR, prefix, matchdir, npy)

            pred_labels = domain_probs2labels(probfile, opt=opt, conf_thresh=CONF_THRESH)
            gt_labels = gtfile2labels(labelfile)
            gt_labels = gt_labels[OFFSET:]

            prec, rec, acc, f1 = eval_metrics(gt_labels, pred_labels, frame_range=FRAME_RANGE)

            row_dict = {}
            row_dict['match'] = matchdir
            row_dict['vid'] = basename
            row_dict['acc'] = acc
            row_dict['rec'] = rec
            row_dict['prec'] = prec
            row_dict['f1'] = f1
            rows_list.append(row_dict.copy())

df_max_domain_pro = pd.DataFrame(rows_list)
df_max_domain_pro.describe()

Unnamed: 0,acc,rec,prec,f1
count,28.0,28.0,28.0,28.0
mean,0.541118,0.852347,0.578119,0.664464
std,0.266408,0.184436,0.268392,0.230379
min,0.095238,0.333333,0.117647,0.173913
25%,0.333333,0.8,0.333333,0.5
50%,0.506098,0.913043,0.585714,0.672043
75%,0.7125,1.0,0.726923,0.831933
max,1.0,1.0,1.0,1.0


#### domain_am_singles

In [58]:
prefix = 'am_singles'
opt = 'max'

rows_list = []
for matchdir in sorted(os.listdir(os.path.join(DOMAIN_ROOTDIR, prefix))):
    for npy in sorted(os.listdir(os.path.join(DOMAIN_ROOTDIR, prefix, matchdir))):
        basename = npy.split('_probs.npy')[0]
        labelfile = os.path.join(LABEL_ROOTDIR, prefix, matchdir, 'player_hit', basename + '.mp4_player_hit.csv')
        probfile = os.path.join(DOMAIN_ROOTDIR, prefix, matchdir, npy)

        pred_labels = domain_probs2labels(probfile, opt=opt, conf_thresh=CONF_THRESH)
        gt_labels = gtfile2labels(labelfile)
        gt_labels = gt_labels[OFFSET:]

        prec, rec, acc, f1 = eval_metrics(gt_labels, pred_labels, frame_range=FRAME_RANGE)

        row_dict = {}
        row_dict['match'] = matchdir
        row_dict['vid'] = basename
        row_dict['acc'] = acc
        row_dict['rec'] = rec
        row_dict['prec'] = prec
        row_dict['f1'] = f1
        rows_list.append(row_dict.copy())

df_max_domain_am_singles = pd.DataFrame(rows_list)
df_max_domain_am_singles.describe()

Unnamed: 0,acc,rec,prec,f1
count,35.0,35.0,35.0,35.0
mean,0.2199,0.666053,0.248709,0.348811
std,0.104485,0.229788,0.122168,0.141495
min,0.0,0.0,0.0,0.0
25%,0.158004,0.5,0.180828,0.272868
50%,0.2,0.666667,0.222222,0.333333
75%,0.286364,0.833333,0.333333,0.445055
max,0.418605,1.0,0.5,0.590164


### domain am_doubles

In [51]:
prefix = 'am_doubles'
opt = 'max'

rows_list = []
for matchdir in sorted(os.listdir(os.path.join(DOMAIN_ROOTDIR, prefix))):
    for npy in sorted(os.listdir(os.path.join(DOMAIN_ROOTDIR, prefix, matchdir))):
        basename_pair = npy.split('_probs.npy')[0]
        basename = basename_pair.split('_')[0]
        labelfile = os.path.join(LABEL_ROOTDIR, prefix, matchdir, 'player_hit', basename + '.mp4_player_hit.csv')
        probfile1 = os.path.join(DOMAIN_ROOTDIR, prefix, matchdir, npy)
        probfile2 = os.path.join(DOMAIN_ROOTDIR, prefix, matchdir, basename + '_pair2_probs.npy')

        pred_labels1 = domain_probs2labels(probfile1, opt=opt, conf_thresh=CONF_THRESH)
        pred_labels2 = domain_probs2labels(probfile2, opt=opt, conf_thresh=CONF_THRESH)
        pred_labels = []
        for i in range(len(pred_labels1)):
            y1, y2 = pred_labels1[i], pred_labels2[i]
            if (y1 == 1 and y2 == 0) or (y1 == 0 and y2 == 1):
                pred_labels.append(1)
            elif (y1 == 2 and y2 == 0) or (y1 == 0 and y2 == 2):
                pred_labels.append(2)
            else:
                pred_labels.append(0)
        
        gt_labels = gtfile2labels(labelfile)
        gt_labels = gt_labels[OFFSET:]

        prec, rec, acc, f1 = eval_metrics(gt_labels, pred_labels, frame_range=FRAME_RANGE)

        row_dict = {}
        row_dict['match'] = matchdir
        row_dict['vid'] = basename
        row_dict['acc'] = acc
        row_dict['rec'] = rec
        row_dict['prec'] = prec
        row_dict['f1'] = f1
        rows_list.append(row_dict.copy())

df_max_domain_am_doubles = pd.DataFrame(rows_list)
df_max_domain_am_doubles.describe()

Unnamed: 0,acc,rec,prec,f1
count,20.0,20.0,20.0,20.0
mean,0.237435,0.460271,0.326628,0.371759
std,0.108549,0.19249,0.148506,0.145297
min,0.0,0.0,0.0,0.0
25%,0.169715,0.378205,0.234804,0.290179
50%,0.25,0.444444,0.348485,0.4
75%,0.287815,0.561404,0.4,0.44697
max,0.5,0.888889,0.666667,0.666667


## Domain naive

### Pro

In [18]:
prefix = 'pro'
opt = 'naive'

rows_list = []
for matchdir in sorted(os.listdir(os.path.join(DOMAIN_ROOTDIR, prefix))):
    if 'test' in matchdir:
        for npy in sorted(os.listdir(os.path.join(DOMAIN_ROOTDIR, prefix, matchdir))):
            basename = npy.split('_probs.npy')[0]
            labelfile = os.path.join(LABEL_ROOTDIR, prefix, matchdir, 'player_hit', basename + '.mp4_player_hit.csv')
            probfile = os.path.join(DOMAIN_ROOTDIR, prefix, matchdir, npy)

            pred_labels = domain_probs2labels(probfile, opt=opt, conf_thresh=CONF_THRESH)
            gt_labels = gtfile2labels(labelfile)
            gt_labels = gt_labels[OFFSET:]

            prec, rec, acc, f1 = eval_metrics(gt_labels, pred_labels, frame_range=FRAME_RANGE)

            row_dict = {}
            row_dict['match'] = matchdir
            row_dict['vid'] = basename
            row_dict['acc'] = acc
            row_dict['rec'] = rec
            row_dict['prec'] = prec
            row_dict['f1'] = f1
            rows_list.append(row_dict.copy())

df_naive_domain_pro = pd.DataFrame(rows_list)
df_naive_domain_pro.describe()

Unnamed: 0,acc,rec,prec,f1
count,28.0,28.0,28.0,28.0
mean,0.749109,0.782911,0.933832,0.847059
std,0.156501,0.146212,0.082397,0.111349
min,0.333333,0.4,0.666667,0.5
25%,0.692308,0.718571,0.897222,0.818182
50%,0.75,0.786041,0.947368,0.857143
75%,0.827899,0.84375,1.0,0.905844
max,1.0,1.0,1.0,1.0


### Am-singles

In [40]:
prefix = 'am_singles'
opt = 'naive'

rows_list = []
for matchdir in sorted(os.listdir(os.path.join(DOMAIN_ROOTDIR, prefix))):
    for npy in sorted(os.listdir(os.path.join(DOMAIN_ROOTDIR, prefix, matchdir))):
        basename = npy.split('_probs.npy')[0]
        labelfile = os.path.join(LABEL_ROOTDIR, prefix, matchdir, 'player_hit', basename + '.mp4_player_hit.csv')
        probfile = os.path.join(DOMAIN_ROOTDIR, prefix, matchdir, npy)

        pred_labels = domain_probs2labels(probfile, opt=opt, conf_thresh=CONF_THRESH)
        gt_labels = gtfile2labels(labelfile)
        gt_labels = gt_labels[OFFSET:]

        prec, rec, acc, f1 = eval_metrics(gt_labels, pred_labels, frame_range=FRAME_RANGE)

        row_dict = {}
        row_dict['match'] = matchdir
        row_dict['vid'] = basename
        row_dict['acc'] = acc
        row_dict['rec'] = rec
        row_dict['prec'] = prec
        row_dict['f1'] = f1
        rows_list.append(row_dict.copy())

df_naive_domain_am_singles = pd.DataFrame(rows_list)
df_naive_domain_am_singles.describe()

Unnamed: 0,acc,rec,prec,f1
count,35.0,35.0,35.0,35.0
mean,0.527182,0.627345,0.698353,0.651286
std,0.263542,0.239785,0.263108,0.23771
min,0.0,0.0,0.0,0.0
25%,0.333333,0.478095,0.5,0.5
50%,0.5,0.625,0.75,0.666667
75%,0.775,0.841667,0.891813,0.873016
max,1.0,1.0,1.0,1.0


### domain am_doubles

In [39]:
prefix = 'am_doubles'
opt = 'naive'

rows_list = []
for matchdir in sorted(os.listdir(os.path.join(DOMAIN_ROOTDIR, prefix))):
    for npy in sorted(os.listdir(os.path.join(DOMAIN_ROOTDIR, prefix, matchdir))):
        basename_pair = npy.split('_probs.npy')[0]
        basename = basename_pair.split('_')[0]
        labelfile = os.path.join(LABEL_ROOTDIR, prefix, matchdir, 'player_hit', basename + '.mp4_player_hit.csv')
        probfile1 = os.path.join(DOMAIN_ROOTDIR, prefix, matchdir, npy)
        probfile2 = os.path.join(DOMAIN_ROOTDIR, prefix, matchdir, basename + '_pair2_probs.npy')

        pred_labels1 = domain_probs2labels(probfile1, opt=opt, conf_thresh=CONF_THRESH)
        pred_labels2 = domain_probs2labels(probfile2, opt=opt, conf_thresh=CONF_THRESH)
        pred_labels = []
        for i in range(len(pred_labels1)):
            y1, y2 = pred_labels1[i], pred_labels2[i]
            if (y1 == 1 and y2 == 0) or (y1 == 0 and y2 == 1):
                pred_labels.append(1)
            elif (y1 == 2 and y2 == 0) or (y1 == 0 and y2 == 2):
                pred_labels.append(2)
            else:
                pred_labels.append(0)
        
        gt_labels = gtfile2labels(labelfile)
        gt_labels = gt_labels[OFFSET:]

        prec, rec, acc, f1 = eval_metrics(gt_labels, pred_labels, frame_range=FRAME_RANGE)

        row_dict = {}
        row_dict['match'] = matchdir
        row_dict['vid'] = basename
        row_dict['acc'] = acc
        row_dict['rec'] = rec
        row_dict['prec'] = prec
        row_dict['f1'] = f1
        rows_list.append(row_dict.copy())

df_naive_domain_am_doubles = pd.DataFrame(rows_list)
df_naive_domain_am_doubles.describe()

Unnamed: 0,acc,rec,prec,f1
count,19.0,19.0,19.0,19.0
mean,0.311026,0.459026,0.470262,0.453988
std,0.159126,0.187218,0.212119,0.179577
min,0.071429,0.125,0.142857,0.133333
25%,0.222527,0.333333,0.333333,0.363971
50%,0.266667,0.444444,0.4,0.421053
75%,0.381016,0.589474,0.665179,0.551192
max,0.615385,0.8,0.916667,0.761905


## Check performance on best and worst videos

### Worst videos

In [22]:
df_naive_domain_pro[df_naive_domain_pro['f1'] < 0.7]

Unnamed: 0,match,vid,acc,rec,prec,f1
8,test_match1,2_02_07,0.333333,0.4,0.666667,0.5
25,test_match3,1_08_09,0.428571,0.5,0.75,0.6


In [24]:
df_naive_domain_am_singles[df_naive_domain_am_singles['f1'] < 0.4]

Unnamed: 0,match,vid,acc,rec,prec,f1
8,match24,1_04_05,0.0,0.0,0.0,0.0
10,match25,1_01_00,0.2,0.333333,0.333333,0.333333
14,match25,1_05_00,0.15,0.25,0.272727,0.26087
17,match25,1_05_03,0.2,0.333333,0.333333,0.333333


In [43]:
df_naive_domain_am_doubles[df_naive_domain_am_doubles['f1'] < 0.4]

Unnamed: 0,match,vid,acc,rec,prec,f1
2,match_china,doubles1,0.230769,0.333333,0.428571,0.375
12,match_clementi,doubles2,0.214286,0.375,0.333333,0.352941
22,match_msia,doubles0,0.098039,0.192308,0.166667,0.178571
32,match_msia,doubles5,0.071429,0.125,0.142857,0.133333
36,match_yewtee,doubles1,0.166667,0.266667,0.307692,0.285714
38,match_yewtee,doubles2,0.16129,0.277778,0.277778,0.277778


### Best videos

In [25]:
df_naive_domain_pro[df_naive_domain_pro['f1'] > 0.9]

Unnamed: 0,match,vid,acc,rec,prec,f1
0,test_match1,1_05_02,0.875,0.875,1.0,0.933333
1,test_match1,1_05_03,1.0,1.0,1.0,1.0
3,test_match1,1_07_03,0.826087,0.826087,1.0,0.904762
4,test_match1,1_07_04,1.0,1.0,1.0,1.0
6,test_match1,1_09_06,1.0,1.0,1.0,1.0
9,test_match1,2_03_08,1.0,1.0,1.0,1.0
12,test_match2,1_04_04,0.875,1.0,0.875,0.933333
26,test_match3,1_09_15,0.833333,0.833333,1.0,0.909091


In [27]:
df_naive_domain_am_singles[df_naive_domain_am_singles['f1'] > 0.9]

Unnamed: 0,match,vid,acc,rec,prec,f1
3,match24,1_01_03,0.875,0.875,1.0,0.933333
4,match24,1_01_04,0.857143,1.0,0.857143,0.923077
9,match24,1_05_05,1.0,1.0,1.0,1.0
11,match25,1_02_00,0.888889,0.888889,1.0,0.941176
15,match25,1_05_01,1.0,1.0,1.0,1.0
21,match26,1_00_03,0.875,1.0,0.875,0.933333
24,match26,1_02_04,0.85,0.85,1.0,0.918919


In [45]:
df_naive_domain_am_doubles[df_naive_domain_am_doubles['f1'] > 0.7]

Unnamed: 0,match,vid,acc,rec,prec,f1
4,match_china,doubles2,0.583333,0.777778,0.7,0.736842
8,match_clementi,doubles0,0.615385,0.8,0.727273,0.761905
20,match_clementi,doubles6,0.55,0.578947,0.916667,0.709677


### Select videos for analysis

In [29]:
good_pro_vid = 'test_match1/1_07_03'
bad_pro_vid = 'test_match3/1_08_09'
good_am_vid = 'match26/1_00_03'
bad_am_vid = 'match25/1_01_00'
test_vids = [good_pro_vid, bad_pro_vid, good_am_vid, bad_am_vid]

In [47]:
num = 3

if num in [0,1]:
    prefix = 'pro'
elif num in [2,3]:
    prefix = 'am_singles'
test_vid = test_vids[num]
opt = 'None'

matchdir, basename = test_vid.split('/')
labelfile = os.path.join(LABEL_ROOTDIR, prefix, matchdir, 'player_hit', basename + '.mp4_player_hit.csv')
probfile_dom = os.path.join(DOMAIN_ROOTDIR, prefix, matchdir, basename + '_probs.npy')
probfile_resnet = os.path.join(RESNET_ROOTDIR, prefix, matchdir, basename + '_probs.npy')

pred_labels_dom = domain_probs2labels(probfile_dom, opt=opt, conf_thresh=CONF_THRESH)
pred_labels_resnet = resnet_probs2labels(probfile_resnet, opt=opt, conf_thresh=CONF_THRESH)
gt_labels = gtfile2labels(labelfile)
gt_labels = gt_labels[OFFSET:]

print(gt_labels)
print()
print(pred_labels_dom)
print()
print(pred_labels_resnet[OFFSET:])

no optimisation used
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

[2 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 2
 2 2 2 2 2 2 2 2 2 0 0 0 0 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2
 2 2 2 2 2 2 0 0 0 0 0 0 0 0 1 1 1 0 0 1 1 1 1 1 1 0 0 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 2 2 2 2 2 2 2 2 2