In [1]:
import os
import wfdb
import pickle
import pandas as pd
import numpy as np
from keras.utils import to_categorical
from tqdm import tqdm_notebook
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from keras.models import load_model
import matplotlib.animation as animation
np.set_printoptions(suppress=True)
%matplotlib inline

Using TensorFlow backend.


In [2]:
cross_idx = 3
win_len = 100
dataset_root = "./dataset_RRI/"
model_root = "./results/models/"
keys = [i.split(".")[0] for i in os.listdir(dataset_root) if "pickle" in i and "cross" not in i ]
model_idx = {
    0:8,
    1:12,
    2:21,
    3:33,
    4:42
}

In [3]:
def int2label(label):
    if label==0:
        return "Normal"
    elif label==1:
        return "AF"

In [4]:
def get_normalization_param(test_keys):
    max_norm = 0
    min_norm = 10000
    for key in keys:
        if key in test_keys:continue
        with open(os.path.join(dataset_root,key+".pickle"),"rb") as f:
            dataset = pickle.load(f)
        X = dataset["X"]
        if np.max(X) > max_norm:
            max_norm = np.max(X)
        if np.min(X) < min_norm:
            min_norm = np.min(X)
    return min_norm,max_norm

In [5]:
def update(i, confidence,confidences,segments,labels,offset,ax1,ax2):
    if i !=0:
        ax1.cla()
        ax2.cla()
    c = confidence[i]
    segment = segments[i]
    label = labels[i]
    if c<0.5:
        color = "red"
    elif c<0.7:
        color = "yellow"
    else:
        color = "green"
        
    ax1.plot([i+j+offset for j in range(win_len)],segment,color="b")
    ax1.set_ylim(0,1)
    ax1.set_xlabel("RRI Index",fontsize=15)
    ax1.set_ylabel("Normalized RRI value",fontsize=15)
    ax1.set_title("Label:"+int2label(label),fontsize=15)
    ax2.bar(range(1, 2), c, color=color, align='center')
    ax2.set_title("Confidence of predict true label",fontsize=15)
    ax2.set_xticklabels([])
    ax2.set_ylim(0,1)

In [6]:
def plot_animation(seg_idx,segments,labels,confidence,confidences,save_dir):
    fig,(ax1,ax2) = plt.subplots(1,2,figsize=(15,5)) #figure objectを取得
    ax3 = ax2.twinx()
    ax3.boxplot(confidences)
    ax3.set_ylim(ax2.get_ylim())
    ax3.set_yticklabels([])
    ani = animation.FuncAnimation(fig, update, fargs = (confidence,confidences,segments,labels,seg_idx,ax1,ax2), interval = 100)
    save_path = os.path.join(save_dir,str(seg_idx)+".mp4")
    ani.save(save_path, writer = 'ffmpeg')
    plt.close()

In [7]:
def check_shift_invariance_once(X,label,seg_idx,save_dir,confidences,return_confidence=False):
    interval = 100
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    continue_seg = np.concatenate( [X[seg_idx],X[seg_idx+1]],axis=0)
    sliding_win_start = 0
    sliding_win_end = 100
    step = 1
    confidence = []
    shifted_segments = []
    shifted_labels = []
    
    consistency_list = [] 
    
    while(sliding_win_end<len(continue_seg)):
        data = continue_seg[sliding_win_start:sliding_win_end,:]

        logits = model.predict(np.expand_dims(data,axis=0))[0]
        confidence.append(logits[label])
        shifted_segments.append(data)
        shifted_labels.append(label)
        if sliding_win_start!=0:
            consistency_list.append(np.argmax(logits)==label)
        sliding_win_start += step
        sliding_win_end += step
    
    if not return_confidence:
        plot_animation(seg_idx,shifted_segments,shifted_labels,confidence,confidences,save_dir)
    else:
        if len(np.where(np.array(confidence) < 0.5)[0])>0:
            print(seg_idx,"contains red predictions:",len(np.where(np.array(confidence) < 0.5)[0]))
        return confidence,consistency_list

In [8]:
def check_shift_invariance(X,y,test_key,save_root = "./Animations"):
    if not os.path.exists(save_root):
        os.mkdir(save_root)
    save_dir = os.path.join(save_root,test_key)

    start = dataset["start"]
    end = dataset["end"]
    confidences = []
    shifted_segments = []
    shifted_labels = []
    consistency = []
    if len(X)==1:
        print("cosistency of ",test_key,"can not calculate cosistency")
        return []
        
    for seg_idx in tqdm_notebook(range(len(X)-1)):
        seg = X[seg_idx]
        next_seg = X[seg_idx+1]
        seg_end = end[seg_idx]
        next_seg_start = start[seg_idx+1]
        label = np.argmax(y[seg_idx])
        next_label = np.argmax(y[seg_idx+1])
        if not (seg_end == next_seg_start and label==next_label):continue
        confidence,consistency_list =  check_shift_invariance_once(X,label,seg_idx,save_dir,confidences=None,return_confidence=True)
        confidences.append(confidence)
        consistency.append(consistency_list)
        
    confidences = np.concatenate(confidences,axis=0)
    consistency = np.concatenate(consistency,axis=0)
    consistency = np.mean(consistency)
    print("cosistency of ",test_key,consistency)
    
    for seg_idx in tqdm_notebook(range(len(X)-1)):
        seg = X[seg_idx]
        next_seg = X[seg_idx+1]
        seg_end = end[seg_idx]
        next_seg_start = start[seg_idx+1]
        label = np.argmax(y[seg_idx])
        next_label = np.argmax(y[seg_idx+1])
        if not (seg_end == next_seg_start and label==next_label):continue
        check_shift_invariance_once(X,label,seg_idx,save_dir,confidences,return_confidence=False)
    return consistency

