In [0]:
REBUILD_DATA = True
GITHUB_DATA = False

In [0]:
import os

# Downloading raw data
if GITHUB_DATA == False and os.path.exists("physionet.org/files/challenge-2017/1.0.0/training/") == False:
  ! wget -r -N -c -np -nv -q https://physionet.org/files/challenge-2017/1.0.0/
  print("Raw data")
# Downloading preprocessed data from Github
elif GITHUB_DATA == False and os.path.exists("ECG_DATA/") == False:
  ! git clone https://github.com/kendreaditya/ECG_DATA.git
  print("Preprocessed data from Github")

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
# Imports and Requirements
import matplotlib.pyplot as plt
import numpy as np
from scipy.io import loadmat
from scipy import signal
from tqdm import tqdm
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import time
from matplotlib import style

# Enabling Cuba
if torch.cuda.is_available():
  device = torch.device("cuda:0")
  print("Running on GPU -", device )
else:
  device = torch.device("cpu")
  print("Running on CPU -", device )

Running on GPU - cuda:0


In [0]:
"""
Copyright (c) 2013 Jami Pekkanen

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

import sys
import numpy as np
import scipy.signal
import scipy.ndimage

def detect_beats(
		ecg,	# The raw ECG signal
		rate,	# Sampling rate in HZ
		# Window size in seconds to use for 
		ransac_window_size=5.0,
		# Low frequency of the band pass filter
		lowfreq=5.0,
		# High frequency of the band pass filter
		highfreq=15.0,
		):
	"""
	ECG heart beat detection based on
	http://link.springer.com/article/10.1007/s13239-011-0065-3/fulltext.html
	with some tweaks (mainly robust estimation of the rectified signal
	cutoff threshold).
	"""

	ransac_window_size = int(ransac_window_size*rate)

	lowpass = scipy.signal.butter(1, highfreq/(rate/2.0), 'low')
	highpass = scipy.signal.butter(1, lowfreq/(rate/2.0), 'high')
	# TODO: Could use an actual bandpass filter
	ecg_low = scipy.signal.filtfilt(*lowpass, x=ecg)
	ecg_band = scipy.signal.filtfilt(*highpass, x=ecg_low)
	
	# Square (=signal power) of the first difference of the signal
	decg = np.diff(ecg_band)
	decg_power = decg**2
	
	# Robust threshold and normalizator estimation
	thresholds = []
	max_powers = []
	for i in range(int(len(decg_power)/ransac_window_size)):
		sample = slice(i*ransac_window_size, (i+1)*ransac_window_size)
		d = decg_power[sample]
		thresholds.append(0.5*np.std(d))
		max_powers.append(np.max(d))

	threshold = np.median(thresholds)
	max_power = np.median(max_powers)
	decg_power[decg_power < threshold] = 0

	decg_power /= max_power
	decg_power[decg_power > 1.0] = 1.0
	square_decg_power = decg_power**2

	shannon_energy = -square_decg_power*np.log(square_decg_power)
	shannon_energy[~np.isfinite(shannon_energy)] = 0.0

	mean_window_len = int(rate*0.125+1)
	lp_energy = np.convolve(shannon_energy, [1.0/mean_window_len]*mean_window_len, mode='same')
	#lp_energy = scipy.signal.filtfilt(*lowpass2, x=shannon_energy)
	
	lp_energy = scipy.ndimage.gaussian_filter1d(lp_energy, rate/8.0)
	lp_energy_diff = np.diff(lp_energy)

	zero_crossings = (lp_energy_diff[:-1] > 0) & (lp_energy_diff[1:] < 0)
	zero_crossings = np.flatnonzero(zero_crossings)
	zero_crossings -= 1
	return zero_crossings

In [0]:
class Data_Preprocessing():
  # Data locations
  DATA = "physionet.org/files/challenge-2017/1.0.0/training/"
  NORMAL = "physionet.org/files/challenge-2017/1.0.0/training/RECORDS-normal" 
  AF = "physionet.org/files/challenge-2017/1.0.0/training/RECORDS-af"
  OTHER = "physionet.org/files/challenge-2017/1.0.0/training/RECORDS-other"
  NOISY = "physionet.org/files/challenge-2017/1.0.0/training/RECORDS-noisy"
  
  # Class labels
  LABELS = {NORMAL: 0, AF: 1, OTHER:2, NOISY: 3}

  # Data storage
  data = []

  # ECG structure
  ECG_LENGTH = 600
  ECG_PER_SAMPLE = int(ECG_LENGTH/188)+1
  def process_data(self):
    for records in self.LABELS:
      with open(records) as record:
        for ecg_file in tqdm(record):
          path = self.DATA+ecg_file[:-1]                        # Path of data file
          metadata = open(path+".hea", "r").read().split(" ")   # Metadata of data file
          ECGs = list(loadmat(path)['val'][0])                        # Processes the data file

          for i in range(int(self.ECG_LENGTH+1)):
            ECGs.insert(i, 0)
            ECGs.append(0)
          
          peaks = detect_beats(ECGs, float(metadata[2]))

          for peak in range(0, len(peaks), self.ECG_PER_SAMPLE):
            try:
              ECG = ECGs[peaks[peak]-int(self.ECG_LENGTH/2):peaks[peak+self.ECG_PER_SAMPLE]+int(self.ECG_LENGTH/2)]
              ECG = ECG / (np.amax(ECG)-np.amin(ECG))       
              ECG = self.zero_padding(self.rnd_zero(ECG))
              self.data.append([np.array(ECG), np.eye(len(self.LABELS))[self.LABELS[records]]])

              # Augmented ECG
              aug_ECG = self.zero_padding(self.rnd_zero(self.resampling(ECG)))
              aug_ECG = aug_ECG / (np.amax(aug_ECG)-np.amin(aug_ECG))  
              self.data.append([np.array(aug_ECG), np.eye(len(self.LABELS))[self.LABELS[records]]])
            except Exception as e:
              #print(e)
              pass
              
  
  def balance_data(self):
    balanced_data = []
    dist = [0] * 4
    dist_count = [0] * 4

    for ECG in self.data:
      dist[np.argmax(ECG[1])] += 1
    
    for ECG in self.data:
      if dist_count[np.argmax(ECG[1])] <= min(dist):
        balanced_data.append(ECG)
        dist_count[np.argmax(ECG[1])] += 1
        
  
    return balanced_data
  
  def split(self, ECGs):
    ECG = []

    # Splits the ECGs
    for i in range(0, len(ECGs), self.ECG_LENGTH):
      ECG.append(ECGs[i:i+self.ECG_LENGTH])

    # Removes ECG if little data in last ECG
    if len(ECG[-1]) < int(self.ECG_LENGTH/4):
      ECG.pop(-1)
    
    # If ECG if less than suggested length, send to get zero padding
    elif len(ECG[-1]) < self.ECG_LENGTH:
      ECG[-1] = self.zero_padding(ECG[-1])
    
    return ECG
  
  # Adds zeros till ECG length is suggested length
  def zero_padding(self, ECG):
    ECG = list(ECG)
    if len(ECG) > self.ECG_LENGTH:
      return ECG[:self.ECG_LENGTH]
    for _ in range(self.ECG_LENGTH-len(ECG)):
      ECG.append(0)
    return ECG

  def rnd_zero(self, ecg):
    for _ in range(np.random.randint(7)):
      pos = abs(np.random.randint(abs(len(ecg)-11)))
      dist = abs(np.random.randint(7))
      #print(len(ecg), pos, pos+dist, [0]*dist)
      ecg[pos:pos+dist]=[0]*dist
    return ecg

  def resampling(self, ecg):
	  MARGIN = 60
	  return signal.resample(ecg, abs(np.random.randint(MARGIN)+(self.ECG_LENGTH-MARGIN)))

  def get_data(self):
    return self.data

  def save_data(self, data, file_name):
    np.random.shuffle(data)                             # Shuffles the data
    np.save(file_name, data)  

In [0]:
data = Data_Preprocessing()

if REBUILD_DATA:
  data.process_data()
  data.save_data(data.get_data(), "ECG_data.npy")
  data.save_data(data.balance_data(), "ECG_BAL_data.npy")

1457it [00:30, 48.87it/s]

In [0]:
classes = ["Normal", "Atrial Fibrillation", "Other", "Noisy"]
dist = [0]*4
for ecg in np.load("ECG_BAL_data.npy", allow_pickle=True):
  dist[np.argmax(ecg[1])] += 1
fig1, ax1 = plt.subplots()
ax1.pie(dist, labels=classes, autopct='%1.1f%%',
        shadow=True, startangle=90)
ax1.axis('equal')

In [0]:
data = np.load("ECG_BAL_data.npy", allow_pickle=True)
print(len(data[300][0]))
plt.plot(data[300][0])

In [0]:
X = torch.Tensor([i[0] for i in data])
y = torch.Tensor([i[1] for i in data])
 
VAL_PCT = 0.1
val_size = int(len(X)*VAL_PCT)

train_X = X[:-val_size]
train_y = y[:-val_size]

test_X = X[-val_size:]
test_y = y[-val_size:]

train_data = data[:-val_size]

test_data = data[-val_size:]

print(f"Validation Set Size: {len(test_X)} \nTraining Set Size: {len(train_X)}")

In [0]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(1,180, 5, padding=2) 
        self.conv2 = nn.Conv1d(180, 150, 5, padding=2) 
        self.conv3 = nn.Conv1d(150, 120, 5, padding=2)
        self.conv4 = nn.Conv1d(120, 90, 5, padding=2)
        self.conv5 = nn.Conv1d(90, 45, 5, padding=2)

        x = torch.randn(1,1,Data_Preprocessing.ECG_LENGTH).view(-1,1,Data_Preprocessing.ECG_LENGTH)
        self._to_linear = None
        self.convs(x)

        self.fc1 = nn.Linear(self._to_linear, 64)
        self.fc2 = nn.Linear(64, 4)

    def convs(self, x):
        x = F.max_pool1d(F.relu(self.conv1(x)), 3)
        x = F.max_pool1d(F.relu(self.conv2(x)), 3)
        x = F.max_pool1d(F.relu(self.conv3(x)), 2)
        x = F.max_pool1d(F.relu(self.conv4(x)), 3)
        x = F.max_pool1d(F.relu(self.conv5(x)), 3)

        if self._to_linear is None:
            self._to_linear = x[0].shape[0]*x[0].shape[1]
        return x

    def forward(self, x):
        x = self.convs(x)
        x = x.view(-1, self._to_linear)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = Net().to(device)
print(net)

In [0]:
class training():
  optimizer = optim.Adam(net.parameters(), lr = 0.001)#, weight_decay=1e-5)
  loss_function = nn.CrossEntropyLoss()#weight=1/torch.Tensor(list(Data_PreProcessing.dataCount.values())))
  PATH = "/content/drive/My Drive/ECG_MODELS/"

  def train(self):
    net = Net().to(device).apply(self.weight_reset)
    layers, params = self.net_info(net)
    MODEL_NAME = f"layers-{layers}_parms-{str(params)[1:-1].replace(', ', '_')}_model-{int(time.time())}"
    BATCH_SIZE = 1000
    EPOCHS = 100

    with open(self.PATH+"data/"+f"{MODEL_NAME}.log", "a") as f:
        for epoch in tqdm(range(EPOCHS)):
            for i in range(0, len(train_X), BATCH_SIZE):
                batch_X = train_X[i:i+BATCH_SIZE].view(-1,1,Data_Preprocessing.ECG_LENGTH)
                batch_y = train_y[i:i+BATCH_SIZE]

                batch_X, batch_y = batch_X.to(device), batch_y.to(device)

                acc, loss = self.fwd_pass(batch_X, batch_y, train=True)
                
                if i % 50 == 0:
                    val_acc, val_loss = self.test(size=100)
                    f.write(f"{round(time.time(),3)},{round(float(acc),2)},{round(float(loss), 4)},{round(float(val_acc),2)},{round(float(val_loss),4)},{epoch}\n")
           #print(acc, val_acc)

    torch.save(net, self.PATH+"model_params/"+f'{MODEL_NAME}.pt')

    model_data = open(self.PATH+"model_data/"+f"{MODEL_NAME}.txt", "w")
    model_data.write(str(net))
    model_data.write("/n")
    model_data.write(str(self.optimizer))
    model_data.write("/n")
    model_data.write(str(self.loss_function))
    model_data.close()

    return MODEL_NAME, EPOCHS

  def fwd_pass(self, X, y, train=False):
    if train:
      net.zero_grad()
    outputs = net(X)
    matches  = [torch.argmax(i)==torch.argmax(j) for i, j in zip(outputs, y)]
    acc = matches.count(True)/len(matches)
    loss_function = nn.CrossEntropyLoss().to(device)
    loss = loss_function(outputs, torch.argmax(y, 1))
    if train:
      loss.backward()
      self.optimizer.step()
    return acc, loss

  def test(self,size=100):
      X, y = test_X[:size], test_y[:size]
      val_acc, val_loss = self.fwd_pass(X.view(-1, 1, Data_Preprocessing.ECG_LENGTH).to(device), y.to(device))
      return val_acc, val_loss

  def net_info(self, net):
    params = []
    for layers, m in enumerate(net.modules()):
      params.append(int(''.join(filter(lambda x: x.isdigit(),str((str(m).split(',')[1]))))))

    return layers, params

  def weight_reset(self, m):
      if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
          m.reset_parameters()

In [0]:
train_model = training()
MODEL_NAME, EPOCHS = train_model.train()

In [0]:
def graph_data(MODEL_NAME, EPOCHS):
  style.use("ggplot")
  PATH = "/content/drive/My Drive/ECG_MODELS/"

  contents = open(PATH+"data/"+f"{MODEL_NAME}.log", "r").read().split("\n")
  times = []
  accuracies = []
  losses = []
  mean_loss = []
  mean_acc = []
  val_accs = []
  val_losses = []
  
  contents = (c.split(",") for c in contents[:-1])
  for temp in list(contents):

    times.append(float(temp[0]))
    accuracies.append(float(temp[1]))
    losses.append(float(temp[2]))
    
    val_accs.append(float(temp[3]))
    val_losses.append(float(temp[4]))


  print("Best Valid Accuracy:", max(val_accs))
  times = (np.asarray(times) - min(times))
  times = times/np.amax(times)
  times = times * EPOCHS
  fig = plt.figure(figsize=(20, 10))
  ax1 = plt.subplot2grid((2,1), (0,0))
  ax2 = plt.subplot2grid((2,1), (1,0), sharex=ax1)

  ax1.set_ylim([0,1])
  ax1.set_ylabel('Accuracy (0.0-1.0)')
  ax1.plot(times, accuracies, label="Training Set Accuracy")
  ax1.plot(times, val_accs, label="Validation Set Accuracy")
  ax1.plot(times, abs(np.array(val_accs)-np.array(accuracies)), label="Δ Accuracy (between Training & Validation Set)")
  ax1.legend(loc=2)

  ax2.set_ylim([0,2])
  ax2.set_ylabel('Loss (0.0-2.0)')
  ax2.set_xlabel(f'Epochs (0-{EPOCHS})')
  ax2.plot(times, losses, label="Training Set Loss")
  ax2.plot(times, val_losses, label="Validation Set Loss")
  ax2.plot(times, abs(np.array(val_losses)-np.array(losses)), label="Δ Loss (between Training & Validation Set)")
  ax2.legend(loc=2)
  plt.savefig(PATH+"graphs/"+MODEL_NAME)
  plt.show()

In [0]:
graph_data(MODEL_NAME, EPOCHS)