In [31]:
# Import libraries and packages
import numpy as np
import argparse
import time
import gc
import random
from math import ceil

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.nn.functional import cosine_similarity

from torch import Tensor
from torch.nn import init
from torch.nn.parameter import Parameter

from torch.utils.data import TensorDataset

# Seeds we tried: 42, 1992, 250, 213, 1709
random.seed(42)
torch.manual_seed(42)

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# DEVICE = torch.device("cpu")

In [32]:
X_test = np.load('Data/eICU Data/X_test_multilabel_full_42.npy', mmap_mode='c')
y_test = np.load('Data/eICU Data/y_test_multilabel_full_42.npy', mmap_mode='c')

X_test = np.swapaxes(X_test, 1, 2)

X_test = torch.from_numpy(X_test)
y_test = torch.from_numpy(y_test)

X_test = X_test[:, None, :, :]
assert X_test.shape == (230, 1, 80, 24)

# init [num_variables, seq_length, num_classes]
num_nodes = X_val.size(-2)

seq_length = X_val.size(-1)

num_classes = 5

# convert data & labels to TensorDataset
test_dataset = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))

test_loader = torch.utils.data.DataLoader(test_dataset,
                                            batch_size=230,
                                            shuffle=False,
                                            pin_memory=True)

  test_dataset = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))


In [25]:
# Define the model

# %% [code]
class multi_shallow_embedding(nn.Module):

    def __init__(self, num_nodes, k_neighs, num_graphs):
        super().__init__()

        self.num_nodes = num_nodes
        self.k = k_neighs
        self.num_graphs = num_graphs

        self.emb_s = Parameter(Tensor(num_graphs, num_nodes, 1))
        self.emb_t = Parameter(Tensor(num_graphs, 1, num_nodes))

    def reset_parameters(self):
        init.xavier_uniform_(self.emb_s)
        init.xavier_uniform_(self.emb_t)


    def forward(self, device):

        # adj: [G, N, N]
        adj = torch.matmul(self.emb_s, self.emb_t).to(device)

        # remove self-loop
        adj = adj.clone()
        idx = torch.arange(self.num_nodes, dtype=torch.long, device=device)
        adj[:, idx, idx] = float('-inf')

        # top-k-edge adj
        adj_flat = adj.reshape(self.num_graphs, -1)
        indices = adj_flat.topk(k=self.k)[1].reshape(-1)

        idx = torch.tensor([ i//self.k for i in range(indices.size(0)) ], device=device)

        adj_flat = torch.zeros_like(adj_flat).clone()
        adj_flat[idx, indices] = 1.
        adj = adj_flat.reshape_as(adj)

        return adj


class Group_Linear(nn.Module):

    def __init__(self, in_channels, out_channels, groups=1, bias=False):
        super().__init__()

        self.out_channels = out_channels
        self.groups = groups

        self.group_mlp = nn.Conv2d(in_channels * groups, out_channels * groups, kernel_size=(1, 1), groups=groups, bias=bias)

        self.reset_parameters()

    def reset_parameters(self):
        self.group_mlp.reset_parameters()


    def forward(self, x: Tensor, is_reshape: False):
        """
        Args:
            x (Tensor): [B, C, N, F] (if not is_reshape), [B, C, G, N, F//G] (if is_reshape)
        """
        B = x.size(0)
        C = x.size(1)
        N = x.size(-2)
        G = self.groups

        if not is_reshape:
            # x: [B, C_in, G, N, F//G]
            x = x.reshape(B, C, N, G, -1).transpose(2, 3)
        # x: [B, G*C_in, N, F//G]
        x = x.transpose(1, 2).reshape(B, G*C, N, -1)

        out = self.group_mlp(x)
        out = out.reshape(B, G, self.out_channels, N, -1).transpose(1, 2)

        # out: [B, C_out, G, N, F//G]
        return out


class DenseGCNConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, groups=1, bias=True):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.lin = Group_Linear(in_channels, out_channels, groups, bias=False)

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        init.zeros_(self.bias)

    def norm(self, adj: Tensor, add_loop):
        if add_loop:
            adj = adj.clone()
            idx = torch.arange(adj.size(-1), dtype=torch.long, device=adj.device)
            adj[:, idx, idx] += 1

        deg_inv_sqrt = adj.sum(-1).clamp(min=1).pow(-0.5)

        adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2)

        return adj


    def forward(self, x: Tensor, adj: Tensor, add_loop=True):
        """
        Args:
            x (Tensor): [B, C, N, F]
            adj (Tensor): [B, G, N, N]
        """
        adj = self.norm(adj, add_loop).unsqueeze(1)

        # x: [B, C, G, N, F//G]
        x = self.lin(x, False)

        out = torch.matmul(adj, x)

        # out: [B, C, N, F]
        B, C, _, N, _ = out.size()
        out = out.transpose(2, 3).reshape(B, C, N, -1)

        if self.bias is not None:
            out = out.transpose(1, -1) + self.bias
            out = out.transpose(1, -1)

        return out

class DenseGINConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, groups=1, eps=0, train_eps=True):
        super().__init__()

        self.mlp = Group_Linear(in_channels, out_channels, groups, bias=False)

        # Encoder
        self.encoder_mean = Group_Linear(in_channels, out_channels, groups, bias=False)
        self.encoder_logvar = Group_Linear(in_channels, out_channels, groups, bias=False)

        # Decoder
        self.mlp = Group_Linear(out_channels, in_channels, groups, bias=False)  # Adjust output channels


        self.init_eps = eps
        if train_eps:
            self.eps = Parameter(Tensor([eps]))
        else:
            self.register_buffer('eps', Tensor([eps]))

        self.reset_parameters()

    def reset_parameters(self):
        self.mlp.reset_parameters()
        self.eps.data.fill_(self.init_eps)

    def norm(self, adj: Tensor, add_loop):
        if add_loop:
            adj = adj.clone()
            idx = torch.arange(adj.size(-1), dtype=torch.long, device=adj.device)
            adj[..., idx, idx] += 1

        deg_inv_sqrt = adj.sum(-1).clamp(min=1).pow(-0.5)

        adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2)

        return adj
    def reparameterize(self, mean, logvar):
        # Add an epsilon to prevent very large values
        epsilon = 1e-7
        std = torch.exp(0.5 * logvar) + epsilon
        eps = torch.randn_like(std)
        return mean + eps * std

    def forward(self, x: Tensor, adj: Tensor, add_loop=True):
        """
        Args:
            x (Tensor): [B, C, N, F]
            adj (Tensor): [G, N, N]
        """
        B, C, N, _ = x.size()
        G = adj.size(0)

        # adj-norm
        adj = self.norm(adj, add_loop=False)

        # x: [B, C, G, N, F//G]
        x = x.reshape(B, C, N, G, -1).transpose(2, 3)

        out = torch.matmul(adj, x)

        # DYNAMIC
        x_pre = x[:, :, :-1, ...]

        # out = x[:, :, 1:, ...] + x_pre
        out[:, :, 1:, ...] = out[:, :, 1:, ...] + x_pre
        # out = torch.cat( [x[:, :, 0, ...].unsqueeze(2), out], dim=2 )

        if add_loop:
            out = (1 + self.eps) * x + out

        # out: [B, C, G, N, F//G]
        out = self.mlp(out, True)

        # out: [B, C, N, F]
        C = out.size(1)
        out2 = out.transpose(2, 3).reshape(B, C, N, -1)

        # Variational encoding
        mean = self.mlp(x,True)
        logvar = self.mlp(x,True)


        return out2 , logvar, mean

class Dense_TimeDiffPool2d(nn.Module):

    def __init__(self, pre_nodes, pooled_nodes, kern_size, padding):
        super().__init__()

        # TODO: add Normalization
        self.time_conv = nn.Conv2d(pre_nodes, pooled_nodes, (1, kern_size), padding=(0, padding))

        self.re_param = Parameter(Tensor(kern_size, 1))

    def reset_parameters(self):
        self.time_conv.reset_parameters()
        init.kaiming_uniform_(self.re_param, nonlinearity='relu')


    def forward(self, x: Tensor, adj: Tensor):
        """
        Args:
            x (Tensor): [B, C, N, F]
            adj (Tensor): [G, N, N]
        """
        x = x.transpose(1, 2)
        out = self.time_conv(x)
        out = out.transpose(1, 2)

        # s: [ N^(l+1), N^l, 1, K ]
        s = torch.matmul(self.time_conv.weight, self.re_param).view(out.size(-2), -1)

        out_adj = torch.matmul(torch.matmul(s, adj), s.transpose(0, 1))

        return out, out_adj

