In [1]:
import warnings
warnings.filterwarnings('ignore')

import glob
import numpy as np
import torch
import matplotlib.pyplot as plt

from models import GeneratorResNet
from dataset import PPG2ECG_Dataset_Eval
from tqdm import tqdm
from make_args import Args

from sklearn.metrics import mean_absolute_error
import dtw # pip install dtw-python

Importing the dtw module. When using in academic works please cite:
  T. Giorgino. Computing and Visualizing Dynamic Time Warping Alignments in R: The dtw Package.
  J. Stat. Soft., doi:10.18637/jss.v031.i07.



- load args

In [2]:
args = Args('./config/CycleGAN_PPG2ECG.json')

- 필요 함수

In [3]:
def MAE(y, y_pred):
    return mean_absolute_error(y, y_pred)
    
def NMAE(y, y_pred):
    numerator = np.abs(y-y_pred).sum()
    denominator = np.abs(y_pred).sum()
    
    return numerator / (denominator + 1e-10)
    
def RMSE(y, pred):
    return np.sqrt(np.mean(np.square(y-pred)))
    
def NRMSE(y, y_pred):
    numerator = RMSE(y, y_pred)
    denominator = y.max() - y.min()
    
    return numerator / (denominator + 1e-10)

def PRD(y, y_pred):
    numerator = ((y-y_pred)**2).sum()
    denominator = (y**2).sum()
    
    return np.sqrt((numerator / (denominator + 1e-10))*100)

def DTW(y, y_pred):
    return dtw.dtw(y, y_pred, keep_internals=True).distance

In [4]:
def load_model(weights_path, input_shape, n_residual_blocks, DEVICE):
    G_AB = GeneratorResNet(input_shape, n_residual_blocks)
    weights = torch.load(weights_path, map_location=DEVICE)
    G_AB.load_state_dict(weights['G_AB'])
    G_AB.to(DEVICE)
    G_AB.eval()
    
    return G_AB

- Define Device

In [5]:
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')

print("Using Pytorch Versions:", torch.__version__, ' Device:', DEVICE)

Using Pytorch Versions: 2.1.1+cu118  Device: cuda


- get dataloader

In [6]:
partition = np.load(args.partition_path, allow_pickle=True).item()

valset = partition['valset']

In [7]:
batch_size = 100
num_worker = args.num_workers

trainloader_instance = PPG2ECG_Dataset_Eval(valset, sampling_rate=args.target_sampling_rate, 
                                           min_max_norm=args.min_max_norm, z_score_norm=args.z_score_norm, interp=args.interp_method)
train_dataloader = torch.utils.data.DataLoader(trainloader_instance,
                                                   batch_size = batch_size,
                                                   shuffle = None,
                                                   num_workers = num_worker,
                                                   drop_last = True,
                                                   pin_memory = True)

- evaluation

In [8]:
model_list = glob.glob('./model_result/CycleGAN/*.pth')
input_shape = (None, 1, int(args.target_sampling_rate * args.sig_time_len))
n_residual_blocks=args.n_residual_blocks

model_perf_list = []

In [11]:
for model_path in model_list:
    G_AB = load_model(model_path, input_shape, n_residual_blocks, DEVICE)
    
    # metric list
    MAE_temp= []
    NMAE_temp = []
    RMSE_temp = []
    NRMSE_temp = []
    PRD_temp = []
    DTW_temp = []
    
    for input_data in tqdm(train_dataloader, total=len(valset)//batch_size):
        # prepare data
        input_ppg, ref_ecg = input_data['ppg'], input_data['ecg']
        
        # inference
        syn_ecg = G_AB(input_ppg.to(DEVICE))
        syn_ecg = syn_ecg.data.cpu().numpy()[:,0,:]
        ref_ecg = ref_ecg.data.cpu().numpy()[:,0,:]
        
        # performance
        for ref, syn in zip(ref_ecg, syn_ecg):
            MAE_temp.append(MAE(ref, syn))
            NMAE_temp.append(NMAE(ref, syn))
            RMSE_temp.append(RMSE(ref, syn))
            NRMSE_temp.append(NRMSE(ref, syn))
            PRD_temp.append(PRD(ref, syn))
            DTW_temp.append(DTW(ref, syn))
        
    model_perf_dict = {}
    model_perf_dict['model_name'] = model_path
    model_perf_dict['MAE'] = np.array(MAE_temp).mean()
    model_perf_dict['NAME'] = np.array(NMAE_temp).mean()
    model_perf_dict['RMSE'] = np.array(RMSE_temp).mean()
    model_perf_dict['NRMSE'] = np.array(NRMSE_temp).mean()
    model_perf_dict['PRD'] = np.array(PRD_temp).mean()
    model_perf_dict['DTW'] = np.array(DTW_temp).mean()
    
    model_perf_list.append(model_perf_dict)

100%|██████████████████████████████████████████████████████████████████████████████████| 38/38 [01:02<00:00,  1.64s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 38/38 [00:57<00:00,  1.50s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 38/38 [00:56<00:00,  1.48s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 38/38 [00:56<00:00,  1.49s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 38/38 [00:57<00:00,  1.51s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 38/38 [00:58<00:00,  1.53s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 38/38 [00:57<00:00,  1.51s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 38/38 [00:57<00:00,  1.50s/it]
100%|███████████████████████████████████

- dtw 기준으로 performance 정렬

In [12]:
dtw_list = []

for i, perf in enumerate(model_perf_list):
    dtw_list.append(perf['DTW'])
    
sort_idx = np.argsort(dtw_list)

In [26]:
top_5_model_perf = [model_perf_list[sort_idx[i]] for i in range(5)]

In [27]:
top_5_model_perf

[{'model_name': './model_result/CycleGAN\\PPG2ECG_CycleGAN_3Epochs.pth',
  'MAE': 0.19828902,
  'NAME': 0.643274364180949,
  'RMSE': 0.32285342,
  'NRMSE': 0.16142671444283388,
  'PRD': 9.806734820984314,
  'DTW': 47.1278363565298},
 {'model_name': './model_result/CycleGAN\\PPG2ECG_CycleGAN_1Epochs.pth',
  'MAE': 0.21024105,
  'NAME': 0.5887596906764797,
  'RMSE': 0.3214796,
  'NRMSE': 0.16073979359884974,
  'PRD': 9.833386328101025,
  'DTW': 50.14073166563581},
 {'model_name': './model_result/CycleGAN\\PPG2ECG_CycleGAN_4Epochs.pth',
  'MAE': 0.2441236,
  'NAME': 0.5962719871713457,
  'RMSE': 0.35292417,
  'NRMSE': 0.17646208662245613,
  'PRD': 10.86567972818998,
  'DTW': 58.72793744724941},
 {'model_name': './model_result/CycleGAN\\PPG2ECG_CycleGAN_10Epochs.pth',
  'MAE': 0.24439129,
  'NAME': 0.580678726286529,
  'RMSE': 0.32982665,
  'NRMSE': 0.16491332637399925,
  'PRD': 10.194100932128572,
  'DTW': 61.91622464640105},
 {'model_name': './model_result/CycleGAN\\PPG2ECG_CycleGAN_7Epo