# Evaluate Myth Commercial Detector

In [1]:
import pickle
import json
import rekall

In [35]:
with open('/app/data/myth-commercial-data/test/myth_test_output.json', 'r') as f:
    myth_output = json.load(f)
with open('/app/data/myth-commercial-data/val/mythcomm_val_output.json', 'r') as f:
    myth_output_val = json.load(f)

In [26]:
with open('/app/data/commercials/data/test.txt', 'r') as f:
    test_gt = [
        l.strip().split()
        for l in f.readlines()
    ]
with open('/app/data/commercials/data/val.txt', 'r') as f:
    val_gt = [
        l.strip().split()
        for l in f.readlines()
    ]

In [4]:
train_set = [293, 514, 529, 715, 755, 763, 1595, 2648, 3317, 3459, 3769,
             3952, 4029, 4143, 4421, 4611, 5281, 6161, 6185, 6532, 7206, 
             7262, 8220, 8859, 9480, 9499, 9758, 10335, 10621, 11003, 11293, 
             11555, 11792, 11940, 12655, 13058, 13141, 13247, 13556, 13791, 
             13827, 13927, 14482, 14638, 15855, 15916, 15964, 16215, 16542, 
             16599, 16693, 16879, 16964, 17458, 17983, 19882, 19959, 20230, 
             20380, 20450, 20920, 21572, 23181, 23184, 24193, 24784, 25463, 
             26060, 26231, 26386, 26824, 26836, 27175, 27188, 27410, 27927, 
             27963, 28613, 29001, 31378, 31994, 32472, 33004, 33387, 33541, 
             33977, 34050, 34359, 34642, 36211, 37107, 37113, 37170, 37927, 
             38275, 38420, 40856, 41235, 41480, 41725, 41836, 42027, 42362, 
             42756, 44998, 45472, 45573, 45645, 45655, 45698, 45744, 46041, 
             46058, 46753, 48140, 49225, 49931, 50883, 51145, 51175, 51469, 
             51482, 51606, 52075, 52749, 53355, 53684, 53932, 54238, 55016, 
             56051, 56764, 57310, 57384, 57592, 57708, 57798, 57804, 57962, 
             58389, 59122, 59398, 60186, 60581, 61359, 61930, 62400, 66092, 
             66666, 79265, 80121, 93033, 94663, 112580, 114248, 115653, 
             123531, 124234, 128012, 133584, 134007, 135812, 136446, 148080, 
             158981, 158982, 192899, 205173]
val_set = [559, 1791, 3730, 3754, 10323, 11579, 17386, 20689, 24847, 24992, 
           26175, 33800, 40203, 40267, 43637, 50561, 54377, 57990, 59028, 
           63965, 67300]
test_set = [385, 8697, 9215, 9901, 12837, 13993, 14925, 18700, 23541, 31902,
            32996, 36755, 50164, 52945, 55711, 57748, 59789, 60433, 136732,
            149097, 169420]

In [9]:
from rekall import Interval, IntervalSet, IntervalSetMapping
from rekall.bounds import Bounds3D
from rekall.predicates import *

In [36]:
test_gt_by_video = {}
for video_id, idx, label in test_gt:
    video_id = int(video_id)
    idx = int(idx)
    label = int(label)
    if video_id not in test_gt_by_video:
        test_gt_by_video[video_id] = []
    test_gt_by_video[video_id].append(
        Interval(Bounds3D(idx * 10, (idx + 1) * 10), label)
    )
    
gt_comms = IntervalSetMapping({
    video_id: IntervalSet(test_gt_by_video[video_id])
    for video_id in test_gt_by_video
})

val_gt_by_video = {}
for video_id, idx, label in val_gt:
    video_id = int(video_id)
    idx = int(idx)
    label = int(label)
    if video_id not in val_gt_by_video:
        val_gt_by_video[video_id] = []
    val_gt_by_video[video_id].append(
        Interval(Bounds3D(idx * 10, (idx + 1) * 10), label)
    )
    
gt_comms_val = IntervalSetMapping({
    video_id: IntervalSet(val_gt_by_video[video_id])
    for video_id in val_gt_by_video
})

In [41]:
myth_comms = IntervalSetMapping({
    int(video_id): IntervalSet([
        Interval(Bounds3D(comm[0], comm[1]))
        for comm in myth_output[video_id]
    ])
    for video_id in myth_output
})

myth_comms_val = IntervalSetMapping({
    int(video_id): IntervalSet([
        Interval(Bounds3D(comm[0], comm[1]))
        for comm in myth_output_val[video_id]
    ])
    for video_id in myth_output_val
})

In [19]:
positives = gt_comms.filter_against(
    myth_comms,
    predicate = overlaps(),
    window=0
)
negatives = gt_comms.minus(positives, axis=('t1', 't2'), window=0)

In [20]:
tp = positives.filter(lambda intrvl: intrvl['payload'] == 1)
fp = positives.filter(lambda intrvl: intrvl['payload'] == 0)
fn = negatives.filter(lambda intrvl: intrvl['payload'] == 1)

In [24]:
def precision_recall_f1(tp, fp, fn):
    def sum_values(obj):
        return sum([v for v in list(obj.values())])
    tp_count = sum_values(tp.size())
    fp_count = sum_values(fp.size())
    fn_count = sum_values(fn.size())
    
    precision = tp_count / (tp_count + fp_count)
    recall = tp_count / (tp_count + fn_count)
    f1 = 2 * precision * recall / (precision + recall)
    
    return (precision, recall, f1, tp_count, fp_count, fn_count)

In [25]:
precision_recall_f1(tp, fp, fn)

(0.9653779572994806, 0.7056094474905104, 0.8153021442495126, 1673, 60, 698)

In [42]:
positives_val = gt_comms_val.filter_against(
    myth_comms_val,
    predicate = overlaps(),
    window=0
)
negatives_val = gt_comms_val.minus(positives, axis=('t1', 't2'), window=0)

In [43]:
tp_val = positives_val.filter(lambda intrvl: intrvl['payload'] == 1)
fp_val = positives_val.filter(lambda intrvl: intrvl['payload'] == 0)
fn_val = negatives_val.filter(lambda intrvl: intrvl['payload'] == 1)

In [44]:
precision_recall_f1(tp_val, fp_val, fn_val)

(0.9526288391462779, 0.3988666085440279, 0.5622983561222922, 1830, 91, 2758)