In [1]:
import torch
from torch import nn
import numpy as np
from chbmit import chbmit
from torch_geometric.utils.convert import to_networkx
from torch_geometric.nn.conv import GATv2Conv
import scipy as sp
from torch_geometric.utils import dense_to_sparse

In [2]:
class GraphLearner(nn.Module):
    def __init__(self, depth, threshold=0.1):
        super().__init__()
        self.depth = depth
        self.threshold= threshold
        self.Q = nn.Linear(depth,depth, bias=False)
        self.K = nn.Linear(depth,depth, bias=False)

    def forward(self,tensor):
        queries = self.Q(tensor)
        keys = torch.transpose(self.K(tensor), -1,-2)

        # this matrix multiplication results in the shape of the adj. matrix
        adj = torch.matmul(queries,keys) / np.sqrt(self.depth) 
        
        # return the softmax of the learned adjacency matrix
        adj = torch.softmax(adj,dim=-1)

        # Average the adj matrix with the transpose so that the graph will be bidirectional
        # e.g the connection from [1,2] should be the same as [2,1]
        adj = (adj + adj.transpose(-1,-2))/2

        # Set values below specified threshold to 0
        adj[adj<=self.threshold] = 0

        edge_indexes, edge_weights = dense_to_sparse(adj)
        return edge_indexes, edge_weights

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [4]:
patient_0 = chbmit.CHB_MIT_PAITENT(tok_len=60, ctx_len=1, path="../physionet.org/files/chbmit/1.0.0/chb01/")

In [5]:
batch_size = 16
dataloader = torch.utils.data.DataLoader(patient_0, batch_size=batch_size, shuffle=True)

In [6]:
class GraphAttentionNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.gsl = GraphLearner(60*256,threshold = 0.04)
        self.layer1 = GATv2Conv(60*256, 30*256)
        self.activation = torch.nn.LeakyReLU()
    
    def forward(self, x):

        edge_idx, edge_weights = self.gsl(x)

        x = self.layer1(x, edge_index=edge_idx)

        x = self.activation(x)

        return x

In [8]:
input = next(iter(dataloader))[0]
GAT = GraphAttentionNetwork()
output = GAT(input)

AssertionError: 

In [None]:
network = to_networkx(data=data)
kx.draw(network)