In [2]:
import matplotlib.pyplot as plt

In [3]:
import numpy as np
import torch_geometric
import networkx as nx
import os
import torch
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from torch_geometric.data import Data
import pandas as pd

In [4]:
#>>> pd.read_excel('data/RhythmNames.xlsx')
#   Acronym Name                                     Full Name
#0          SB                               Sinus Bradycardia
#1          SR                                    Sinus Rhythm
#2         AFIB                            Atrial Fibrillation
#3         ST                                Sinus Tachycardia
#4          AF                                  Atrial Flutter
#5          SI                              Sinus Irregularity
#6          SVT                   Supraventricular Tachycardia
#7          AT                              Atrial Tachycardia
#8         AVNRT  Atrioventricular  Node Reentrant Tachycardia
#9          AVRT        Atrioventricular Reentrant Tachycardia
#10        SAAWR       Sinus Atrium to Atrial Wandering Rhythm

In [5]:
#>>> pd.read_excel('data/ConditionNames.xlsx')
#   Acronym Name                                  Full Name
#0          1AVB            1 degree atrioventricular block
#1          2AVB            2 degree atrioventricular block
#2         2AVB1  2 degree atrioventricular block(Type one)
#3         2AVB2  2 degree atrioventricular block(Type two)
#4          3AVB            3 degree atrioventricular block
#5           ABI                            atrial bigeminy
#6           ALS                            Axis left shift
#7           APB                     atrial premature beats
#8           AQW                            abnormal Q wave
#9           ARS                           Axis right shift
#10          AVB                     atrioventricular block
#11          CCR                 countercolockwise rotation
#12           CR                        colockwise rotation
#13          ERV     Early repolarization of the ventricles
#14         FQRS                                  fQRS Wave
#15          IDC            Interior differences conduction
#16          IVB                     Intraventricular block
#17          JEB                     junctional escape beat
#18          JPS                              J point shift
#19          JPT                  junctional premature beat
#20         LBBB                   left bundle branch block
#21        LBBBB              left back bundle branch block
#22        LFBBB             left front bundle branch block
#23         LRRI                           Long RR interval
#24          LVH                 left ventricle hypertrophy
#25         LVHV                left ventricle high voltage
#26      LVQRSAL              lower voltage QRS in all lead
#27      LVQRSCL            lower voltage QRS in chest lead
#28      LVQRSLL             lower voltage QRS in limb lead
#29           MI                      myocardial infarction
#30         MIBW         myocardial infraction in back wall
#31         MIFW   Myocardial infgraction in the front wall
#32         MILW    Myocardial infraction in the lower wall
#33         MISW     Myocardial infraction in the side wall
#34         PRIE                      PR interval extension
#35          PWC                              P wave Change
#36         QTIE                      QT interval extension
#37          RAH                   right atrial hypertrophy
#38         RAHV                  right atrial high voltage
#39         RBBB                  right bundle branch block
#40          RVH                right ventricle hypertrophy
#41         STDD                               ST drop down
#42          STE                               ST extension
#43         STTC                                ST-T Change
#44         STTU                                 ST tilt up
#45          TWC                              T wave Change
#46          TWO                            T wave opposite
#47           UW                                     U wave
#48           VB                       ventricular bigeminy
#49          VEB                    ventricular escape beat
#50          VFW                    ventricular fusion wave
#51          VPB                 ventricular premature beat
#52          VPE                  ventricular preexcitation
#53          VET               ventricular escape trigeminy
#54         WAVN   Wandering in the atrioventricalualr node
#55          WPW                                        WPW

In [9]:
def get_edge_index(node_num=12):
    src_nodes = [idx for idx in range(node_num) for _ in range(node_num)]
    tgt_nodes = [idx for _ in range(node_num) for idx in range(node_num)]
    edge_index = torch.tensor(np.array([src_nodes, tgt_nodes]), dtype=torch.long)
    return edge_index

In [10]:
from tqdm import tqdm

In [11]:
diagnostics_df = pd.read_excel('data/Diagnostics.xlsx')

In [12]:
from sklearn.preprocessing import LabelEncoder

In [13]:
le = LabelEncoder()
le.fit(diagnostics_df.Rhythm.unique())

LabelEncoder()

In [14]:
le.classes_

array(['AF', 'AFIB', 'AT', 'AVNRT', 'AVRT', 'SA', 'SAAWR', 'SB', 'SR',
       'ST', 'SVT'], dtype=object)

In [15]:
diagnostics_df['label'] = le.transform(diagnostics_df.Rhythm.values)

In [16]:
diagnostics_df['label']

0         1
1         7
2         5
3         7
4         0
         ..
10641    10
10642    10
10643    10
10644    10
10645    10
Name: label, Length: 10646, dtype: int64

