In [1]:
import matplotlib.pyplot as plt

In [2]:
import numpy as np
import torch_geometric
import networkx as nx
import os
import torch
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from torch_geometric.data import Data
import pandas as pd
import torch.nn as nn

In [3]:
#>>> pd.read_excel('data/RhythmNames.xlsx')
#   Acronym Name                                     Full Name
#0          SB                               Sinus Bradycardia
#1          SR                                    Sinus Rhythm
#2         AFIB                            Atrial Fibrillation
#3         ST                                Sinus Tachycardia
#4          AF                                  Atrial Flutter
#5          SI                              Sinus Irregularity
#6          SVT                   Supraventricular Tachycardia
#7          AT                              Atrial Tachycardia
#8         AVNRT  Atrioventricular  Node Reentrant Tachycardia
#9          AVRT        Atrioventricular Reentrant Tachycardia
#10        SAAWR       Sinus Atrium to Atrial Wandering Rhythm

In [4]:
#>>> pd.read_excel('data/ConditionNames.xlsx')
#   Acronym Name                                  Full Name
#0          1AVB            1 degree atrioventricular block
#1          2AVB            2 degree atrioventricular block
#2         2AVB1  2 degree atrioventricular block(Type one)
#3         2AVB2  2 degree atrioventricular block(Type two)
#4          3AVB            3 degree atrioventricular block
#5           ABI                            atrial bigeminy
#6           ALS                            Axis left shift
#7           APB                     atrial premature beats
#8           AQW                            abnormal Q wave
#9           ARS                           Axis right shift
#10          AVB                     atrioventricular block
#11          CCR                 countercolockwise rotation
#12           CR                        colockwise rotation
#13          ERV     Early repolarization of the ventricles
#14         FQRS                                  fQRS Wave
#15          IDC            Interior differences conduction
#16          IVB                     Intraventricular block
#17          JEB                     junctional escape beat
#18          JPS                              J point shift
#19          JPT                  junctional premature beat
#20         LBBB                   left bundle branch block
#21        LBBBB              left back bundle branch block
#22        LFBBB             left front bundle branch block
#23         LRRI                           Long RR interval
#24          LVH                 left ventricle hypertrophy
#25         LVHV                left ventricle high voltage
#26      LVQRSAL              lower voltage QRS in all lead
#27      LVQRSCL            lower voltage QRS in chest lead
#28      LVQRSLL             lower voltage QRS in limb lead
#29           MI                      myocardial infarction
#30         MIBW         myocardial infraction in back wall
#31         MIFW   Myocardial infgraction in the front wall
#32         MILW    Myocardial infraction in the lower wall
#33         MISW     Myocardial infraction in the side wall
#34         PRIE                      PR interval extension
#35          PWC                              P wave Change
#36         QTIE                      QT interval extension
#37          RAH                   right atrial hypertrophy
#38         RAHV                  right atrial high voltage
#39         RBBB                  right bundle branch block
#40          RVH                right ventricle hypertrophy
#41         STDD                               ST drop down
#42          STE                               ST extension
#43         STTC                                ST-T Change
#44         STTU                                 ST tilt up
#45          TWC                              T wave Change
#46          TWO                            T wave opposite
#47           UW                                     U wave
#48           VB                       ventricular bigeminy
#49          VEB                    ventricular escape beat
#50          VFW                    ventricular fusion wave
#51          VPB                 ventricular premature beat
#52          VPE                  ventricular preexcitation
#53          VET               ventricular escape trigeminy
#54         WAVN   Wandering in the atrioventricalualr node
#55          WPW                                        WPW

In [5]:
diagnostics_df = pd.read_excel('data/Diagnostics.xlsx')

In [6]:
from sklearn.preprocessing import LabelEncoder

In [7]:
le = LabelEncoder()
le.fit(diagnostics_df.Rhythm.unique())

LabelEncoder()

In [8]:
le.classes_

array(['AF', 'AFIB', 'AT', 'AVNRT', 'AVRT', 'SA', 'SAAWR', 'SB', 'SR',
       'ST', 'SVT'], dtype=object)

In [9]:
diagnostics_df['label'] = le.transform(diagnostics_df.Rhythm.values)

