In [9]:
#!/usr/bin/env python
# coding: utf-8

##### libraries
## 3rd party
from __future__ import print_function
import pywt
import math
import numpy as np
import matplotlib.pyplot as plt
import copy
import pdb
import csv
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
torch.set_default_tensor_type(torch.DoubleTensor)

# ignore package deprecation warnings
import warnings
warnings.filterwarnings("ignore")

## custom
from utils.data_formatting import HDF5Dataset
from utils.common import *

In [10]:
##### parameters
val_fold = 4
bs = 32

In [11]:
dset_train = CINC7Dataset('../data/custom/physionet2017_4classes.hdf5',train=True,transform=DWT(4))
dset_val = CINC7Dataset('../data/custom/physionet2017_4classes.hdf5',train=False,transform=DWT(4))

dl_train = DataLoader(dset_train, batch_size=bs, shuffle=False, num_workers=0)
dl_val = DataLoader(dset_val, batch_size=bs, shuffle=False, num_workers=0)

In [12]:
##### main code
# load networks
model_path = f'snn_ref/model_fold{val_fold}.pt'
cnn = WVCNN4()
cnn.load_state_dict(torch.load(model_path).state_dict(), strict=False)
cnn_bninteg = WVCNN4()
cnn_bninteg.normed = True
cnn_bninteg.load_state_dict(cnn.state_dict())

<All keys matched successfully>

In [13]:
#"""
# import and format physionet data
data_raw = HDF5Dataset('../data/custom/physionet2017_4classes.hdf5', True)

# convert to slicable data
ecgs = torch.empty(len(data_raw),18000)
dwts = torch.empty(len(data_raw),2,1127)
lbls = torch.empty(len(data_raw))
for ii,data in enumerate(tqdm(data_raw)):
    if len(data[0][0]) >= 18000:
        data_in = data[0][:,0:18000]
    else:
        data_in = torch.zeros((1,18000))
        data_in[0,0:len(data[0][0])] = data[0]
    ecgs[ii,:] = data_in
    dwt_tmp = pywt.wavedec(data_in,'db2',level=4)
    dwts[ii,:,:] = torch.tensor(np.concatenate((dwt_tmp[0],dwt_tmp[1]),0))
    lbls[ii] = data[1]
#"""

100%|██████████| 8528/8528 [00:19<00:00, 426.80it/s]


In [14]:
#"""
# get train/test for normalization
fold_idx = split_into_5folds(lbls)
train_idx = (fold_idx!=val_fold)
val_idx = (fold_idx==val_fold)

dwt_val = dwts[val_idx].unsqueeze(1)
dwt_train = dwts[train_idx].unsqueeze(1)

lbl_val = lbls[val_idx].long()
lbl_train = lbls[train_idx].long()
#"""

In [15]:
b = 62

def q(x,b):
    # no saturation!
    result = torch.round(x * 2**b) / 2**b
    return result 

db2wvlt = pywt.Wavelet('db2')
# floating point values in python are stored as double precision values
print(type(db2wvlt.dec_lo[0]))


dlo_fp64 = np.array(db2wvlt.dec_lo,dtype=np.float64)
dlo_fp64 = torch.from_numpy(dlo_fp64)
dhi_fp64 = np.array(db2wvlt.dec_hi,dtype=np.float64)
dhi_fp64 = torch.from_numpy(dhi_fp64)
print(dlo_fp64.dtype)
print(dlo_fp64)
print(dhi_fp64)
dlo = q(dlo_fp64,b)
dhi = q(dhi_fp64,b)
print(dlo.dtype)
print(dlo)
print(dhi)

edlo_abs = torch.abs(dlo-dlo_fp64)
edlo_rmse = torch.sqrt(torch.mean(edlo_abs**2))
edhi_abs = torch.abs(dhi-dhi_fp64)
edhi_rmse = torch.sqrt(torch.mean(edhi_abs**2))
print(edlo_rmse)
print(edhi_rmse)
torch.zeros((1,18000)).dtype

<class 'float'>
torch.float64
tensor([-0.1294,  0.2241,  0.8365,  0.4830])
tensor([-0.4830,  0.8365, -0.2241, -0.1294])
torch.float64
tensor([-0.1294,  0.2241,  0.8365,  0.4830])
tensor([-0.4830,  0.8365, -0.2241, -0.1294])
tensor(0.)
tensor(0.)