In [29]:
data_list = []
for file_name in tqdm(os.listdir('data/')):
    if file_name.endswith('.csv'):
        sub_name = file_name.split('.')[0]
        y = torch.tensor(diagnostics_df.loc[diagnostics_df.FileName == sub_name].label.values[0])
        df = pd.read_csv(f'data/{file_name}', header=None)
        if len(df) != 5000:
            print(file_name)
            continue
        x = torch.Tensor(np.array([df[col].values for col in df.columns]))
        edge_index = get_edge_index(node_num=12)
        data = Data(x=x, edge_index=edge_index, y=y)
        data_list.append(data)

 53%|█████▎    | 5675/10653 [03:20<04:37, 17.96it/s]

MUSE_20180113_124215_52000.csv


100%|██████████| 10653/10653 [06:21<00:00, 27.95it/s]


In [30]:
train_loader = DataLoader(data_list[:8000], batch_size=32, shuffle=True)
test_loader = DataLoader(data_list[8000:], batch_size=32)

In [31]:
instance = next(iter(train_loader))

In [32]:
instance

DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

In [33]:
#g = torch_geometric.utils.to_networkx(instance, to_undirected=True)
#nx.draw(g)

In [34]:
instance.y[0]

tensor(8)

In [35]:
for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()

Step 1:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 2:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 3:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 4:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 5:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 6:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 7:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 8:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], 

Step 71:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 72:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 73:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 74:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 75:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 76:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 77:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 78:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2,

Step 137:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 138:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 139:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 140:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 141:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 142:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 143:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 144:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_i

Step 207:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 208:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 209:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 210:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 211:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 212:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 213:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_index=[2, 4608], y=[32], batch=[384], ptr=[33])

Step 214:
Number of graphs in the current batch: 32
DataBatch(x=[384, 5000], edge_i

In [36]:
hidden_channels = 250

In [37]:
# Training Graph Neural Networks for Graph Classification
# Embed each node by performing multiple rounds of message passing
# Aggregate node embeddings into a unified graph embedding (readout layer)
# Train a final classifier on the graph embedding
 
# There exists multiple readout layers in literature, but the most common one is to simply take the average of node embeddings

# For ex: In GCNConv we use the  ReLU(𝑥)=max(𝑥,0)  activation for obtaining localized node embeddings,
# before we apply our final classifier on top of a graph readout layer.

# PyTorch Geometric provides this functionality via torch_geometric.nn.global_mean_pool, 
# which takes in the node embeddings of all nodes in the mini-batch and the assignment vector batch 
# to compute a graph embedding of size [batch_size, hidden_channels] for each graph in the batch.

# The final architecture for applying GNNs to the task of graph classification then looks as follows and allows for complete end-to-end training:

from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn import GraphConv


class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GNN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GraphConv(-1, hidden_channels)
        self.conv2 = GraphConv(hidden_channels, hidden_channels)
        self.conv3 = GraphConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, len(le.classes_))

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        x = global_mean_pool(x, batch)

        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(-1, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, len(le.classes_))

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x


class GAT(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GAT, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GATConv(-1, hidden_channels)
        self.conv2 = GATConv(hidden_channels, hidden_channels)
        self.conv3 = GATConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, len(le.classes_))

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x
       
# Define the models
model1 = GCN(hidden_channels=hidden_channels)
model2= GAT(hidden_channels=hidden_channels)
model3 = GNN(hidden_channels=hidden_channels)

# Print them 
print(model1)
print(model2)
print(model3)

GCN(
  (conv1): GCNConv(-1, 250)
  (conv2): GCNConv(250, 250)
  (conv3): GCNConv(250, 250)
  (lin): Linear(in_features=250, out_features=11, bias=True)
)
GAT(
  (conv1): GATConv(-1, 250, heads=1)
  (conv2): GATConv(250, 250, heads=1)
  (conv3): GATConv(250, 250, heads=1)
  (lin): Linear(in_features=250, out_features=11, bias=True)
)
GNN(
  (conv1): GraphConv(-1, 250)
  (conv2): GraphConv(250, 250)
  (conv3): GraphConv(250, 250)
  (lin): Linear(in_features=250, out_features=11, bias=True)
)


In [38]:
# Set model paramters and model type
def set_model_parameters(model_type, lr=0.01):
    model = model_type
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.CrossEntropyLoss()
    return model, optimizer, criterion

# Train the model
def train(model, optimizer,criterion):
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
         out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
         loss = criterion(out, data.y)  # Compute the loss.
         loss.backward()  # Derive gradients.
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.

# Test the model 
def test(loader, model):
     model.eval()

     correct = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data.x, data.edge_index, data.batch)  
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
     return correct / len(loader.dataset)  # Derive ratio of correct predictions.

# Training and Testing Pipeline 
def running_epochs(model,optimizer,criterion):
    for epoch in range(1, 10):
        train(model,optimizer,criterion)
        train_acc = test(train_loader, model)
        test_acc = test(test_loader, model)
        print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

In [40]:
# Experiment 1: GCN Baseline 
model, optimizer, criterion = set_model_parameters(model1, lr=0.001)
running_epochs(model,optimizer,criterion)

Epoch: 001, Train Acc: 0.0428, Test Acc: 0.0389
Epoch: 002, Train Acc: 0.0428, Test Acc: 0.0389


KeyboardInterrupt: 