In [None]:
import numpy as np
import argparse
import time
import gc
import random
from math import ceil

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.nn.functional as F

from torch import Tensor
from torch.nn import init
from torch.nn.parameter import Parameter

from torch.utils.data import TensorDataset

random.seed(1992)
torch.manual_seed(1992)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# DEVICE = torch.device("cpu")

Mounted at /content/gdrive


In [None]:
# Load your relevant numpy dataset files

data_train = X_train
data_val = X_val
data_test = X_test
label_train = y_train
label_val = y_val
label_test = y_test

# init [num_variables, seq_length, num_classes]
num_nodes = data_val.size(-2)

seq_length = data_val.size(-1)

num_classes = 6 # Define number of classes


# convert data & labels to TensorDataset
train_dataset = TensorDataset(torch.tensor(data_train), torch.tensor(label_train))
val_dataset = TensorDataset(torch.tensor(data_val), torch.tensor(label_val))
test_dataset = TensorDataset(torch.tensor(data_test), torch.tensor(label_test))


# data_loader
train_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size=8,
                                            shuffle=True,
                                            pin_memory=True)

val_loader = torch.utils.data.DataLoader(val_dataset,
                                            batch_size=476,
                                            shuffle=True,
                                            pin_memory=True)

test_loader = torch.utils.data.DataLoader(test_dataset,
                                            batch_size=476,
                                            shuffle=False,
                                            pin_memory=True)

  train_dataset = TensorDataset(torch.tensor(data_train), torch.tensor(label_train))
  val_dataset = TensorDataset(torch.tensor(data_val), torch.tensor(label_val))
  test_dataset = TensorDataset(torch.tensor(data_test), torch.tensor(label_test))


In [None]:
class multi_shallow_embedding(nn.Module):

    def __init__(self, num_nodes, k_neighs, num_graphs):
        super().__init__()

        self.num_nodes = num_nodes
        self.k = k_neighs
        self.num_graphs = num_graphs

        self.emb_s = Parameter(Tensor(num_graphs, num_nodes, 1))
        self.emb_t = Parameter(Tensor(num_graphs, 1, num_nodes))

    def reset_parameters(self):
        init.xavier_uniform_(self.emb_s)
        init.xavier_uniform_(self.emb_t)


    def forward(self, device):

        # adj: [G, N, N]
        adj = torch.matmul(self.emb_s, self.emb_t).to(device)

        # remove self-loop
        adj = adj.clone()
        idx = torch.arange(self.num_nodes, dtype=torch.long, device=device)
        adj[:, idx, idx] = float('-inf')

        # top-k-edge adj
        adj_flat = adj.reshape(self.num_graphs, -1)
        indices = adj_flat.topk(k=self.k)[1].reshape(-1)

        idx = torch.tensor([ i//self.k for i in range(indices.size(0)) ], device=device)

        adj_flat = torch.zeros_like(adj_flat).clone()
        adj_flat[idx, indices] = 1.
        adj = adj_flat.reshape_as(adj)

        return adj


class Group_Linear(nn.Module):

    def __init__(self, in_channels, out_channels, groups=1, bias=False):
        super().__init__()

        self.out_channels = out_channels
        self.groups = groups

        self.group_mlp = nn.Conv2d(in_channels * groups, out_channels * groups, kernel_size=(1, 1), groups=groups, bias=bias)

        self.reset_parameters()

    def reset_parameters(self):
        self.group_mlp.reset_parameters()


    def forward(self, x: Tensor, is_reshape: False):
        """
        Args:
            x (Tensor): [B, C, N, F] (if not is_reshape), [B, C, G, N, F//G] (if is_reshape)
        """
        B = x.size(0)
        C = x.size(1)
        N = x.size(-2)
        G = self.groups

        if not is_reshape:
            # x: [B, C_in, G, N, F//G]
            x = x.reshape(B, C, N, G, -1).transpose(2, 3)
        # x: [B, G*C_in, N, F//G]
        x = x.transpose(1, 2).reshape(B, G*C, N, -1)

        out = self.group_mlp(x)
        out = out.reshape(B, G, self.out_channels, N, -1).transpose(1, 2)

        # out: [B, C_out, G, N, F//G]
        return out


class DenseGCNConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, groups=1, bias=True):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.lin = Group_Linear(in_channels, out_channels, groups, bias=False)

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        init.zeros_(self.bias)

    def norm(self, adj: Tensor, add_loop):
        if add_loop:
            adj = adj.clone()
            idx = torch.arange(adj.size(-1), dtype=torch.long, device=adj.device)
            adj[:, idx, idx] += 1

        deg_inv_sqrt = adj.sum(-1).clamp(min=1).pow(-0.5)

        adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2)

        return adj


    def forward(self, x: Tensor, adj: Tensor, add_loop=True):
        """
        Args:
            x (Tensor): [B, C, N, F]
            adj (Tensor): [B, G, N, N]
        """
        adj = self.norm(adj, add_loop).unsqueeze(1)

        # x: [B, C, G, N, F//G]
        x = self.lin(x, False)

        out = torch.matmul(adj, x)

        # out: [B, C, N, F]
        B, C, _, N, _ = out.size()
        out = out.transpose(2, 3).reshape(B, C, N, -1)

        if self.bias is not None:
            out = out.transpose(1, -1) + self.bias
            out = out.transpose(1, -1)

        return out

class DenseGINConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, groups=1, eps=0, train_eps=True):
        super().__init__()

        # TODO: Multi-layer model
        self.mlp = Group_Linear(in_channels, out_channels, groups, bias=False)

        # Encoder part
        self.encoder_mean = Group_Linear(in_channels, out_channels, groups, bias=False)
        self.encoder_logvar = Group_Linear(in_channels, out_channels, groups, bias=False)

        # Decoder part (similar to the original DenseGINConv2d)
        self.mlp = Group_Linear(out_channels, in_channels, groups, bias=False)  # Adjust output channels


        self.init_eps = eps
        if train_eps:
            self.eps = Parameter(Tensor([eps]))
        else:
            self.register_buffer('eps', Tensor([eps]))

        self.reset_parameters()

    def reset_parameters(self):
        self.mlp.reset_parameters()
        self.eps.data.fill_(self.init_eps)

    def norm(self, adj: Tensor, add_loop):
        if add_loop:
            adj = adj.clone()
            idx = torch.arange(adj.size(-1), dtype=torch.long, device=adj.device)
            adj[..., idx, idx] += 1

        deg_inv_sqrt = adj.sum(-1).clamp(min=1).pow(-0.5)

        adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2)

        return adj
    def reparameterize(self, mean, logvar):
        # Add an epsilon to prevent very large values
        epsilon = 1e-7
        std = torch.exp(0.5 * logvar) + epsilon
        eps = torch.randn_like(std)
        return mean + eps * std

    def forward(self, x: Tensor, adj: Tensor, add_loop=True):
        """
        Args:
            x (Tensor): [B, C, N, F]
            adj (Tensor): [G, N, N]
        """
        B, C, N, _ = x.size()
        G = adj.size(0)

        # adj-norm
        adj = self.norm(adj, add_loop=False)

        # x: [B, C, G, N, F//G]
        x = x.reshape(B, C, N, G, -1).transpose(2, 3)

        out = torch.matmul(adj, x)

        # DYNAMIC
        x_pre = x[:, :, :-1, ...]

        # out = x[:, :, 1:, ...] + x_pre
        out[:, :, 1:, ...] = out[:, :, 1:, ...] + x_pre
        # out = torch.cat( [x[:, :, 0, ...].unsqueeze(2), out], dim=2 )

        if add_loop:
            out = (1 + self.eps) * x + out

        # out: [B, C, G, N, F//G]
        out = self.mlp(out, True)

        # out: [B, C, N, F]
        C = out.size(1)
        out2 = out.transpose(2, 3).reshape(B, C, N, -1)

        # Variational encoding
        mean = self.mlp(x,True)
        logvar = self.mlp(x,True)


        return out2 , logvar, mean

