In [1]:
from rekall import IntervalSetMapping, IntervalSet, Interval
from rekall.bounds import Bounds1D
from rekall.predicates import *

In [2]:
import numpy as np 

In [3]:
L_dev = np.load('/dfs/scratch1/danfu/rekall_experiments/heartmri/L_test.npy')

In [4]:
y_dev = np.load('/dfs/scratch1/danfu/rekall_experiments/heartmri/y_test.npy')

In [5]:
L_dev.shape

(5, 540)

In [6]:
y_dev.shape

(540,)

In [7]:
gt_ism_frame = IntervalSetMapping({
    i: IntervalSet([
        Interval(Bounds1D(frame, frame+1), payload = y_dev.tolist()[i*6 + frame])
        for frame in range(6)
    ])
    for i in range(int(y_dev.shape[0] / 6))
})
gt_ism_patient = gt_ism_frame.coalesce(
    ('t1', 't2'), Bounds1D.span
)

In [8]:
lf_ism = IntervalSetMapping({
    i: IntervalSet([
        Interval(Bounds1D(frame, frame+1), payload=lfs)
        for frame, lfs in enumerate(L_dev.T.tolist()[i * 6:i*6 + 6])
    ])
    for i in range(int(L_dev.T.shape[0] / 6))
})
frame_mv = lf_ism.map(
    lambda intrvl: Interval(
        intrvl['bounds'],
        payload = [1 if intrvl['payload'].count(1) > intrvl['payload'].count(-1) else -1]
    )
)
patient_mv = frame_mv.coalesce(
    ('t1', 't2'), Bounds1D.span, lambda p1, p2: p1 + p2
).map(
    lambda intrvl: Interval(
        intrvl['bounds'],
        payload = 1 if intrvl['payload'].count(1) > intrvl['payload'].count(-1) else -1
    )
)

In [9]:
def tp_fp_fn_patient(predicted, gt):
    tp = predicted.filter(
        payload_satisfies(lambda p: p == 1)
    ).filter_against(
        gt,
        predicate = lambda i1, i2: i1['payload'] == i2['payload']
    )
    
    fp = predicted.filter(
        payload_satisfies(lambda p: p == 1)
    ).minus(tp)
    
    fn = gt.filter(
        payload_satisfies(lambda p: p == 1)
    ).minus(
        predicted.filter(
            payload_satisfies(lambda p: p == 1)
        )
    )
    
    return tp, fp, fn

In [10]:
def prf1_patient(predicted, gt):
    tp, fp, fn = tp_fp_fn_patient(predicted, gt)
    
    pre = len(tp) / (len(tp) + len(fp))
    rec = len(tp) / (len(tp) + len(fn))
    f1 = 2 * pre * rec / (pre + rec)
    
    return pre, rec, f1, len(tp), len(fp), len(fn)

In [38]:
def bav_query(lf_ism):
    three_with_four_positive = lf_ism.map(
        lambda intrvl: Interval(
            intrvl['bounds'],
            payload = [1 if intrvl['payload'].count(1) >= 4 else -1]
        )
    ).coalesce(
        ('t1', 't2'), Bounds1D.span, lambda p1, p2: p1 + p2
    ).map(
        lambda intrvl: Interval(
            intrvl['bounds'],
            payload = (
                1 if intrvl['payload'].count(1) >= 3
                else -1
            )
        )
    )
    
    two_columns_multiple_negative = lf_ism.map(
        lambda intrvl: Interval(
            intrvl['bounds'],
            payload = [1 if lf == -1 else 0 for lf in intrvl['payload']]
        )
    ).coalesce(
        ('t1', 't2'), Bounds1D.span, lambda p1, p2: [i + j for i, j in zip(p1, p2)]
    ).map(
        lambda intrvl: Interval(
            intrvl['bounds'],
            payload = (
                -1 if len([1 for count in intrvl['payload'] if count > 1]) > 1
                else 1
            )
        )
    )
    
    three_four_positive_columns_not_negative = three_with_four_positive.join(
        two_columns_multiple_negative,
        predicate = equal(),
        merge_op = lambda i1, i2: Interval(
            i1['bounds'],
            1 if i1['payload'] == 1 and i2['payload'] == 1 else -1
        )
    )
    
    second_all_fifth_all = lf_ism.map(
        lambda intrvl: Interval(
            intrvl['bounds'],
            payload = [intrvl['payload'][1]]
        )
    ).coalesce(
        ('t1', 't2'), Bounds1D.span, lambda p1, p2: p1 + p2
    ).map(
        lambda intrvl: Interval(
            intrvl['bounds'],
            payload = (
                1 if intrvl['payload'].count(1) == 6
                else -1
            )
        )
    ).join(
        lf_ism.map(
            lambda intrvl: Interval(
                intrvl['bounds'],
                payload = [intrvl['payload'][4]]
            )
        ).coalesce(
            ('t1', 't2'), Bounds1D.span, lambda p1, p2: p1 + p2
        ).map(
            lambda intrvl: Interval(
                intrvl['bounds'],
                payload = (
                    1 if intrvl['payload'].count(1) == 6
                    else -1
                )
            )
        ),
        predicate = equal(),
        merge_op = lambda i1, i2: Interval(
            i1['bounds'],
            1 if i1['payload'] == 1 and i2['payload'] == 1 else -1
        )
    )
    
    attempt_final = three_four_positive_columns_not_negative.join(
        second_all_fifth_all,
        predicate = equal(),
        merge_op = lambda i1, i2: Interval(
            i1['bounds'],
            1 if i1['payload'] == 1 or i2['payload'] == 1 else -1
        )
    )
    
    return attempt_final

In [39]:
bav_result = bav_query(lf_ism)

In [40]:
print(prf1_patient(bav_result, gt_ism_patient))

(0.6666666666666666, 0.6666666666666666, 0.6666666666666666, 2, 1, 1)