In [10]:
diagnostics_df['label']

0         1
1         7
2         5
3         7
4         0
         ..
10641    10
10642    10
10643    10
10644    10
10645    10
Name: label, Length: 10646, dtype: int64

In [11]:
data_list = []
for file_name in tqdm(os.listdir('data/')):
    if file_name.endswith('.csv'):
        sub_name = file_name.split('.')[0]
        y = torch.tensor(diagnostics_df.loc[diagnostics_df.FileName == sub_name].label.values[0])
        df = pd.read_csv(f'data/{file_name}', header=None)
        if len(df) != 5000:
            print(file_name)
            continue
        x = torch.Tensor(np.array([df[col].values for col in df.columns])).t()
        data = [x, y]
        data_list.append(data)

 53%|█████▎    | 5690/10653 [01:01<00:50, 99.12it/s] 

MUSE_20180113_124215_52000.csv


100%|██████████| 10653/10653 [01:52<00:00, 94.39it/s]


In [12]:
n_epochs = 250
lr = 0.01
n_folds = 5
lstm_input_size = 12
hidden_state_size = 30
batch_size = 32
num_sequence_layers = 2
output_dim = 11
num_time_steps = 5000
rnn_type = 'LSTM'

In [13]:
train_loader = DataLoader(data_list[:8000], batch_size=batch_size, shuffle=True)
test_loader = DataLoader(data_list[8000:], batch_size=batch_size)

In [14]:
instance = next(iter(train_loader))

In [15]:
instance[0].shape

torch.Size([32, 5000, 12])

In [16]:
#g = torch_geometric.utils.to_networkx(instance, to_undirected=True)
#nx.draw(g)

In [17]:
for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Batch {step}: {data[0].shape}')
    print()

Step 1:
Batch 0: torch.Size([32, 5000, 12])

Step 2:
Batch 1: torch.Size([32, 5000, 12])

Step 3:
Batch 2: torch.Size([32, 5000, 12])

Step 4:
Batch 3: torch.Size([32, 5000, 12])

Step 5:
Batch 4: torch.Size([32, 5000, 12])

Step 6:
Batch 5: torch.Size([32, 5000, 12])

Step 7:
Batch 6: torch.Size([32, 5000, 12])

Step 8:
Batch 7: torch.Size([32, 5000, 12])

Step 9:
Batch 8: torch.Size([32, 5000, 12])

Step 10:
Batch 9: torch.Size([32, 5000, 12])

Step 11:
Batch 10: torch.Size([32, 5000, 12])

Step 12:
Batch 11: torch.Size([32, 5000, 12])

Step 13:
Batch 12: torch.Size([32, 5000, 12])

Step 14:
Batch 13: torch.Size([32, 5000, 12])

Step 15:
Batch 14: torch.Size([32, 5000, 12])

Step 16:
Batch 15: torch.Size([32, 5000, 12])

Step 17:
Batch 16: torch.Size([32, 5000, 12])

Step 18:
Batch 17: torch.Size([32, 5000, 12])

Step 19:
Batch 18: torch.Size([32, 5000, 12])

Step 20:
Batch 19: torch.Size([32, 5000, 12])

Step 21:
Batch 20: torch.Size([32, 5000, 12])

