In [1]:
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path

import torch
from torch.utils.data import Dataset, DataLoader

In [2]:
data_dir = Path(r'..\dataset\del_wake')
data_json = Path(r'..\dataset\data_list')
mesa_mros_shhs = 'mesa_mros_shhs_20230215.json'
train_name_lst = ['mesa', 'MrOS_visit1', 'MrOS_visit2']

In [3]:
# MESA + MrOS + SHHS
with open(data_json/mesa_mros_shhs) as f:
    all_npz = json.load(f)

data_num = [[len(i) for i in v.values()] for v in all_npz.values()]
df = pd.DataFrame(data_num).transpose()
df.index = ['normal', 'mild', 'moderate', 'severe']
df.columns = all_npz.keys()
df.loc['total'] = df.sum(axis=0)
df.loc[:, 'total'] = df.sum(axis=1)
df

Unnamed: 0,mesa,MrOS_visit1,MrOS_visit2,shhs1,shhs2,total
normal,414,502,154,1766,629,3465
mild,643,1029,360,2031,960,5023
moderate,518,792,283,1237,637,3467
severe,481,583,229,759,425,2477
total,2056,2906,1026,5793,2651,14432


In [4]:
test_list = {k: [] for k in all_npz.keys() if k not in train_name_lst}

for k, v in all_npz.items():
    if k not in train_name_lst:
        for j in v.values():
            test_list[k] += j

print('test list:')
for k, v in test_list.items():
    print(f'\t{k}: {len(v)}')

test list:
	shhs1: 5793
	shhs2: 2651


In [5]:
SR = 1 # sample rate
segment_2d = False # segment signal to 2D
segment_len = 60*SR if segment_2d else None # length of each segment
signal_len = (8*3600*SR)//segment_len if segment_2d else 8*3600*SR # length of signal

def normalize(signal):
    return signal/100.

def AHI_class(AHI, classes=4, cutoff=15):
    if classes == 2:
        if AHI < cutoff:
            label = 0
        else:
            label = 1
    elif classes == 4:
        if AHI < 5:
            label = 0
        elif 5 <= AHI < 15:
            label = 1
        elif 15 <= AHI < 30:
            label = 2
        else: label = 3
    
    return label

def read_npz_file(file):
    npz_data = np.load(file, allow_pickle=True)
    signal = normalize(npz_data['SpO2'])
    
    if segment_2d:
        segment_num = len(signal)//segment_len
        signal = signal[:segment_num*segment_len].reshape((segment_num, segment_len))
        
    label = AHI_class(float(npz_data['csv_data'].item()['ahi_a0h3']))
        
    return signal, label

def cut_pad_signal(signal, length=signal_len, mode='middle'):
    if len(signal) > length:
        cut_len = len(signal)-length
        if mode == 'middle': # cut from middle
            mid = cut_len//2
            return signal[mid:mid+length]
        elif mode == 'random': # random cut
            rand = np.random.randint(cut_len)
            return signal[rand:rand+length]
    
    else:
        pad_len = length-len(signal)
        if segment_2d:
            signal = np.pad(signal, ((0, pad_len), (0, 0)))
        else:
            signal = np.pad(signal, (0, pad_len))
        return signal

In [6]:
test_info = {j: {i: [] for i in ['signal', 'label']} for j in test_list.keys()}

for k, v in test_list.items():
    for file in tqdm(v):
        path = str(data_dir/file)
        signal, label = read_npz_file(path)
        test_info[k]['signal'].append(signal)
        test_info[k]['label'].append(label)

100%|████████████████████████████████████████████████████████████████████████████| 5793/5793 [00:05<00:00, 1051.65it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 2651/2651 [00:03<00:00, 870.81it/s]


In [7]:
def startEvaluation(model):
    model.eval()
    
    batch_data = []
    for v in test_data.values():
        batch_data += v
        
    _ = model(batch_data[0][0]) # initialization
    
    pbar = tqdm(batch_data, unit='batch')
    for b, batch in enumerate(pbar):
        signals, labels = batch

        with torch.no_grad():
            outputs = model(signals)

In [8]:
class EvaluationDataset(Dataset):
    def __init__(self, data_type, signal, label):
        self.data_type = data_type
        self.signal = signal
        self.label = label
        
    def __getitem__(self, index):
        if self.data_type == 'fixed': # fixed length
            signal = torch.tensor(cut_pad_signal(self.signal[index]), dtype=torch.float32).cuda().unsqueeze(0)
        elif self.data_type == 'original': # original length
            signal = torch.tensor(self.signal[index], dtype=torch.float32).cuda().unsqueeze(0)
        label = torch.tensor(self.label[index], dtype=torch.int64).cuda()
        return signal, label
    
    def __len__(self):
        return len(self.signal)

In [9]:
# original length
batch_size = 1
test_data = {}
for k in test_list.keys():
    test_data[k] = DataLoader(EvaluationDataset('original', test_info[k]['signal'], test_info[k]['label']),
                              batch_size=batch_size)

In [10]:
from models import m004, m006, m008, m009, m204_ReLU2, m205, m206, m207

In [11]:
def speedTest(model, model_name):
    weights_dir = Path('weights', model_name)
    weights_file = 'best_loss'
    model = model().cuda()
    model.load_state_dict(torch.load(weights_dir/f'{weights_file}.pth'))

    startEvaluation(model)

In [12]:
model_name = 'm004_20230215_152506'
speedTest(m004, model_name)

100%|██████████████████████████████████████████████████████████████████████████| 8444/8444 [00:19<00:00, 424.00batch/s]


In [13]:
model_name = 'm006_20230216_104829'
speedTest(m006, model_name)

100%|██████████████████████████████████████████████████████████████████████████| 8444/8444 [00:22<00:00, 367.40batch/s]


In [14]:
model_name = 'm008_20230216_134648'
speedTest(m008, model_name)

100%|██████████████████████████████████████████████████████████████████████████| 8444/8444 [00:26<00:00, 322.64batch/s]


In [15]:
model_name = 'm009_20230220_150313'
speedTest(m009, model_name)

100%|██████████████████████████████████████████████████████████████████████████| 8444/8444 [00:32<00:00, 263.86batch/s]


In [16]:
model_name = 'm204_ReLU2_20230302_150035'
speedTest(m204_ReLU2, model_name)

100%|██████████████████████████████████████████████████████████████████████████| 8444/8444 [00:21<00:00, 391.90batch/s]


In [17]:
model_name = 'm205_20230307_122555'
speedTest(m205, model_name)

100%|██████████████████████████████████████████████████████████████████████████| 8444/8444 [00:24<00:00, 340.31batch/s]


In [18]:
model_name = 'm206_20230307_133016'
speedTest(m206, model_name)

100%|██████████████████████████████████████████████████████████████████████████| 8444/8444 [00:25<00:00, 327.35batch/s]


In [19]:
model_name = 'm207_20230307_151901'
speedTest(m207, model_name)

100%|██████████████████████████████████████████████████████████████████████████| 8444/8444 [00:32<00:00, 256.56batch/s]