class Dense_TimeDiffPool2d(nn.Module):

    def __init__(self, pre_nodes, pooled_nodes, kern_size, padding):
        super().__init__()

        # TODO: add Normalization
        self.time_conv = nn.Conv2d(pre_nodes, pooled_nodes, (1, kern_size), padding=(0, padding))

        self.re_param = Parameter(Tensor(kern_size, 1))

    def reset_parameters(self):
        self.time_conv.reset_parameters()
        init.kaiming_uniform_(self.re_param, nonlinearity='relu')


    def forward(self, x: Tensor, adj: Tensor):
        """
        Args:
            x (Tensor): [B, C, N, F]
            adj (Tensor): [G, N, N]
        """
        x = x.transpose(1, 2)
        out = self.time_conv(x)
        out = out.transpose(1, 2)

        # s: [ N^(l+1), N^l, 1, K ]
        s = torch.matmul(self.time_conv.weight, self.re_param).view(out.size(-2), -1)

        # TODO: fully-connect, how to decrease time complexity
        out_adj = torch.matmul(torch.matmul(s, adj), s.transpose(0, 1))

        return out, out_adj



class GNNStack(nn.Module):
    """ The stack layers of GNN.

    """

    def __init__(self, gnn_model_type, num_layers, groups, pool_ratio, kern_size,
                 in_dim, hidden_dim, out_dim,
                 seq_len, num_nodes, num_classes, dropout=0.7, activation=nn.ReLU()):

        super().__init__()

        self.attention_matrix = nn.Parameter(torch.randn(groups, num_nodes, num_nodes))
        nn.init.xavier_uniform_(self.attention_matrix)
        #******************************************************************************

        # TODO: Sparsity Analysis
        k_neighs = self.num_nodes = num_nodes

        self.num_graphs = groups

        self.num_feats = seq_len
        if seq_len % groups:
            self.num_feats += ( groups - seq_len % groups )
        self.g_constr = multi_shallow_embedding(num_nodes, k_neighs, self.num_graphs)

        gnn_model, heads = self.build_gnn_model(gnn_model_type)

        self.heads = heads
        self.out_dim = out_dim

        assert num_layers >= 1, 'Error: Number of layers is invalid.'
        assert num_layers == len(kern_size), 'Error: Number of kernel_size should equal to number of layers.'
        paddings = [ (k - 1) // 2 for k in kern_size ]

        self.tconvs = nn.ModuleList(
            [nn.Conv2d(1, in_dim, (1, kern_size[0]), padding=(0, paddings[0]))] +
            [nn.Conv2d(heads * in_dim, hidden_dim, (1, kern_size[layer+1]), padding=(0, paddings[layer+1])) for layer in range(num_layers - 2)] +
            [nn.Conv2d(heads * hidden_dim, out_dim, (1, kern_size[-1]), padding=(0, paddings[-1]))]
        )

        self.lstm_layers = nn.ModuleList()
        self.lstm_layers.append(nn.RNN(input_size=24, hidden_size=in_dim * 24, batch_first=True))
        for _ in range(1, num_layers - 2):
            self.lstm_layers.append(nn.RNN(input_size=in_dim * 24, hidden_size=hidden_dim * 24, batch_first=True))
        self.lstm_layers.append(nn.RNN(input_size=hidden_dim * 24, hidden_size=out_dim * 24, batch_first=True))

        self.gconvs = nn.ModuleList(
            [gnn_model(in_dim, heads * in_dim, groups)] +
            [gnn_model(hidden_dim, heads * hidden_dim, groups) for _ in range(num_layers - 2)] +
            [gnn_model(out_dim, heads * out_dim, groups)]
        )

        self.bns = nn.ModuleList(
            [nn.BatchNorm2d(heads * in_dim)] +
            [nn.BatchNorm2d(heads * hidden_dim) for _ in range(num_layers - 2)] +
            [nn.BatchNorm2d(heads * out_dim)]
        )

        self.left_num_nodes = []
        for layer in range(num_layers + 1):
            left_node = round( num_nodes * (1 - (pool_ratio*layer)) )
            if left_node > 0:
                self.left_num_nodes.append(left_node)
            else:
                self.left_num_nodes.append(1)
        self.diffpool = nn.ModuleList(
            [Dense_TimeDiffPool2d(self.left_num_nodes[layer], self.left_num_nodes[layer+1], kern_size[layer], paddings[layer]) for layer in range(num_layers - 1)] +
            [Dense_TimeDiffPool2d(self.left_num_nodes[-2], self.left_num_nodes[-1], kern_size[-1], paddings[-1])]
        )

        self.num_layers = num_layers
        self.dropout = dropout
        self.activation = activation

        self.softmax = nn.Softmax(dim=-1)
        self.global_pool = nn.AdaptiveAvgPool2d(1)

        self.linear = nn.Linear(heads * out_dim, num_classes)

        self.reset_parameters()


    def reset_parameters(self):
        for lstm, gconv, bn, pool in zip(self.lstm_layers, self.gconvs, self.bns, self.diffpool):
            lstm.reset_parameters()
            gconv.reset_parameters()
            bn.reset_parameters()
            pool.reset_parameters()

        self.linear.reset_parameters()


    def build_gnn_model(self, model_type):
        if model_type == 'dyGCN2d':
            return DenseGCNConv2d, 1
        if model_type == 'dyGIN2d':
            return DenseGINConv2d, 1


    def forward(self, inputs: Tensor):

        if inputs.size(-1) % self.num_graphs:
            pad_size = (self.num_graphs - inputs.size(-1) % self.num_graphs) / 2
            x = F.pad(inputs, (int(pad_size), ceil(pad_size)), mode='constant', value=0.0)
        else:
            x = inputs

        adj = self.g_constr(x.device)

        #*******************************************************************
        # Attention layer
        attention_scores = F.softmax(self.attention_matrix, dim=-1)
        adj = adj * attention_scores
        #*******************************************************************

        for lstm, tconv, gconv, bn, pool in zip(self.lstm_layers, self.tconvs, self.gconvs, self.bns, self.diffpool):
            s=x.shape[1]
            if s==1:
               x1=x.repeat(1, 128, 1,1)
            else:
               x1=x.repeat(1, 2, 1,1)
            batch_size, channels, seq_len, features = x.size()
            x = x.view(batch_size, seq_len, channels * features)
            # Apply LSTM layer
            x, _ = lstm(x)
            x = torch.reshape(x, (batch_size, -1, seq_len, features))
            x=x+x1

            temp, logvar, mean = gconv(x, adj)

            x, adj = pool(temp, adj)

            x = self.activation(bn(x))
            x = F.dropout(x, p=self.dropout, training=self.training)


        out = self.global_pool(x)
        out = out.view(out.size(0), -1)
        out = self.linear(out)



        return out,logvar, mean,adj, attention_scores


model = GNNStack(gnn_model_type='dyGIN2d', num_layers=1,
                     groups=6, pool_ratio=0.2, kern_size=[11],
                     in_dim=128, hidden_dim=128, out_dim=128,
                     seq_len=seq_length, num_nodes=num_nodes, num_classes=num_classes)


torch.cuda.set_device(DEVICE)

# collect cache
gc.collect()
# torch.cuda.empty_cache()

model = model.cuda(DEVICE)


# %% [code]
criterion = nn.BCEWithLogitsLoss().cuda(DEVICE)


# %% [code]
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


# %% [code]
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5,
                                                              patience=10, verbose=True)


