In [1]:
from rekall import IntervalSetMapping, IntervalSet, Interval
from rekall.bounds import Bounds1D
from rekall.predicates import *

In [2]:
import numpy as np 

In [3]:
a = np.load('/dfs/scratch0/paroma/heart-MRI/bav-classification/coral/gen_unknown_L_dev.npy')

In [4]:
b = np.load('/dfs/scratch0/paroma/heart-MRI/bav-classification/coral/gen_unknown_y_dev.npy')

In [5]:
gt_ism_frame = IntervalSetMapping({
    i: IntervalSet([
        Interval(Bounds1D(frame, frame+1), payload = b.tolist()[i*6 + frame])
        for frame in range(6)
    ])
    for i in range(int(b.shape[0] / 6))
})
gt_ism_patient = gt_ism_frame.coalesce(
    ('t1', 't2'), Bounds1D.span
)

In [6]:
lf_ism = IntervalSetMapping({
    i: IntervalSet([
        Interval(Bounds1D(frame, frame+1), payload=lfs)
        for frame, lfs in enumerate(a.T.tolist()[i * 6:i*6 + 6])
    ])
    for i in range(int(a.T.shape[0] / 6))
})
frame_mv = lf_ism.map(
    lambda intrvl: Interval(
        intrvl['bounds'],
        payload = [1 if intrvl['payload'].count(1) > intrvl['payload'].count(-1) else -1]
    )
)
patient_mv = frame_mv.coalesce(
    ('t1', 't2'), Bounds1D.span, lambda p1, p2: p1 + p2
).map(
    lambda intrvl: Interval(
        intrvl['bounds'],
        payload = 1 if intrvl['payload'].count(1) > intrvl['payload'].count(-1) else -1
    )
)

In [7]:
def tp_fp_fn_patient(predicted, gt):
    tp = predicted.filter(
        payload_satisfies(lambda p: p == 1)
    ).filter_against(
        gt,
        predicate = lambda i1, i2: i1['payload'] == i2['payload']
    )
    
    fp = predicted.filter(
        payload_satisfies(lambda p: p == 1)
    ).minus(tp)
    
    fn = gt.filter(
        payload_satisfies(lambda p: p == 1)
    ).minus(
        predicted.filter(
            payload_satisfies(lambda p: p == 1)
        )
    )
    
    return tp, fp, fn

In [8]:
def prf1_patient(predicted, gt):
    tp, fp, fn = tp_fp_fn_patient(predicted, gt)
    
    pre = len(tp) / (len(tp) + len(fp))
    rec = len(tp) / (len(tp) + len(fn))
    f1 = 2 * pre * rec / (pre + rec)
    
    return pre, rec, f1, len(tp), len(fp), len(fn)

