In [1]:
import numpy as np
import torch
import glob
import sklearn.preprocessing as skp

from dataset import PPG_Dataset
from models import vgg16_bn
from gan_models import GeneratorResNet
from sklearn.metrics import classification_report
from scipy.interpolate import splrep, splev
from tqdm import tqdm
from make_args import Args

- args

In [2]:
args = Args('./config/VGG16_PPGECG.json')

- 필요 함수

In [3]:
def make_syn_ecg(G_AB, input_ppg):
    with torch.no_grad():
        syn_ecg = G_AB(input_ppg)

    return syn_ecg

In [4]:
def normalize_data(data, z_score_norm=True, min_max_norm=False):
    norm_data = []
    data = data.data.cpu().numpy()

    for i in range(len(data)):
        target_data = data[i].copy()
        if z_score_norm:
            target_data = (target_data - target_data.mean()) / (target_data.std() + 1e-17)
        if min_max_norm:
            target_data = skp.minmax_scale(target_data, (-1, 1), axis=1)
        norm_data.append(target_data)

    return torch.from_numpy(np.array(norm_data)).type(torch.FloatTensor)

In [5]:
def interp_spline(ecg, step=1, k=3):
    x_new = np.arange(0, ecg.shape[0], ecg.shape[0]/step)
    interp_spline_method = splrep(np.arange(0, ecg.shape[0], 1), ecg, k=k)
    return splev(x_new, interp_spline_method)

In [6]:
def batch_signal_downsample(batch_signal, fs, target_fs):
    batch_signal = batch_signal.data.cpu().numpy()
    sig_seconds = batch_signal.shape[2] // fs
    
    interp_signals = []
    for i in range(len(batch_signal)):
        interp_signal = interp_spline(batch_signal[i][0], step=target_fs*sig_seconds, k=5)
        interp_signals.append(interp_signal)
    interp_signals = np.array(interp_signals)
    
    return torch.from_numpy(interp_signals).type(torch.FloatTensor).unsqueeze(1)

- Define Device

In [7]:
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')

print("Using Pytorch Versions:", torch.__version__, ' Device:', DEVICE)

Using Pytorch Versions: 2.1.1+cu118  Device: cuda


- load partition

In [8]:
partition = np.load(args.partition_path, allow_pickle=True).item()

testset = partition['testset']

- dataloader

In [9]:
batch_size = 33 # testset
num_workers = args.num_workers

In [10]:
dataloader_instance = PPG_Dataset(filepaths=testset, sampling_rate=args.target_sampling_rate, 
                                  min_max_norm=args.min_max_norm, z_score_norm=args.z_score_norm, interp=args.interp_method)
dataloader = torch.utils.data.DataLoader(dataloader_instance,
                                             batch_size = batch_size,
                                             shuffle = None,
                                             num_workers = num_workers,
                                             drop_last = True,
                                             pin_memory = True)

- load model

In [11]:
cls_weight_path = args.model_save_path

cls_model = vgg16_bn(in_channels=args.in_channels, num_classes=args.num_classes)
cls_weights = torch.load(cls_weight_path)
cls_model.load_state_dict(cls_weights)
cls_model.to(DEVICE)
cls_model.eval()

VGG(
  (features): Sequential(
    (0): Conv1d(2, 32, kernel_size=(3,), stride=(1,), padding=(1,))
    (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))
    (4): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
    (7): Conv1d(32, 64, kernel_size=(3,), stride=(1,), padding=(1,))
    (8): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
    (11): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
    (14): Conv1d(64, 128, kernel_size=(3,), stride=(1,

- load PPG2ECG model

In [12]:
weights_path = './gan_weights/PPG2ECG_CycleGAN_1Epochs.pth'
input_shape = (None, 1, int(args.target_sampling_rate * args.sig_time_len))
n_residual_blocks = 6

G_AB = GeneratorResNet(input_shape, n_residual_blocks)
weights = torch.load(weights_path)
G_AB.load_state_dict(weights['G_AB'])
G_AB.to(DEVICE)
G_AB.eval()

GeneratorResNet(
  (model): Sequential(
    (0): ReflectionPad1d((1, 1))
    (1): Conv1d(1, 64, kernel_size=(7,), stride=(1,), padding=(1,))
    (2): InstanceNorm1d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
    (4): Conv1d(64, 128, kernel_size=(3,), stride=(2,), padding=(2,))
    (5): InstanceNorm1d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (6): LeakyReLU(negative_slope=0.2, inplace=True)
    (7): Conv1d(128, 256, kernel_size=(3,), stride=(2,), padding=(2,))
    (8): InstanceNorm1d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (9): LeakyReLU(negative_slope=0.2, inplace=True)
    (10): ResidualBlock(
      (block): Sequential(
        (0): ReflectionPad1d((1, 1))
        (1): Conv1d(256, 256, kernel_size=(3,), stride=(1,))
        (2): InstanceNorm1d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (3): LeakyReLU(neg

- inference

In [13]:
z_score_norm = args.z_score_norm_rescale
min_max_norm = args.min_max_norm_rescale

In [14]:
y_trues_list = []
y_preds_list = []

for data in tqdm(dataloader, total=len(dataloader)):    
    with torch.no_grad():
        # data extraction
        ppg, label = data
        ppg = ppg.to(DEVICE)
        ecg = make_syn_ecg(G_AB, ppg)   
        
        ppg = batch_signal_downsample(ppg, 128, args.downsample_fs)
        ecg = batch_signal_downsample(ecg, 128, args.downsample_fs)
        
        # normalize
        if z_score_norm == True and min_max_norm == False:
            ppg = normalize_data(ppg, z_score_norm=True, min_max_norm=False)
            ecg = normalize_data(ecg, z_score_norm=True, min_max_norm=False)

        elif z_score_norm == False and min_max_norm == True:
            ppg = normalize_data(ppg, z_score_norm=False, min_max_norm=True)
            ecg = normalize_data(ecg, z_score_norm=False, min_max_norm=True)

        elif z_score_norm == True and min_max_norm == True:
            ppg = normalize_data(ppg, z_score_norm=True, min_max_norm=True)
            ecg = normalize_data(ecg, z_score_norm=True, min_max_norm=True)

        elif z_score_norm == False and min_max_norm == False:
            ppg = normalize_data(ppg, z_score_norm=False, min_max_norm=False)
            ecg = normalize_data(ecg, z_score_norm=False, min_max_norm=False)

        input_data = torch.cat((ppg, ecg), 1).to(DEVICE)
        
        y_pred = cls_model(input_data).data.cpu().numpy()
        y_pred_max = list(np.argmax(y_pred, -1))
        label = list(label.data.cpu().numpy())
        
        y_trues_list.extend(label)
        y_preds_list.extend(y_pred_max)

100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:04<00:00,  1.13it/s]


- score

In [16]:
print(classification_report(y_trues_list, y_preds_list, digits=4))

              precision    recall  f1-score   support

           0     0.9688    0.9612    0.9650       129
           1     0.8649    0.8889    0.8767        36

    accuracy                         0.9455       165
   macro avg     0.9168    0.9251    0.9208       165
weighted avg     0.9461    0.9455    0.9457       165