# %% [code]
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)



# %% [code]
def balanced_accuracy_multi(inp, targ, thresh=0.5, sigmoid=True):
    "Computes balanced accuracy when `inp` and `targ` are the same size."

    if sigmoid: inp = inp.sigmoid()
    pred = inp>thresh

    correct = pred==targ.bool()
    TP = torch.logical_and(correct,  (targ==1).bool()).sum()
    TN = torch.logical_and(correct,  (targ==0).bool()).sum()
    FN = torch.logical_and(~correct, (targ==1).bool()).sum()
    FP = torch.logical_and(~correct, (targ==0).bool()).sum()

    TPR = TP/(TP+FN)
    TNR = TN/(TN+FP)
    balanced_accuracy = (TPR+TNR)/2
    return balanced_accuracy


# %% [code]
def Fbeta_multi(inp, targ, beta=1.0, thresh=0.5, sigmoid=True):
    "Computes Fbeta when `inp` and `targ` are the same size."

    if sigmoid: inp = inp.sigmoid()
    pred = inp>thresh

    correct = pred==targ.bool()
    TP = torch.logical_and(correct,  (targ==1).bool()).sum()
    TN = torch.logical_and(correct,  (targ==0).bool()).sum()
    FN = torch.logical_and(~correct, (targ==1).bool()).sum()
    FP = torch.logical_and(~correct, (targ==0).bool()).sum()

    precision = TP/(TP+FP)
    recall = TP/(TP+FN)
    beta2 = beta*beta

    if precision+recall > 0:
        Fbeta = (1+beta2)*precision*recall/(beta2*precision+recall)
    else:
        Fbeta = 0
    return Fbeta