Step 22:
Batch 21: torch.Size([3

Step 154:
Batch 153: torch.Size([32, 5000, 12])

Step 155:
Batch 154: torch.Size([32, 5000, 12])

Step 156:
Batch 155: torch.Size([32, 5000, 12])

Step 157:
Batch 156: torch.Size([32, 5000, 12])

Step 158:
Batch 157: torch.Size([32, 5000, 12])

Step 159:
Batch 158: torch.Size([32, 5000, 12])

Step 160:
Batch 159: torch.Size([32, 5000, 12])

Step 161:
Batch 160: torch.Size([32, 5000, 12])

Step 162:
Batch 161: torch.Size([32, 5000, 12])

Step 163:
Batch 162: torch.Size([32, 5000, 12])

Step 164:
Batch 163: torch.Size([32, 5000, 12])

Step 165:
Batch 164: torch.Size([32, 5000, 12])

Step 166:
Batch 165: torch.Size([32, 5000, 12])

Step 167:
Batch 166: torch.Size([32, 5000, 12])

Step 168:
Batch 167: torch.Size([32, 5000, 12])

Step 169:
Batch 168: torch.Size([32, 5000, 12])

Step 170:
Batch 169: torch.Size([32, 5000, 12])

Step 171:
Batch 170: torch.Size([32, 5000, 12])

Step 172:
Batch 171: torch.Size([32, 5000, 12])

Step 173:
Batch 172: torch.Size([32, 5000, 12])

Step 174:
Batch 173:

In [18]:
class Bi_RNN(nn.Module):

    def __init__(self, input_dim, hidden_dim, batch_size, output_dim=11, num_layers=2, rnn_type='LSTM'):
        super(Bi_RNN, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.num_layers = num_layers

        #Define the initial linear hidden layer
        self.init_linear = nn.Linear(self.input_dim, self.input_dim)

        # Define the LSTM layer
        self.lstm = eval('nn.' + rnn_type)(self.input_dim, self.hidden_dim, self.num_layers, batch_first=True, bidirectional=True)

        # Define the output layer
        self.linear = nn.Linear(self.hidden_dim * 2, output_dim)

    def init_hidden(self):
        # This is what we'll initialise our hidden state as
        return (torch.zeros(self.num_layers, self.batch_size, self.hidden_dim),
                torch.zeros(self.num_layers, self.batch_size, self.hidden_dim))

    def forward(self, input):
        #Forward pass through initial hidden layer
        linear_input = self.init_linear(input)

        # Forward pass through LSTM layer
        # shape of lstm_out: [batch_size, input_size ,hidden_dim]
        # shape of self.hidden: (a, b), where a and b both
        # have shape (batch_size, num_layers, hidden_dim).
        lstm_out, self.hidden = self.lstm(linear_input)
        out_reduced = torch.cat((self.hidden[0][-2,:,:], self.hidden[0][-1,:,:]), dim = 1)
        
        
        # Can pass on the entirety of lstm_out to the next layer if it is a seq2seq prediction
        y_pred = self.linear(out_reduced)
        # text_fea = torch.squeeze(text_fea, 1)
        return y_pred

In [19]:
if torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'

In [20]:
model1 = Bi_RNN(lstm_input_size, hidden_state_size, batch_size=batch_size, output_dim=output_dim, num_layers=num_sequence_layers, rnn_type=rnn_type).to(device)

In [21]:
print(model1)

Bi_RNN(
  (init_linear): Linear(in_features=12, out_features=12, bias=True)
  (lstm): LSTM(12, 30, num_layers=2, batch_first=True, bidirectional=True)
  (linear): Linear(in_features=60, out_features=11, bias=True)
)


In [22]:
# Set model paramters and model type
def set_model_parameters(model_type, lr=0.01):
    model = model_type
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.CrossEntropyLoss()
    return model, optimizer, criterion

# Train the model
def train(model, optimizer,criterion):
    model.train()

    for idx, (x_batch, y_batch) in enumerate(train_loader):  # Iterate in batches over the training dataset.
        x_batch = x_batch.view(-1, num_time_steps, lstm_input_size).to(device)
        y_batch = y_batch.to(device)
        optimizer.zero_grad()
        y_pred = model(x_batch)
        loss = criterion(y_pred, y_batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()  # Clear gradients.

# Test the model 
def test(loader, model):
     model.eval()

     correct = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data[0].to(device))  
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data[1].to(device)).sum())  # Check against ground-truth labels.
     return correct / len(loader.dataset)  # Derive ratio of correct predictions.

# Training and Testing Pipeline 
def running_epochs(model,optimizer,criterion):
    for epoch in range(12):
        train(model,optimizer,criterion)
        train_acc = test(train_loader, model)
        test_acc = test(test_loader, model)
        print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

In [None]:
model, optimizer, criterion = set_model_parameters(model1, lr=0.001)
running_epochs(model,optimizer,criterion)

Epoch: 000, Train Acc: 0.0428, Test Acc: 0.0389
Epoch: 001, Train Acc: 0.0428, Test Acc: 0.0389