In [10]:
def bav_query(lf_ism):
    second_lf_unsure = lf_ism.map(
        lambda intrvl: Interval(
            intrvl['bounds'],
            payload = [intrvl['payload'][1]]
        )
    ).coalesce(
        ('t1', 't2'), Bounds1D.span, lambda p1, p2: p1 + p2
    ).map(
        lambda intrvl: Interval(
            intrvl['bounds'],
            payload = (
                1 if intrvl['payload'].count(1) > 4
                else -1
            )
        )
    )

    fourth_lf_negative = lf_ism.map(
        lambda intrvl: Interval(
            intrvl['bounds'],
            payload = [intrvl['payload'][3]]
        )
    ).coalesce(
        ('t1', 't2'), Bounds1D.span, lambda p1, p2: p1 + p2
    ).map(
        lambda intrvl: Interval(
            intrvl['bounds'],
            payload = (
                -1 if intrvl['payload'].count(-1) in (2, 3)
                else 1
            )
        )
    )

    second_fourth = second_lf_unsure.join(
        fourth_lf_negative,
        predicate = equal(),
        merge_op = lambda i1, i2: Interval(
            i1['bounds'],
            1 if i1['payload'] == 1 and i2['payload'] == 1 else -1
        )
    )

    second_has_zero = lf_ism.map(
        lambda intrvl: Interval(
            intrvl['bounds'],
            payload = [intrvl['payload'][1]]
        )
    ).coalesce(
        ('t1', 't2'), Bounds1D.span, lambda p1, p2: p1 + p2
    ).map(
        lambda intrvl: Interval(
            intrvl['bounds'],
            payload = (
                -1 if intrvl['payload'].count(0) > 0
                else 1
            )
        )
    )

    attempt3 = second_fourth.join(
        second_has_zero,
        predicate = equal(),
        merge_op = lambda i1, i2: Interval(
            i1['bounds'],
            1 if i1['payload'] == 1 and i2['payload'] == 1 else -1
        )
    )

    fifth_lf_starts_or_ends_with_0 = lf_ism.map(
        lambda intrvl: Interval(
            intrvl['bounds'],
            payload = [intrvl['payload'][4]]
        )
    ).coalesce(
        ('t1', 't2'), Bounds1D.span, lambda p1, p2: p1 + p2
    ).map(
        lambda intrvl: Interval(
            intrvl['bounds'],
            payload = (
                -1 if intrvl['payload'][0] == 0 or intrvl['payload'][-1] == 0
                else 1
            )
        )
    )

    attempt4 = attempt3.join(
        fifth_lf_starts_or_ends_with_0,
        predicate = equal(),
        merge_op = lambda i1, i2: Interval(
            i1['bounds'],
            1 if i1['payload'] == 1 and i2['payload'] == 1 else -1
        )
    )

    three_negatives_twice = lf_ism.map(
        lambda intrvl: Interval(
            intrvl['bounds'],
            payload = [-1 if intrvl['payload'].count(-1) >= 3 else 1]
        )
    ).coalesce(
        ('t1', 't2'), Bounds1D.span, lambda p1, p2: p1 + p2
    ).map(
        lambda intrvl: Interval(
            intrvl['bounds'],
            payload = (
                -1 if intrvl['payload'].count(-1) >= 2
                else 1
            )
        )
    )

    fifth_lf_less_than_three_positive = lf_ism.map(
        lambda intrvl: Interval(
            intrvl['bounds'],
            payload = [intrvl['payload'][4]]
        )
    ).coalesce(
        ('t1', 't2'), Bounds1D.span, lambda p1, p2: p1 + p2
    ).map(
        lambda intrvl: Interval(
            intrvl['bounds'],
            payload = (
                -1 if (intrvl['payload'].count(-1) > 3 or
                      intrvl['payload'].count(-1) > 2 and
                      intrvl['payload'].count(0) == 1)
                else 1
            )
        )
    )

    attempt5 = attempt4.join(
        three_negatives_twice,
        predicate = equal(),
        merge_op = lambda i1, i2: Interval(
            i1['bounds'],
            1 if i1['payload'] == 1 and i2['payload'] == 1 else -1
        )
    ).join(
        fifth_lf_less_than_three_positive,
        predicate = equal(),
        merge_op = lambda i1, i2: Interval(
            i1['bounds'],
            1 if i1['payload'] == 1 and i2['payload'] == 1 else -1
        )
    )
    
    return attempt5, attempt4, attempt3, second_fourth

In [11]:
bav_result = bav_query(lf_ism)

In [13]:
prf1_patient(bav_result[0], gt_ism_patient)

(0.6, 0.8571428571428571, 0.7058823529411764, 6, 4, 1)

# Try it on the unlabelled training set

In [14]:
c = np.load('/dfs/scratch0/paroma/heart-MRI/bav-classification/coral/gen_unknown_L_train.npy')

In [22]:
lf_train_ism = IntervalSetMapping({
    i: IntervalSet([
        Interval(Bounds1D(frame, frame+1), payload=lfs)
        for frame, lfs in enumerate(c.T.tolist()[i * 6:i*6 + 6])
    ])
    for i in range(int(c.T.shape[0] / 6))
})

In [23]:
bav_result_train = bav_query(lf_train_ism)

In [24]:
for bav_result in bav_result_train:
    print(len(bav_result.filter(payload_satisfies(lambda p: p == 1))))
    print(len(bav_result.filter(payload_satisfies(lambda p: p == -1))))
    print()

443
4102

570
3975

719
3826

948
3597