# %% [code]
def recall_multi(inp, targ, thresh=0.5, sigmoid=True):
    "Computes recall when `inp` and `targ` are the same size."

    if sigmoid: inp = inp.sigmoid()
    pred = inp>thresh

    correct = pred==targ.bool()
    TP = torch.logical_and(correct,  (targ==1).bool()).sum()
    FN = torch.logical_and(~correct, (targ==1).bool()).sum()

    recall = TP/(TP+FN)
    return recall

per=1e-8
# %% [code]
def regularization_loss(adj_original, adj_learned, lambda_reg):

    # Get dimensions
    n_orig = adj_original.shape[-1]
    n_learn = adj_learned.shape[-1]

    # Pad the smaller matrix with zeros
    if n_orig < n_learn:
        pad_size = (0, n_learn - n_orig, 0, n_learn - n_orig)
        adj_original = F.pad(adj_original, pad_size, mode='constant', value=0)
    elif n_learn < n_orig:
        pad_size = (0, n_orig - n_learn, 0, n_orig - n_learn)
        adj_learned = F.pad(adj_learned, pad_size, mode='constant', value=0)

    return lambda_reg * torch.sum((adj_original - adj_learned) ** 2)

def structural_loss(adj_original, adj_learned, mu):
    """
    Calculate the structural loss as the cosine distance between the original
    and learned adjacency matrices.
    """

    # Get dimensions and pad the smaller matrix
    n_orig = adj_original.shape[-1]
    n_learn = adj_learned.shape[-1]

    if n_orig < n_learn:
        pad_size = (0, n_learn - n_orig, 0, n_learn - n_orig)
        adj_original = F.pad(adj_original, pad_size, mode='constant', value=0)
    elif n_learn < n_orig:
        pad_size = (0, n_orig - n_learn, 0, n_orig - n_learn)
        adj_learned = F.pad(adj_learned, pad_size, mode='constant', value=0)

    adj_original_flat = adj_original.view(-1)
    adj_learned_flat = adj_learned.view(-1)

    # Compute cosine similarity
    dot_product = torch.dot(adj_original_flat, adj_learned_flat)
    norm_orig = torch.norm(adj_original_flat)
    norm_learn = torch.norm(adj_learned_flat)

    # Avoid division by zero
    if norm_orig * norm_learn == 0:
        cosine_similarity = 0
    else:
        cosine_similarity = dot_product / (norm_orig * norm_learn)

    # Structural loss based on cosine distance
    return mu * (1 - cosine_similarity)

def initialize_adj_original(num_graphs, num_nodes):
    # Random initialization for adj_original
    adj_original = torch.rand(num_graphs, num_nodes, num_nodes, device=DEVICE)
    # Set self-loops to zero using .fill_diagonal_() for each graph in the batch
    for i in range(num_graphs):
        adj_original[i].fill_diagonal_(0)
    return adj_original