In [9]:
for cross_idx in range(5):
    with open(os.path.join(dataset_root,"dataset-cross"+str(cross_idx)+".pickle"),"rb") as f:
            dataset = pickle.load(f)
            test_keys = dataset["test_key"]
            X_test = dataset["X_test"]
            y_test = dataset["y_test"]
    model = load_model(os.path.join(model_root,str(model_idx[cross_idx])+"-model.h5"))
    min_norm,max_norm = get_normalization_param(test_keys)
    consistencys = []

    for test_key in test_keys:
        with open(os.path.join(dataset_root,test_key+".pickle"),"rb") as f:
            dataset = pickle.load(f)
        X = np.expand_dims( (dataset["X"] - min_norm) / (max_norm - min_norm) ,axis=2)
        y = to_categorical (dataset["y"],num_classes=2)

        consistency = check_shift_invariance(X,y,test_key,save_root = os.path.join( "./Animations","cross-"+str(cross_idx)))
        consistencys.append(consistency)

    consis = np.mean(consistencys)
    print("consistency for cross",cross_idx,"=",consis)





HBox(children=(IntProgress(value=0, max=49), HTML(value='')))

4 contains red predictions: 9
34 contains red predictions: 2

cosistency of  08434 0.9977324263038548


HBox(children=(IntProgress(value=0, max=49), HTML(value='')))




HBox(children=(IntProgress(value=0, max=22), HTML(value='')))

11 contains red predictions: 4
14 contains red predictions: 17
15 contains red predictions: 13

cosistency of  04746 0.9843893480257117


HBox(children=(IntProgress(value=0, max=22), HTML(value='')))




HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

3 contains red predictions: 5
4 contains red predictions: 16
11 contains red predictions: 6
12 contains red predictions: 1
14 contains red predictions: 3
15 contains red predictions: 12
16 contains red predictions: 3
30 contains red predictions: 7
31 contains red predictions: 3
34 contains red predictions: 15
39 contains red predictions: 1
58 contains red predictions: 2
59 contains red predictions: 5
60 contains red predictions: 1
64 contains red predictions: 5
67 contains red predictions: 1
73 contains red predictions: 2
75 contains red predictions: 12
76 contains red predictions: 1
83 contains red predictions: 1
85 contains red predictions: 8
86 contains red predictions: 4
92 contains red predictions: 5
96 contains red predictions: 1
103 contains red predictions: 1
104 contains red predictions: 2
107 contains red predictions: 26
108 contains red predictions: 1
112 contains red predictions: 1
113 contains red predictions: 1
144 contains red predictions: 1
147 contains red predictions:

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))


consistency for cross 0 = 0.9759052740751349


HBox(children=(IntProgress(value=0, max=43), HTML(value='')))

1 contains red predictions: 1
2 contains red predictions: 35
12 contains red predictions: 38
32 contains red predictions: 1
33 contains red predictions: 20
34 contains red predictions: 63
35 contains red predictions: 18
39 contains red predictions: 52
40 contains red predictions: 87
41 contains red predictions: 10
42 contains red predictions: 60

cosistency of  08219 0.9102654451491661


HBox(children=(IntProgress(value=0, max=43), HTML(value='')))




HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

0 contains red predictions: 2
1 contains red predictions: 1
3 contains red predictions: 9
4 contains red predictions: 52
5 contains red predictions: 3
6 contains red predictions: 2
10 contains red predictions: 4
11 contains red predictions: 6
12 contains red predictions: 4
14 contains red predictions: 7
15 contains red predictions: 38
16 contains red predictions: 2
17 contains red predictions: 3
18 contains red predictions: 3
21 contains red predictions: 1
24 contains red predictions: 6
25 contains red predictions: 2
26 contains red predictions: 1
30 contains red predictions: 15
31 contains red predictions: 18
32 contains red predictions: 5
33 contains red predictions: 4
34 contains red predictions: 18
35 contains red predictions: 1
36 contains red predictions: 1
37 contains red predictions: 2
46 contains red predictions: 3
47 contains red predictions: 2
48 contains red predictions: 2
55 contains red predictions: 1
58 contains red predictions: 8
59 contains red predictions: 20
60 conta