class GNNStack(nn.Module):

    def __init__(self, gnn_model_type, num_layers, groups, pool_ratio, kern_size,
                 in_dim, hidden_dim, out_dim,
                 seq_len, num_nodes, num_classes, dropout=0.7, activation=nn.ReLU()):

        super().__init__()

        self.attention_matrix = nn.Parameter(torch.randn(groups, num_nodes, num_nodes))
        nn.init.xavier_uniform_(self.attention_matrix)
        #******************************************************************************

        k_neighs = self.num_nodes = num_nodes

        self.num_graphs = groups

        self.num_feats = seq_len
        if seq_len % groups:
            self.num_feats += ( groups - seq_len % groups )
        self.g_constr = multi_shallow_embedding(num_nodes, k_neighs, self.num_graphs)

        gnn_model, heads = self.build_gnn_model(gnn_model_type)

        assert num_layers >= 1, 'Error: Number of layers is invalid.'
        assert num_layers == len(kern_size), 'Error: Number of kernel_size should equal to number of layers.'
        paddings = [ (k - 1) // 2 for k in kern_size ]

        self.tconvs = nn.ModuleList(
            [nn.Conv2d(1, in_dim, (1, kern_size[0]), padding=(0, paddings[0]))] +
            [nn.Conv2d(heads * in_dim, hidden_dim, (1, kern_size[layer+1]), padding=(0, paddings[layer+1])) for layer in range(num_layers - 2)] +
            [nn.Conv2d(heads * hidden_dim, out_dim, (1, kern_size[-1]), padding=(0, paddings[-1]))]
        )

        self.lstm_layers = nn.ModuleList()
        self.lstm_layers.append(nn.RNN(input_size=24, hidden_size=in_dim * 24, batch_first=True))
        for _ in range(1, num_layers - 2):
            self.lstm_layers.append(nn.RNN(input_size=in_dim * 24, hidden_size=hidden_dim * 24, batch_first=True))
        self.lstm_layers.append(nn.RNN(input_size=hidden_dim * 24, hidden_size=out_dim * 24, batch_first=True))

        self.gconvs = nn.ModuleList(
            [gnn_model(in_dim, heads * in_dim, groups)] +
            [gnn_model(hidden_dim, heads * hidden_dim, groups) for _ in range(num_layers - 2)] +
            [gnn_model(out_dim, heads * out_dim, groups)]
        )

        self.bns = nn.ModuleList(
            [nn.BatchNorm2d(heads * in_dim)] +
            [nn.BatchNorm2d(heads * hidden_dim) for _ in range(num_layers - 2)] +
            [nn.BatchNorm2d(heads * out_dim)]
        )

        self.left_num_nodes = []
        for layer in range(num_layers + 1):
            left_node = round( num_nodes * (1 - (pool_ratio*layer)) )
            if left_node > 0:
                self.left_num_nodes.append(left_node)
            else:
                self.left_num_nodes.append(1)
        self.diffpool = nn.ModuleList(
            [Dense_TimeDiffPool2d(self.left_num_nodes[layer], self.left_num_nodes[layer+1], kern_size[layer], paddings[layer]) for layer in range(num_layers - 1)] +
            [Dense_TimeDiffPool2d(self.left_num_nodes[-2], self.left_num_nodes[-1], kern_size[-1], paddings[-1])]
        )

        self.num_layers = num_layers
        self.dropout = dropout
        self.activation = activation

        self.softmax = nn.Softmax(dim=-1)
        self.global_pool = nn.AdaptiveAvgPool2d(1)

        self.linear = nn.Linear(heads * out_dim, num_classes)

        self.reset_parameters()

    def reset_parameters(self):
        for lstm, gconv, bn, pool in zip(self.lstm_layers, self.gconvs, self.bns, self.diffpool):
            lstm.reset_parameters()
            gconv.reset_parameters()
            bn.reset_parameters()
            pool.reset_parameters()

        self.linear.reset_parameters()

    def build_gnn_model(self, model_type):
        if model_type == 'dyGCN2d':
            return DenseGCNConv2d, 1
        if model_type == 'DenseGINConv2d':
            return DenseGINConv2d, 1


    def forward(self, inputs: Tensor):

        if inputs.size(-1) % self.num_graphs:
            pad_size = (self.num_graphs - inputs.size(-1) % self.num_graphs) / 2
            x = F.pad(inputs, (int(pad_size), ceil(pad_size)), mode='constant', value=0.0)
        else:
            x = inputs

        adj = self.g_constr(x.device)
   
        #*******************************************************************
        # Attention layer 
        attention_scores = F.softmax(self.attention_matrix, dim=-1)  
        adj = adj * attention_scores 
        # print(attention_scores,"attention_scores")
        #*******************************************************************    
        
        for lstm, gconv, bn, pool in zip(self.lstm_layers, self.gconvs, self.bns, self.diffpool):           
            s=x.shape[1]
            if s==1:
               x1=x.repeat(1, 128, 1,1)
            else:
               x1=x.repeat(1, 2, 1,1)    
            batch_size, channels, seq_len, features = x.size()
            x = x.view(batch_size, seq_len, channels * features)
            # Apply LSTM layer
            x, _ = lstm(x)
            x = torch.reshape(x, (batch_size, -1, seq_len, features))
            x=x+x1

            temp, logvar, mean = gconv(x, adj)
            
            x, adj = pool(temp, adj)

            x = self.activation(bn(x))
            x = F.dropout(x, p=self.dropout, training=self.training)
            

        out = self.global_pool(x)
        out = out.view(out.size(0), -1)
        out = self.linear(out)

        return out,logvar, mean,adj

In [26]:
# Metrics
# %% [code]
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

# %% [code]
def balanced_accuracy_multi(inp, targ, thresh=0.5, sigmoid=True):
    "Computes balanced accuracy when `inp` and `targ` are the same size."

    if sigmoid: inp = inp.sigmoid()
    pred = inp>thresh

    correct = pred==targ.bool()
    TP = torch.logical_and(correct,  (targ==1).bool()).sum()
    TN = torch.logical_and(correct,  (targ==0).bool()).sum()
    FN = torch.logical_and(~correct, (targ==1).bool()).sum()
    FP = torch.logical_and(~correct, (targ==0).bool()).sum()

    TPR = TP/(TP+FN)
    TNR = TN/(TN+FP)
    balanced_accuracy = (TPR+TNR)/2
    return balanced_accuracy

# %% [code]
def Fbeta_multi(inp, targ, beta=1.0, thresh=0.5, sigmoid=True):
    "Computes Fbeta when `inp` and `targ` are the same size."

    if sigmoid: inp = inp.sigmoid()
    pred = inp>thresh

    correct = pred==targ.bool()
    TP = torch.logical_and(correct,  (targ==1).bool()).sum()
    TN = torch.logical_and(correct,  (targ==0).bool()).sum()
    FN = torch.logical_and(~correct, (targ==1).bool()).sum()
    FP = torch.logical_and(~correct, (targ==0).bool()).sum()

    precision = TP/(TP+FP)
    recall = TP/(TP+FN)
    beta2 = beta*beta

    if precision+recall > 0:
        Fbeta = (1+beta2)*precision*recall/(beta2*precision+recall)
    else:
        Fbeta = 0
    return Fbeta


# %% [code]
def recall_multi(inp, targ, thresh=0.5, sigmoid=True):
    "Computes recall when `inp` and `targ` are the same size."

    if sigmoid: inp = inp.sigmoid()
    pred = inp>thresh

    correct = pred==targ.bool()
    TP = torch.logical_and(correct,  (targ==1).bool()).sum()
    FN = torch.logical_and(~correct, (targ==1).bool()).sum()

    recall = TP/(TP+FN)
    return recall

In [33]:
model = GNNStack(gnn_model_type='DenseGINConv2d', num_layers=2,
                     groups=6, pool_ratio=0.2, kern_size=[11, 3],
                     in_dim=128, hidden_dim=128, out_dim=256,
                     seq_len=seq_length, num_nodes=num_nodes, num_classes=num_classes)

In [34]:
model = torch.load('models/eICU_1992.pt', map_location=torch.device('cpu'))
model.eval()

with torch.no_grad():
    for count, (data, label) in enumerate(test_loader):

        data = data.to(DEVICE).type(torch.float)
        label = label.to(DEVICE).type(torch.float)

        # compute output
        output = model(data)[0]

        # measure accuracy and record loss
        acc1 = balanced_accuracy_multi(output, label)
        f1 = Fbeta_multi(output, label)
        sens = recall_multi(output, label)

        print(acc1, f1, sens)

tensor(0.6721) tensor(0.4581) tensor(0.4128)