def train(train_loader, model, criterion, optimizer, lr_scheduler,adj_original):
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc', ':6.2f')
    f_1 = AverageMeter('F1', ':6.2f')
    sensitivity = AverageMeter('Sens', ':6.2f')

    # switch to train mode
    model.train()
    adj_previous = adj_original
    for count, (data, label) in enumerate(train_loader):

        # data in cuda
        augmented_data_neg = augment_data(data)

        # Set 50 percent of the values to 0
        mask = np.random.choice([0, 1], size=data.shape, p=[0.2, 0.8])

        # Use the mask to set 50% of the elements to zero
        augmented_data_pos = torch.tensor(data * mask)

        data, augmented_data_neg = data.to(DEVICE).type(torch.float), augmented_data_neg.to(DEVICE).type(torch.float)
        augmented_data_pos=augmented_data_pos.to(DEVICE).type(torch.float)
        label = label.to(DEVICE).type(torch.float)

        # Forward pass
        output, logvar, mean,adj_learned,attn  = model(data)

        loss_reg = regularization_loss(adj_original, adj_learned, 0.001)
        loss_struct = 0
        if adj_previous is not None:
            loss_struct = structural_loss(adj_previous, adj_learned, 0.001)
        loss = criterion(output, label) # + loss_reg + loss_struct
        total_loss = loss
        losses.update(loss.item(), data.size(0))

        acc1 = balanced_accuracy_multi(output, label)
        f1 = Fbeta_multi(output, label)
        sens = recall_multi(output, label)
        top1.update(acc1, data.size(0))
        f_1.update(f1, data.size(0))
        sensitivity.update(sens, data.size(0))

        # compute gradient and do Adam step
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        adj_previous = adj_learned.detach()  # Update adj_previous for the next iteration

    lr_scheduler.step(top1.avg)

    return top1.avg, losses.avg, f_1.avg, sensitivity.avg, attn
from torch.nn.functional import cosine_similarity

def augment_data(data):
    # Shuffling along the time axis (axis=2)
    shuffled_data = data.clone()
    batch_size, channels, time_length, features = data.shape
    for i in range(batch_size):
        # Generating a random permutation of indices from 0 to time_length - 1
        idx = torch.randperm(time_length)
        shuffled_data[i, :, :, :] = data[i, :, idx, :]
    return shuffled_data


# %% [code]
def validate(val_loader, model, criterion, adj_original):
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    f_1 = AverageMeter('F1', ':6.2f')
    sensitivity = AverageMeter('Sens', ':6.2f')

    # Switch to evaluate mode
    model.eval()

    with torch.no_grad():
        for count, (data, label) in enumerate(val_loader):
            data = data.to(DEVICE).type(torch.float)
            label = label.to(DEVICE).type(torch.float)

            # Compute output
            output, logvar, mean, adj_learned, attn = model(data)

            # Calculate KL divergence for the variational part (if it's part of the loss during evaluation)
            kl_divergence = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
            primary_loss = criterion(output, label)

            loss_reg = regularization_loss(adj_original, adj_learned, 0.1)
            loss_struct = structural_loss(adj_original, adj_learned, 0.01)

            # Here, we use primary_loss directly for validation
            # Since regularization and structural losses are typically training stabilizers
            loss = primary_loss + loss_reg + kl_divergence + loss_struct  # You can include or exclude KL divergence based on your specific requirements

            # Measure accuracy and record loss
            acc1 = balanced_accuracy_multi(output, label)
            f1 = Fbeta_multi(output, label)
            sens = recall_multi(output, label)

            losses.update(loss.item(), data.size(0))
            top1.update(acc1, data.size(0))
            f_1.update(f1, data.size(0))
            sensitivity.update(sens, data.size(0))

            # Clean-up and memory management
            gc.collect()
            torch.cuda.empty_cache()

    return top1.avg, losses.avg, f_1.avg, sensitivity.avg,attn

In [None]:
loss_train = []
acc_train = []
loss_val = []
acc_val = []
epoches = []
f1_train = []
f1_val = []
sens_train = []
sens_val = []