375 contains red predictions: 9
376 contains red predictions: 53
377 contains red predictions: 37
378 contains red predictions: 45
379 contains red predictions: 24
380 contains red predictions: 49
381 contains red predictions: 23
382 contains red predictions: 10
383 contains red predictions: 5
384 contains red predictions: 24
385 contains red predictions: 25
386 contains red predictions: 20
387 contains red predictions: 1
388 contains red predictions: 20
389 contains red predictions: 1
390 contains red predictions: 13

cosistency of  07162 0.9184427394146064


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))




HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

26 contains red predictions: 21
27 contains red predictions: 38
31 contains red predictions: 15

cosistency of  05091 0.9766414141414141


HBox(children=(IntProgress(value=0, max=32), HTML(value='')))


consistency for cross 1 = 0.9351165329017288
cosistency of  07879 can not calculate cosistency


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

3 contains red predictions: 1
4 contains red predictions: 9
107 contains red predictions: 2
216 contains red predictions: 1
220 contains red predictions: 2
224 contains red predictions: 4
225 contains red predictions: 4
226 contains red predictions: 4
227 contains red predictions: 1
235 contains red predictions: 2
248 contains red predictions: 1
256 contains red predictions: 1
265 contains red predictions: 1
271 contains red predictions: 3
283 contains red predictions: 1
298 contains red predictions: 1
306 contains red predictions: 2
308 contains red predictions: 3
320 contains red predictions: 6
321 contains red predictions: 14
326 contains red predictions: 2
329 contains red predictions: 2
330 contains red predictions: 1
331 contains red predictions: 2
332 contains red predictions: 8
333 contains red predictions: 1
335 contains red predictions: 1
338 contains red predictions: 2
339 contains red predictions: 3
342 contains red predictions: 1
347 contains red predictions: 5
350 contain

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))




HBox(children=(IntProgress(value=0, max=71), HTML(value='')))


cosistency of  08455 1.0


HBox(children=(IntProgress(value=0, max=71), HTML(value='')))


consistency for cross 2 = []


HBox(children=(IntProgress(value=0, max=54), HTML(value='')))

0 contains red predictions: 19
1 contains red predictions: 15
2 contains red predictions: 12
3 contains red predictions: 2
21 contains red predictions: 6
29 contains red predictions: 21
30 contains red predictions: 97
31 contains red predictions: 100
32 contains red predictions: 100
33 contains red predictions: 100
34 contains red predictions: 44
35 contains red predictions: 4
49 contains red predictions: 2
53 contains red predictions: 24

cosistency of  05261 0.8991769547325102


HBox(children=(IntProgress(value=0, max=54), HTML(value='')))




HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

14 contains red predictions: 1
16 contains red predictions: 38

cosistency of  06453 0.9781144781144782


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))




HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

4 contains red predictions: 9
5 contains red predictions: 2
6 contains red predictions: 1
11 contains red predictions: 4
14 contains red predictions: 1
16 contains red predictions: 5
17 contains red predictions: 1
18 contains red predictions: 5
24 contains red predictions: 1
30 contains red predictions: 2
34 contains red predictions: 1
38 contains red predictions: 1
67 contains red predictions: 1
77 contains red predictions: 2
84 contains red predictions: 1
85 contains red predictions: 7
94 contains red predictions: 1
95 contains red predictions: 1
96 contains red predictions: 5
100 contains red predictions: 3
107 contains red predictions: 8
112 contains red predictions: 1
130 contains red predictions: 1
137 contains red predictions: 1
152 contains red predictions: 1
163 contains red predictions: 1
186 contains red predictions: 1
189 contains red predictions: 3
190 contains red predictions: 1
194 contains red predictions: 1
195 contains red predictions: 3
196 contains red predictions: 

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))


consistency for cross 3 = 0.9400834782098398


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

134 contains red predictions: 1
195 contains red predictions: 2
224 contains red predictions: 1
226 contains red predictions: 4
230 contains red predictions: 1
232 contains red predictions: 2
269 contains red predictions: 2
270 contains red predictions: 1
273 contains red predictions: 1
274 contains red predictions: 1
275 contains red predictions: 7
276 contains red predictions: 1
286 contains red predictions: 1
308 contains red predictions: 1
329 contains red predictions: 1
332 contains red predictions: 4
333 contains red predictions: 1
335 contains red predictions: 2
338 contains red predictions: 1
350 contains red predictions: 19
351 contains red predictions: 13
353 contains red predictions: 7
354 contains red predictions: 1
356 contains red predictions: 1
357 contains red predictions: 7
361 contains red predictions: 4
362 contains red predictions: 13
363 contains red predictions: 13
364 contains red predictions: 1
365 contains red predictions: 7
366 contains red predictions: 1
367 

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


cosistency of  04043 1.0


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))




HBox(children=(IntProgress(value=0, max=65), HTML(value='')))


cosistency of  08405 1.0


HBox(children=(IntProgress(value=0, max=65), HTML(value='')))


consistency for cross 4 = 0.9969774471053244