torch.float64

In [16]:
dwt_lvl=4

db2wvlt = pywt.Wavelet('db2')
dlo = q(torch.Tensor(db2wvlt.dec_lo),b)
dhi = q(torch.Tensor(db2wvlt.dec_hi),b)
din_q = q(ecgs,b)


if (len(din_q.shape) == 1):
    din_q = din_q.unsqueeze(0).unsqueeze(0)
else:
    din_q = din_q.unsqueeze(1)    

dcoeff = torch.cat((dlo.flip(0).unsqueeze(0),dhi.flip(0).unsqueeze(0)),0).unsqueeze(1)

dwt_tmp = []
xin = din_q
for ii in range(dwt_lvl):
    coeff = F.conv1d(xin,dcoeff,padding='valid')
    #pdb.set_trace()
    coeffs = coeff[:,:,0::2]
    coeffq = q(coeffs,b)

    xin = coeffq[:,0:1,:]
    
    dwt_tmp = [coeffq[:,1:2,:]] + dwt_tmp 


dwts = coeffq
#dwts_normed = coeffq * self.lmbda + self.dlta



dwt_tmp2 = pywt.wavedec(data_in,'db2',level=dwt_lvl)
dwts2 = torch.tensor(np.concatenate((dwt_tmp2[0],dwt_tmp2[1]),1).reshape((dwt_tmp2[0].shape[0],2,1127)))
#dwts_normed = dwts * self.lmbda + self.dlta

In [17]:
coeff = F.conv1d(data_in.unsqueeze(0),dcoeff,padding='valid')
coeffs = coeff[:,:,0::2]
coeff2 = F.conv1d(coeffs[:,0:1,:],dcoeff,padding='valid')
coeffs2 = coeff2[:,:,1::2]
coeff3 = F.conv1d(coeffs2[:,0:1,:],dcoeff,padding='valid')
coeffs3 = coeff3[:,:,0::2]
coeff4 = F.conv1d(coeffs3[:,0:1,:],dcoeff,padding='valid')
coeffs4 = coeff4[:,:,0::2]

torch.set_printoptions(precision=16)
print(coeffs[:,0:1,:])
print(coeffs2[:,0:1,:])
print(coeffs3[:,0:1,:])
print(coeffs4[:,0:1,:])

tensor([[[-0.3789222866689697, -0.5280587494520761, -0.6125757266505413,
           ...,  0.0000000000000000,  0.0000000000000000,
           0.0000000000000000]]])
tensor([[[-0.8283945583940165, -0.8802224679852876, -0.8298870670433200,
           ...,  0.0000000000000000,  0.0000000000000000,
           0.0000000000000000]]])
tensor([[[-1.2217629776138441, -1.1300250905552371, -0.8820119697674462,
           ...,  0.0000000000000000,  0.0000000000000000,
           0.0000000000000000]]])
tensor([[[-1.6466345070674784, -1.0158996906375446, -1.0924075096805570,
           ...,  0.0000000000000000,  0.0000000000000000,
           0.0000000000000000]]])


In [18]:
tmp = pywt.wavedec(data_in,'db2',level=1)
tmp[0][0,0:9]

array([-0.35251759, -0.37892229, -0.52805875, -0.61257573, -0.62906239,
       -0.61872026, -0.59751339, -0.58025005, -0.56243327])

In [19]:
tmp = pywt.wavedec(data_in,'db2',level=2)
tmp[0][0,0:9]

array([-0.50787063, -0.52631559, -0.82839456, -0.88022247, -0.82988707,
       -0.77780531, -0.69275403, -0.59271511, -0.49676922])

In [20]:
tmp = pywt.wavedec(data_in,'db2',level=3)
tmp[0][0,0:9]

array([-0.72475881, -0.75732464, -1.22176298, -1.13002509, -0.88201197,
       -0.66775369, -0.6203374 , -0.83232711, -1.09599978])

In [21]:
tmp = pywt.wavedec(data_in,'db2',level=4)
tmp[0][0,0:9]

array([-1.0364775 , -1.11116071, -1.64663451, -1.01589969, -1.09240751,
       -1.60833762, -1.51961904, -1.3363575 , -1.30799303])