# init acc
best_acc1 = 0
best_f1 = 0
best_sens = 0
adj_original = initialize_adj_original(6, num_nodes=num_nodes)

for epoch in range(100):
    epoches += [epoch]

    # train for one epoch
    acc_train_per, loss_train_per, f1_train_per, sens_train_per, temp = train(train_loader, model, criterion, optimizer, lr_scheduler,adj_original)

    acc_train += [acc_train_per]
    loss_train += [loss_train_per]
    f1_train += [f1_train_per]
    sens_train += [sens_train_per]

    msg = f'TRAIN, epoch {epoch}, loss {loss_train_per}, acc {acc_train_per}'

    # evaluate on validation set
    acc_val_per, loss_val_per, f1_val_per, sens_val_per, attention = validate(val_loader, model, criterion, adj_original)

    acc_val += [acc_val_per]
    loss_val += [loss_val_per]
    f1_val += [f1_val_per]
    sens_val += [sens_val_per]

    print(f'VAL, loss {loss_val_per}, acc {acc_val_per}, f1 {f1_val_per}, sens {sens_val_per}')

    if acc_val_per > best_acc1:
      torch.save(model, 'models/EHRSHOT_1992.pt')
      attention_scores_train = temp

    # remember best acc
    best_acc1 = max(acc_val_per, best_acc1)
    best_f1 = max(f1_val_per, best_f1)
    best_sens = max(sens_val_per, best_sens)

  augmented_data_pos = torch.tensor(data * mask)


VAL, loss 0.19888468086719513, acc 0.5, f1 0.0, sens 0.0
VAL, loss 0.20049485564231873, acc 0.5, f1 0.0, sens 0.0
VAL, loss 0.20112231373786926, acc 0.5, f1 0.0, sens 0.0
VAL, loss 0.20223067700862885, acc 0.5, f1 0.0, sens 0.0
VAL, loss 0.20105482637882233, acc 0.5, f1 0.0, sens 0.0
VAL, loss 0.20301832258701324, acc 0.5, f1 0.0, sens 0.0
VAL, loss 0.2020978331565857, acc 0.5, f1 0.0, sens 0.0
VAL, loss 0.20419423282146454, acc 0.5, f1 0.0, sens 0.0
VAL, loss 0.2023783028125763, acc 0.5, f1 0.0, sens 0.0
VAL, loss 0.20403751730918884, acc 0.5, f1 0.0, sens 0.0
VAL, loss 0.20352806150913239, acc 0.5, f1 0.0, sens 0.0
VAL, loss 0.20310641825199127, acc 0.5, f1 0.0, sens 0.0
VAL, loss 0.2028382122516632, acc 0.5, f1 0.0, sens 0.0
VAL, loss 0.20276802778244019, acc 0.5, f1 0.0, sens 0.0
VAL, loss 0.20237518846988678, acc 0.5, f1 0.0, sens 0.0
VAL, loss 0.20436765253543854, acc 0.5, f1 0.0, sens 0.0
VAL, loss 0.2037253975868225, acc 0.5, f1 0.0, sens 0.0
VAL, loss 0.2035953402519226, acc 0

In [None]:
model = torch.load('models/EHRSHOT_1992.pt')
model.eval()

with torch.no_grad():
    for count, (data, label) in enumerate(val_loader):

        data = data.to(DEVICE).type(torch.float)
        label = label.to(DEVICE).type(torch.float)

        # compute output
        output, _, _, _, attention_scores_test = model(data)

        # measure accuracy and record loss
        acc1 = balanced_accuracy_multi(output, label)
        f1 = Fbeta_multi(output, label)
        sens = recall_multi(output, label)

        print(acc1, f1, sens)

  model = torch.load('models/EHRSHOT_1992.pt')


tensor(0.5000, device='cuda:0') 0 tensor(0., device='cuda:0')


In [None]:
best_acc1, best_f1, best_sens

(tensor(0.5000, device='cuda:0'), 0.0, tensor(0., device='cuda:0'))