In [2]:
import numpy as np
import os
import pyedflib
import random
from matplotlib import pyplot as plt
from scipy.interpolate import interp1d
import xml.etree.ElementTree as ET
import sys
import scipy.stats as stats
from sklearn.metrics import roc_auc_score
import sys
import pandas as pd
sys.path.append('/home/linzenghui/ECG_code/HeartRateVariability_220217')
import FrequencyDomain as fd
import TimeDomain as td
import NonLinear as nl
from common import *
from Rpeaks import *
files = os.listdir('../polysomnography')
names = [name for name in files if name.split('.')[-1] == 'annot']
names.sort()
data_frame=pd.read_csv('/data/0shared/linzenghui/ECG_data/msp/datasets/msp-harmonized-dataset-0.1.1.csv')
##97 116没有annot

In [2]:
def resample_interp(ts, fs_in, fs_out):
    """
    基于线性拟合的差值重采样算法
    计算前后点对应的比例进行插值
    :param ts:  单导联数据，一维浮点型数组
    :param fs_in: 原始采样率，整型
    :param fs_out: 目标采样率，整型
    :return: 重采样后的数据
    """
    t = len(ts) / fs_in
    fs_in, fs_out = int(fs_in), int(fs_out)
    if fs_out == fs_in:
        return np.array(ts)
    else:
        x_old = np.linspace(0, 1, num=len(ts), endpoint=True)
        x_new = np.linspace(0, 1, num=int(t * fs_out), endpoint=True)
        y_old = ts
        f = interp1d(x_old, y_old, kind='linear')
        y_new = f(x_new)
        return y_new

In [3]:
def get_wake_time(anno_path):
    '''
    anno_path--->waketime(s)
    '''
    with open(anno_path,'r') as f:
        s=f.readlines()
    stage=[line.split('\t')[0] for line in s]
    stage=[sta for sta in stage if len(sta)<=2]
    if len(stage)==0:
        return (0,0)
    else:
        for i in range(len(stage)):
            if stage[i]=='W':
                continue
            else:
                break
        return (i*30,(len(stage))*30)

def get_ecgpath_from_name(name):
    ecg_folder='/data/0shared/linzenghui/ECG_data/msp/polysomnography/'
    return ecg_folder+name.split('.')[0]+'.edf'

In [4]:
def getidx_from_name(name):
    'id'
    return int(name.split('.')[0].split('-')[1][-3:])

def find_ahi(data_frame,id):
    query='id=='+str(id)
    return float(data_frame.query(query)['nsrr_ahi_hp3u'].values)

def find_age(data_frame,id):
    query='id=='+str(id)
    return float(data_frame.query(query)['nsrr_age'].values)

def find_sex(data_frame,id):
    query='id=='+str(id)
    return float(data_frame.query(query)['nsrr_gatest'].values)

def find_bmi(data_frame,id):
    query='id=='+str(id)
    return float(data_frame.query(query)['nsrr_bmi'].values)

def my_random_split(list1,len,shuffle=False):
    temp_list=list1.copy()
    if shuffle:
        random.seed(100)
        random.shuffle(temp_list)
    return (temp_list[0:len[0]],temp_list[-len[1]:])


def get_wake_ecg_from_edf(name_list, bag_path, datasetnumber=10):
    ecg_database=np.zeros(shape=(len(name_list),6+5*60*200))
    ecg_database[:,0]=datasetnumber
    for idx,name in enumerate(name_list):
        wake_time=get_wake_time(bag_path+name)[0]
        assert wake_time>=300
        id=getidx_from_name(name)
        ahi=find_ahi(data_frame,id)
        ecg_database[idx,1]=ahi
        ecg_database[idx,2]=float(ahi>=5)
        ecg_database[idx,3]=find_age(data_frame,id)
        ecg_database[idx,4]=find_sex(data_frame,id)
        ecg_database[idx,5]=find_bmi(data_frame,id)
        f = pyedflib.EdfReader(get_ecgpath_from_name(name))
        labels=f.getSignalLabels()
        ind=labels.index('ECG')
        
        header=f.getSignalHeader(ind)
        if header['dimension']=='uV':
            ecg_data=f.readSignal(ind)/1000
        if header['dimension']=='mV':
            ecg_data=f.readSignal(ind)
        
        fs=f.getSampleFrequency(ind)
        if fs!=200:
            ecg_database[idx,6:]=resample_interp(ecg_data[int((wake_time-300)*fs):int(fs*wake_time)],fs_in=fs,fs_out=200)
        else:
            ecg_database[idx,6:]=ecg_data[int((wake_time-300)*fs):int(fs*wake_time)]
        
        f.close()
    assert idx==len(name_list)-1
    return ecg_database

def cut_data(data,window_size=30*200,step=30*200,datasetnumber=10):
    database=np.zeros(shape=(int(data.shape[0]*((data.shape[1]-6-window_size)/step+1)),6+window_size))
    database[:,0]=datasetnumber
    count=0
    for idx in range(data.shape[0]):
        for start in range(6,data.shape[1],step):
            database[count,1]=data[idx,1]
            database[count,2]=data[idx,2]
            database[count,3]=data[idx,3]
            database[count,4]=data[idx,4]
            database[count,5]=data[idx,5]
            database[count,6:]=data[idx,start:start+window_size]
            count+=1
    assert count==database.shape[0]
    return database

