In [1]:
import os
import wfdb
import pickle
import pandas as pd
import numpy as np
from keras.utils import to_categorical
from tqdm import tqdm_notebook
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from keras.models import load_model
import matplotlib.animation as animation
np.set_printoptions(suppress=True)
%matplotlib inline

Using TensorFlow backend.


In [2]:
cross_idx = 3
win_len = 100
dataset_root = "./dataset_RRI/"
model_root = "./results/models/"
keys = [i.split(".")[0] for i in os.listdir(dataset_root) if "pickle" in i and "cross" not in i ]
model_idx = {
    0:8,
    1:12,
    2:21,
    3:33,
    4:42
}

In [3]:
def int2label(label):
    if label==0:
        return "Normal"
    elif label==1:
        return "AF"

In [4]:
def get_normalization_param(test_keys):
    max_norm = 0
    min_norm = 10000
    for key in keys:
        if key in test_keys:continue
        with open(os.path.join(dataset_root,key+".pickle"),"rb") as f:
            dataset = pickle.load(f)
        X = dataset["X"]
        if np.max(X) > max_norm:
            max_norm = np.max(X)
        if np.min(X) < min_norm:
            min_norm = np.min(X)
    return min_norm,max_norm

In [5]:
def update(i, confidences,segments,labels,offset,ax1,ax2):
    if i !=0:
        ax1.cla()
        ax2.cla()
    c = confidences[i]
    segment = segments[i]
    label = labels[i]
    if c<0.5:
        color = "red"
    elif c<0.7:
        color = "yellow"
    else:
        color = "green"
    ax1.plot([i+j+offset for j in range(win_len)],segment,color="b")
    ax1.set_ylim(0,1)
    ax1.set_xlabel("RRI Index",fontsize=15)
    ax1.set_ylabel("Normalized RRI value",fontsize=15)
    ax1.set_title("Label:"+int2label(label),fontsize=15)
    ax2.bar(range(1, 2), c, color=color, align='center')
    ax2.set_title("Confidence of predict true label",fontsize=15)
    ax2.set_xticklabels([])
    ax2.set_ylim(0,1)

In [6]:
def plot_animation(seg_idx,segments,labels,confidences,save_dir):
    fig,(ax1,ax2) = plt.subplots(1,2,figsize=(15,5)) #figure objectを取得
    ax3 = ax2.twinx()
    ax3.boxplot(confidences)
    ax3.set_ylim(ax2.get_ylim())
    ax3.set_yticklabels([])
    ani = animation.FuncAnimation(fig, update, fargs = (confidences,segments,labels,seg_idx,ax1,ax2), interval = 100)
    save_path = os.path.join(save_dir,str(seg_idx)+".mp4")
    ani.save(save_path, writer = 'ffmpeg')
    plt.close()

In [7]:
def check_shift_invariance_once(X,label,seg_idx,save_dir,confidences,return_confidence=False):
    interval = 100
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    continue_seg = np.concatenate( [X[seg_idx],X[seg_idx+1]],axis=0)
    sliding_win_start = 0
    sliding_win_end = 100
    step = 1
    confidence = []
    shifted_segments = []
    shifted_labels = []
    
    consistency_list = [] 
    
    while(sliding_win_end<len(continue_seg)):
        data = continue_seg[sliding_win_start:sliding_win_end,:]

        logits = model.predict(np.expand_dims(data,axis=0))[0]
        confidence.append(logits[label])
        shifted_segments.append(data)
        shifted_labels.append(label)
        if sliding_win_start!=0:
            consistency_list.append(np.argmax(logits)==label)
        sliding_win_start += step
        sliding_win_end += step
    
    if not return_confidence:
        plot_animation(seg_idx,shifted_segments,shifted_labels,confidences,save_dir)
    else:
        if len(np.where(np.array(confidence) < 0.5)[0])>0:
            print(seg_idx)
        return confidence,consistency_list

