In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import train
import models

plt.rcParams['figure.figsize']=[16,8]

In [None]:
# Load Test Dataset
dataset = train.Dataset('./data/test.npy')

# Load Signal Reconstruction Model
net_sr = models.get_signal_reconstruction_model()
net_sr.load_state_dict(torch.load('logs/model_sr.pth'))
net_sr = net_sr.eval()

# Load Frequency Estimation Model
net_fe = models.get_frequency_estimation_model()
net_fe.load_state_dict(torch.load('logs/model_fe.pth'))
net_fe = net_fe.eval()

In [None]:
idx = np.random.randint(len(dataset))
#idx = 777
print('Data Index : ', idx)
x, f, s = dataset[idx]

# Numpy Array to Torch Tensor
xt = torch.FloatTensor([x])
ft = torch.FloatTensor([f])
st = torch.FloatTensor([s])

with torch.no_grad():
    # Signal Reconstruction Model Inference
    s_pred = net_sr(xt)
    s_pred = s_pred.squeeze().numpy()
    # Frequency Estimation Model Inference
    f_pred = net_fe(xt)
    f_pred = f_pred.squeeze().numpy()

mae = np.mean(np.abs(f_pred - f))
print('Mean Absolute Error (bpm) = %.4f'% (mae * 60))

plt.figure(0)
for i in range(16):
    plt.subplot(4, 4, i+1)
    plt.plot(x[i])
plt.figure(1)
plt.plot(s, label='Original Signal')
plt.plot(s_pred, label='Reconstructed Signal')
plt.legend()

In [None]:
results_sr = []
results_fe = []
div = 100
for i in tqdm(range(0, len(dataset), div)):
    x, f, s = dataset[i:i+div]
    
    xt = torch.from_numpy(x)
    ft = torch.from_numpy(f)
    st = torch.from_numpy(s)
    
    with torch.no_grad():
        # Signal Reconstruction Model Inference
        s_pred = net_sr(xt)
        s_pred = s_pred.squeeze().numpy()
        # Frequency Estimation Model Inference
        f_pred = net_fe(xt)
        f_pred = f_pred.squeeze().numpy()
        
        mse_sr = np.mean(np.square(s_pred - s))
        mae_fe = np.mean(np.abs(f_pred - f.reshape(-1, 1)))
        
        results_sr.append(mse_sr)
        results_fe.append(mae_fe)

plt.figure(figsize=(16, 6))
plt.subplot(1, 2, 1)
plt.plot(results_sr)
plt.title('Signal Reconstruction')
plt.xlabel('Noise Level')
plt.ylabel('Error (MSE)')
plt.grid()
plt.subplot(1, 2, 2)
plt.plot(results_fe)
plt.title('Frequency Estimation')
plt.xlabel('Noise Level')
plt.ylabel('Error (MAE)')
plt.grid()