def check_edf(edf):
    f = pyedflib.EdfReader(edf)
    labels=f.getSignalLabels()
    try:
        ind=labels.index('ECG')
    except:
        ind=labels.index('ECG1')
    header=f.getSignalHeader(ind)
    dura=f.getFileDuration()
    ecg=f.readSignal(ind)
    print(header)
    print(dura,len(ecg)/dura)
    f.close()

In [8]:
def r_peaks(sig,fs=200):
    peaks=simple_qrs_detector(sig,fs=fs)
    rpos = R_Wave_finetune(sig, peaks)
    return rpos

def basic_screen(sig,rpos,fs=200):
    if len(rpos)<150:
        return (False,'峰值太少')
        # Amplitude less than 3mV
        # ampl = np.abs(np.max(sig) - np.min(sig))
        #if ampl > 3:
        #    return False
    sig_len=len(sig)
    tmp_sig = np.abs(sig)
    tmp_sig = tmp_sig[tmp_sig > 0.1]
    if len(tmp_sig) < 5:
        return (False,'电压值过低')
    if rpos[0] > fs*5 or rpos[-1] < (sig_len-fs*5):
        return (False,'前方或后方有空缺')
    rr_intervals = np.diff(rpos)
    maxRR = np.max(rr_intervals)
    meanRR = np.mean(rr_intervals)
    if maxRR > meanRR * 3:
        return (False,'rri max值过大')
    if maxRR>fs*5:
        return (False,'有超过5秒的空白')
    return (True,'pass')

def cal_corr_coeff_lst(sig, rpos):
        beat_seg = []
        for idx, r_p in enumerate(rpos):
            if r_p > 200 and (len(sig) - r_p) > 200:
                tmp_seg = sig[r_p - 200:r_p + 200]
                beat_seg.append(tmp_seg)
        beat_seg = np.array(beat_seg)
        template_qrs = np.mean(beat_seg, axis=0)
        template_qrs = template_qrs - np.mean(template_qrs)
        coeff_lst = []
        for seg in beat_seg:
            seg = seg - np.mean(seg)
            coeff = np.corrcoef(seg, template_qrs)[0, 1]
            coeff_lst.append(coeff)
        return coeff_lst

def sqi(sig, rpos):
        preRes = basic_screen(sig,rpos)

        if preRes[0]:
            rpos = rpos
        else:
            return (preRes[1], float(0))

        coeff_lst = cal_corr_coeff_lst(sig, rpos)
        # template_nums = self.check_coeff(coeff_lst)
        coeff = float(np.mean(coeff_lst))

        if coeff > 0.6:
            return (True, coeff)
        else:
            return (False, coeff)

def show_5min_ecg(ecg_data,title,r_peaks=None):
    fig, ax = plt.subplots(nrows=10, ncols=1, sharex=True)
    fig.set(figwidth=15, figheight=20,dpi=300, facecolor='bisque')
    for i in range(10):
        if i==0:
            ax[i].set_title(title)
        ax[i].plot(ecg_data[i*6000:i*6000+6000], color = 'red')
        ax[i].set_xticks(range(0,6200,200))
        ax[i].set_xticklabels(range(0,31,1))
        ax[i].set_xlabel('s')
        ax[i].set_ylabel('mv')
        ax[i].patch.set_facecolor('linen')
        if r_peaks:
            for r in r_peaks:
                if (r>i*6000 and r<=i*6000+6000): 
                    ax[i].scatter([r-6000*i],[ecg_data[r]], color='blue')
    plt.subplots_adjust(hspace=0)

In [9]:
##baseline check
wake_list=[get_wake_time('../polysomnography/'+name)[0]/60 for name in names]
whole_len=[get_wake_time('../polysomnography/'+name)[1]/60 for name in names]
df=pd.DataFrame({'name':names,'wake':wake_list,'whole':whole_len})
name_list=list(df[df['wake']>=5]['name'])##wakeness时间长度合格

wake_ecg=get_wake_ecg_from_edf(name_list,bag_path='../polysomnography/',datasetnumber=7)
sqi_list=[basic_screen(wake_ecg[idx,6:],rpos=r_peaks(wake_ecg[idx,6:]),fs=200)[0] for idx in range(len(wake_ecg))]
cleaned_wake_ecg=wake_ecg[sqi_list]##基本质量过关

In [10]:
with open('msp.npy','xb') as f:
    np.save(f,cleaned_wake_ecg)

In [None]:
msp_data = np.load('msp.npy')
class Psg_dataset(Dataset):
    def __init__(self,state,use_cut) :
        if state=='train':
            self.data_all=train_data
            if use_cut:
                self.data_all=cut_data(train_data)
                print('train_use_cut')
        if state=='test':
            self.data_all=test_data
            if use_cut:
                self.data_all=cut_data(test_data)
                print('test_use_cut')
    
    def __getitem__(self, index) :
            return self.data_all[index,6:],self.data_all[index,2]

    def __len__(self):
        return len(self.data_all)
test_psg_data=Psg_dataset(state='test',use_cut=False)
test_iter=DataLoader(test_psg_data,batch_size=64,shuffle=False)
model.eval()
pro=torch.tensor([]).to(device_str)
true_y=torch.tensor([])
with torch.no_grad():
    for val_x,val_y in test_iter:
        val_x=torch.tensor(np.expand_dims(val_x,1)).to(device_str,torch.float32)
        y_hat=torch.sigmoid(model(val_x))
        pro=torch.cat((pro,y_hat))
        true_y=torch.cat((true_y,val_y))
pre=pro.to('cpu')
value=roc_auc_score(true_y,pre)
print(f"auc:{value}")