In [31]:
import os
import wfdb
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from skimage.restoration import denoise_wavelet

In [67]:
data_directory = "./data/CEBS/m001"
sampling_rate = '1S'
target_snr_db = 3

In [33]:
def signaltonoise_dB(signal_clean, signal_noisy):
    signal_clean_watts = signal_clean ** 2
    signal_clean_avg_watts = np.mean(signal_clean_watts)
    signal_clean_avg_db = 10 * np.log10(signal_clean_avg_watts)

    noise = signal_noisy - signal_clean
    noise_watts = noise ** 2
    noise_avg_watts = np.mean(noise_watts)
    noise_avg_db = 10 * np.log10(noise_avg_watts)

    return signal_clean_avg_db - noise_avg_db

In [34]:
def load_dataset(directory):
    record = wfdb.rdsamp(directory)
    data = record[0]
    metadata = record[1]

    frequency = metadata['fs']
    column_names = metadata['sig_name']
    length = metadata['sig_len']

    data = pd.DataFrame(data, columns=column_names)

    frequency_string = str(int(1/frequency*1000000)) + 'U'
    index = pd.date_range(start='1/1/1970', periods=length, freq=frequency_string)
    data.set_index(index, inplace=True)
    data.drop(['I', 'II', 'RESP'], axis=1, inplace=True)
    
    return data

In [35]:
def add_gaussian(data, target_snr_db):
    data_watts = data ** 2
    data_avg_watts = np.mean(data_watts)
    data_avg_db = 10 * np.log10(data_avg_watts)

    noise_avg_db = data_avg_db - target_snr_db
    noise_avg_watts = 10 ** (noise_avg_db / 10)

    mean_noise = 0
    noise = np.random.normal(mean_noise, np.sqrt(noise_avg_watts), len(data_watts))

    data_noisy = data + noise
    return data_noisy

In [68]:
class DeNoise(nn.Module):
    def __init__(self):
        super(DeNoise, self).__init__()

        self.lin1 = nn.Linear(600, 400)
        self.lin_t1 = nn.Linear(400, 600)

    def forward(self, x):
        x = torch.tanh(self.lin1(x))
        x = self.lin_t1(x)
        return x

In [70]:
data = load_dataset(data_directory)
data = data.resample(sampling_rate).mean().SCG

data_noisy = add_gaussian(data, target_snr_db)

data = pd.DataFrame({'clean': data, 'noise': data_noisy})
denoised = denoise_wavelet(data.noise, method='BayesShrink', mode='soft', wavelet_levels=1, wavelet='sym8', rescale_sigma='True')
data['denoised'] = denoised

data_groups = data.groupby((data.index - data.index[0]).total_seconds() // 600) # 600 sec

"""
print(signaltonoise_dB(data.clean, data.clean))
print(signaltonoise_dB(data.clean, data.noise))
print(signaltonoise_dB(data.clean, data.denoised))

data.clean[:180].plot()
plt.show()
data.noise[:180].plot()
plt.show()
data.denoised[:180].plot()
plt.show()
"""

'\nprint(signaltonoise_dB(data.clean, data.clean))\nprint(signaltonoise_dB(data.clean, data.noise))\nprint(signaltonoise_dB(data.clean, data.denoised))\n\ndata.clean[:180].plot()\nplt.show()\ndata.noise[:180].plot()\nplt.show()\ndata.denoised[:180].plot()\nplt.show()\n'

In [72]:
def train(n_epochs, model):
    training_loss = []
    group_size = data_groups.size()[0]
    
    for epoch in range(n_epochs):
        trainloss = 0.0
        for _, dat in data_groups:
            
            if dat.shape[0] != group_size:
                continue

            sig = torch.Tensor(dat.clean)
            noise = torch.Tensor(dat.noise)
            
            optimizer.zero_grad()
            output = model(noise)
            loss = criterion(output, sig)
            loss.backward()
            optimizer.step()
            trainloss += loss.item()

            if (epoch + 1) % 10 == 0:
                print(f'epoch {epoch + 1} / {n_epochs}, loss = {loss.item():.4f}')

        training_loss.append(trainloss / data.shape[0])

model = DeNoise()
train(100, model)

epoch 10 / 100, loss = 0.0294
epoch 10 / 100, loss = 0.0272
epoch 10 / 100, loss = 0.0244
epoch 10 / 100, loss = 0.0242
epoch 20 / 100, loss = 0.0294
epoch 20 / 100, loss = 0.0272
epoch 20 / 100, loss = 0.0244
epoch 20 / 100, loss = 0.0242
epoch 30 / 100, loss = 0.0294
epoch 30 / 100, loss = 0.0272
epoch 30 / 100, loss = 0.0244
epoch 30 / 100, loss = 0.0242
epoch 40 / 100, loss = 0.0294
epoch 40 / 100, loss = 0.0272
epoch 40 / 100, loss = 0.0244
epoch 40 / 100, loss = 0.0242
epoch 50 / 100, loss = 0.0294
epoch 50 / 100, loss = 0.0272
epoch 50 / 100, loss = 0.0244
epoch 50 / 100, loss = 0.0242
epoch 60 / 100, loss = 0.0294
epoch 60 / 100, loss = 0.0272
epoch 60 / 100, loss = 0.0244
epoch 60 / 100, loss = 0.0242
epoch 70 / 100, loss = 0.0294
epoch 70 / 100, loss = 0.0272
epoch 70 / 100, loss = 0.0244
epoch 70 / 100, loss = 0.0242
epoch 80 / 100, loss = 0.0294
epoch 80 / 100, loss = 0.0272
epoch 80 / 100, loss = 0.0244
epoch 80 / 100, loss = 0.0242
epoch 90 / 100, loss = 0.0294
epoch 90 /