In [15]:
def check_shift_invariance(X,y,test_key,save_root = "./Animations"):
    if not os.path.exists(save_root):
        os.mkdir(save_root)
    save_dir = os.path.join(save_root,test_key)

    start = dataset["start"]
    end = dataset["end"]
    confidences = []
    shifted_segments = []
    shifted_labels = []
    consistency = []
    if len(X)==1:
        print("cosistency of ",test_key,"can not calculate cosistency")
        return []
        
    for seg_idx in tqdm_notebook(range(len(X)-1)):
        seg = X[seg_idx]
        next_seg = X[seg_idx+1]
        seg_end = end[seg_idx]
        next_seg_start = start[seg_idx+1]
        label = np.argmax(y[seg_idx])
        next_label = np.argmax(y[seg_idx+1])
        if not (seg_end == next_seg_start and label==next_label):continue
        confidence,consistency_list =  check_shift_invariance_once(X,label,seg_idx,save_dir,confidences=None,return_confidence=True)
        confidences.append(confidence)
        consistency.append(consistency_list)
        
    confidences = np.concatenate(confidences,axis=0)
    consistency = np.concatenate(consistency,axis=0)
    consistency = np.mean(consistency)
    print("cosistency of ",test_key,consistency)
    
    for seg_idx in tqdm_notebook(range(len(X)-1)):
        seg = X[seg_idx]
        next_seg = X[seg_idx+1]
        seg_end = end[seg_idx]
        next_seg_start = start[seg_idx+1]
        label = np.argmax(y[seg_idx])
        next_label = np.argmax(y[seg_idx+1])
        if not (seg_end == next_seg_start and label==next_label):continue
        check_shift_invariance_once(X,label,seg_idx,save_dir,confidences,return_confidence=False)
    return consistency

In [16]:
for cross_idx in range(5):
    if cross_idx < 2:continue
    with open(os.path.join(dataset_root,"dataset-cross"+str(cross_idx)+".pickle"),"rb") as f:
            dataset = pickle.load(f)
            test_keys = dataset["test_key"]
            X_test = dataset["X_test"]
            y_test = dataset["y_test"]
    model = load_model(os.path.join(model_root,str(model_idx[cross_idx])+"-model.h5"))
    min_norm,max_norm = get_normalization_param(test_keys)
    consistencys = []

    for test_key in test_keys:
        with open(os.path.join(dataset_root,test_key+".pickle"),"rb") as f:
            dataset = pickle.load(f)
        X = np.expand_dims( (dataset["X"] - min_norm) / (max_norm - min_norm) ,axis=2)
        y = to_categorical (dataset["y"],num_classes=2)

        consistency = check_shift_invariance(X,y,test_key,save_root = os.path.join( "./Animations","cross-"+str(cross_idx)))
        consistencys.append(consistency)

    consis = np.mean(consistencys)
    print("consistency for cross",cross_idx,"=",consis)

cosistency of  07879 can not calculate cosistency


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

3
4
107
216
220
224
225
226
227
235
248
256
265
271
283
298
306
308
320
321
326
329
330
331
332
333
335
338
339
342
347
350
351
353
354
355
356
357
359
360
362
366
367
368
369
370
371
373
374
375
376
378
379
380
381
382

cosistency of  07162 0.9923531995143249


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))




HBox(children=(IntProgress(value=0, max=71), HTML(value='')))


cosistency of  08455 1.0


HBox(children=(IntProgress(value=0, max=71), HTML(value='')))


consistency for cross 2 = []


HBox(children=(IntProgress(value=0, max=54), HTML(value='')))

0
1
2
3
21
29
30
31
32
33
34
35
49
53

cosistency of  05261 0.8991769547325102


HBox(children=(IntProgress(value=0, max=54), HTML(value='')))




HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

14
16

cosistency of  06453 0.9781144781144782


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))




HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

4
5
6
11
14
16
17
18
24
30
34
38
67
77
84
85
94
95
96
100
107
112
130
137
152
163
186
189
190
194
195
196
197
200
205
210
215
216
217
218
219
220
221
222
223
224
225
226
227
229
230
231
232
233
234
235
236
253
254
256
262
264
265
266
269
270
271
272
273
274
275
276
283
284
286
301
302
304
305
306
307
308
309
311
313
316
317
318
326
327
328
329
330
331
332
333
334
335
338
339
341
342
343
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
383
384
385
386
387
388
389
390

cosistency of  07162 0.9429590017825312


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))


consistency for cross 3 = 0.9400834782098398


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

134
195
224
226
230
232
269
270
273
274
275
276
286
308
329
332
333
335
338
350
351
353
354
356
357
361
362
363
364
365
366
367
368
369
370
371
373
374
376
378
379
380
381

cosistency of  07162 0.9909323413159731


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


cosistency of  04043 1.0


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))




HBox(children=(IntProgress(value=0, max=65), HTML(value='')))


cosistency of  08405 1.0


HBox(children=(IntProgress(value=0, max=65), HTML(value='')))


consistency for cross 4 = 0.9969774471053244


## Summary of consistency

consistency for cross 0 = 0.9759052740751349

consistency for cross 1 = 0.9351165329017288

consistency for cross 2 = []

consistency for cross 3 = 0.9400834782098398

consistency for cross 4 = 0.9969774471053244