In [1]:

"""
TUH Epilepsy Dataset
"""

from tqdm import tqdm
import glob, os
import numpy as np
import torch

from IPython.display import clear_output
from torch.utils.data import TensorDataset,DataLoader
from sklearn.preprocessing import OneHotEncoder

DEVICE = torch.device("cuda") #torch.device("mps")
BATCH_SIZE = 32
NUM_WINDOWS = 60


def norm_adj(train_graphs, test_graphs):
    for i in range(train_graphs.shape[0]):
        for j in range(train_graphs.shape[1]):
            min_ = (train_graphs[i, j, :, :]).min()
            max_ = (train_graphs[i, j, :, :]).max()
            train_graphs[i, j, :,  :] = (train_graphs[i, j, :,  :] - min_)/(max_ - min_)
                
    for i in range(test_graphs.shape[0]):
        for j in range(test_graphs.shape[1]):
            min_ = (test_graphs[i, j, :, :]).min()
            max_ = (test_graphs[i, j, :, :]).max()
            test_graphs[i, j, :, :] = (test_graphs[i, j, :,  :] - min_)/(max_ - min_)
            
    return train_graphs, test_graphs

def norm_feat(train_X, test_X):
    for i in range(train_X.shape[1]):
        min_ = (train_X[:, i, :]).min()
        max_ = (train_X[:, i, :]).max()
        train_X[:, i, :] = (train_X[:, i, :] - min_)/(max_ - min_)
        test_X[:, i, :] = (test_X[:, i, :] - min_)/(max_ - min_)

    return train_X, test_X


from sklearn.model_selection import train_test_split

def stratify_train_test_split(conn="coh"):
    train_X = np.load("/kaggle/input/tuh-epilepsy/train_X.npy", mmap_mode="c")*1e3
    train_y = np.load("/kaggle/input/tuh-epilepsy/train_y.npy", mmap_mode="c")
    train_graphs = np.load("/kaggle/input/tuh-epilepsy/train_graphs_"+conn+".npy", mmap_mode="c")
    test_y = np.load("/kaggle/input/tuh-epilepsy/test_y.npy", mmap_mode="c")
    test_X = np.load("/kaggle/input/tuh-epilepsy/test_X.npy", mmap_mode="c")*1e3
    test_graphs = np.load("/kaggle/input/tuh-epilepsy/test_graphs_"+conn+".npy", mmap_mode="c")
    train_X = np.vstack((train_X, test_X))
    train_graphs = np.vstack((train_graphs, test_graphs))
    train_y = np.vstack((train_y, test_y))
    train_graphs, test_graphs = norm_adj(train_graphs, test_graphs)
    train_X = np.moveaxis(train_X, 1, 2)
    test_X = np.moveaxis(test_X, 1, 2)
    target = train_y
    train_X, test_X, train_y, test_y = train_test_split(train_X, train_y, test_size=0.1, random_state=42, stratify=target)
    train_graphs, test_graphs = train_test_split(train_graphs, test_size=0.1, random_state=42, stratify=target)
    return train_X, train_graphs, train_y, test_X, test_graphs, test_y

In [3]:
train_X.shape

(1992, 60, 19, 200)

In [4]:
train_graphs = np.vstack((train_graphs, test_graphs))
train_graphs.shape

(1992, 60, 19, 19)

In [5]:
train_y = np.vstack((train_y, test_y))
train_y.shape

(1992, 1)

In [6]:
target = train_y
train_X, test_X, train_y, test_y = train_test_split(train_X, train_y, test_size=0.1, random_state=42, stratify=target)
train_graphs, test_graphs = train_test_split(train_graphs, test_size=0.1, random_state=42, stratify=target)

In [9]:
list((test_y.squeeze())).count(0), list((test_y.squeeze())).count(1)

(37, 163)

In [11]:
list((train_y.squeeze())).count(0), list((train_y.squeeze())).count(1)

(332, 1460)

In [10]:
163+37

200

In [2]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 30 07:34:26 2025

@author: mohamedr
"""

import scipy.sparse as sp
import numpy as np

from torch_geometric.data import Data
import torch
from torch_geometric.utils import dense_to_sparse
from sklearn.preprocessing import StandardScaler
from sklearn.base import TransformerMixin,BaseEstimator


def normalize_adj(adj):
    """Symmetrically normalize adjacency matrix."""
    adj = (adj - adj.min())/(adj.max() - adj.min())
    adj = sp.coo_matrix(adj)
    rowsum = np.array(adj.sum(1)) # D
    d_inv_sqrt = np.power(rowsum, -0.5).flatten() # D^-0.5
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt) # D^-0.5
    return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() # D^-0.5AD^0.5

def preprocess_adj(adj):
    """Preprocessing of adjacency matrix for simple GCN model 
    and conversion to tuple representation."""
    adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0])).toarray()
    return adj_normalized

def normalize_all_adj(graphs):
    for graph_idx in range(graphs.shape[0]):
        for band_idx in range(graphs.shape[-1]):
            graphs[graph_idx, :, :, band_idx] = preprocess_adj(graphs[graph_idx, :, :, band_idx])
    return graphs


def build_pyg_dl(x, a, y, time_points, device):
    """Convert features and adjacency to PyTorch Geometric Dataloader"""
    a = torch.from_numpy(a)
    a = a + 1e-10 
    edge_attr = []
    
    for edge_time_idx in range(time_points):
        Af = a[edge_time_idx, :, :]
        Af.fill_diagonal_(1)
        edge_index, attrf = dense_to_sparse(Af)
        edge_attr.append(attrf)
    
    edge_attr = torch.stack(edge_attr)
    edge_attr = torch.moveaxis(edge_attr, 0, 1).to(device)
    edge_index = edge_index.to(device)
    x = torch.from_numpy(x).to(device)
    y = torch.tensor([y], dtype=torch.float).to(device)
    data = Data(x=x, edge_index=edge_index, 
                edge_attr=edge_attr, 
                y=y)
    return data


#https://stackoverflow.com/questions/50125844/how-to-standard-scale-a-3d-matrix
class StandardScaler3D(BaseEstimator,TransformerMixin):
    #batch, sequence, channels
    def __init__(self):
        self.scaler = StandardScaler()

    def fit(self,X,y=None):
        X = X.reshape(-1, X.shape[1]*X.shape[3])
        self.scaler.fit(X)
        return self

    def transform(self,X):
        return self.scaler.transform(X.reshape(-1, X.shape[1]*X.shape[3])).reshape(X.shape)
        
def std_data(train_X, val_X):
    scaler = StandardScaler3D()
    train_X = scaler.fit_transform(train_X)
    val_X = scaler.transform(val_X)
    return train_X, val_X


def fill_diag(x):
    num_channels = x.shape[1]
    num_bands = x.shape[0]
    x_diag = np.zeros((num_bands, num_channels, num_channels))
    for band_idx, band in enumerate(x):
        for idx, i in enumerate(band):
            x_diag[band_idx, idx, idx] = 0
            if idx == 0:
                x_diag[band_idx, idx, 1:] = i
            elif idx > 0 and idx < num_channels-1:
                x_diag[band_idx, idx, idx+1:] = i[idx:]
                x_diag[band_idx, idx, :idx] = i[:idx]
            elif idx == num_channels-1:
                x_diag[band_idx, idx, :-1] = i
    return x_diag


import torch, gc
from torch_geometric.loader import DataLoader
import random
from sklearn.model_selection import KFold    
from sklearn.preprocessing import OneHotEncoder


def loaders(train_X, train_graphs, train_y, test_X, test_graphs, test_y, device, batch_size, num_windows):
    #ohe
    ohe = OneHotEncoder()
    train_y_ohe = ohe.fit_transform(train_y).toarray()
    test_y_ohe = ohe.transform(test_y).toarray()
    
    # build pyg dataloader
    train_dataset = [build_pyg_dl(x, g, y, num_windows, device) for x, g, y in zip(train_X, train_graphs, train_y_ohe)]
    test_dataset = [build_pyg_dl(x, g, y, num_windows, device) for x, g, y in zip(test_X, test_graphs, test_y_ohe)]
    train_iter = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_iter = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_iter, test_iter

In [3]:
import torch
#from .temporalgcn import TGCN
#from .temporalgcn import TGCN2
from torch_geometric.nn import GCNConv
from torch_geometric_temporal.nn.recurrent import TGCN


class A3TGCN(torch.nn.Module):
    r"""An implementation of the Attention Temporal Graph Convolutional Cell.
    For details see this paper: `"A3T-GCN: Attention Temporal Graph Convolutional
    Network for Traffic Forecasting." <https://arxiv.org/abs/2006.11583>`_

    Args:
        in_channels (int): Number of input features.
        out_channels (int): Number of output features.
        periods (int): Number of time periods.
        improved (bool): Stronger self loops (default :obj:`False`).
        cached (bool): Caching the message weights (default :obj:`False`).
        add_self_loops (bool): Adding self-loops for smoothing (default :obj:`True`).
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        periods: int,
        improved: bool = False,
        cached: bool = False,
        add_self_loops: bool = True
    ):
        super(A3TGCN, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.periods = periods
        self.improved = improved
        self.cached = cached
        self.add_self_loops = add_self_loops
        self._setup_layers()

    def _setup_layers(self):
        self._base_tgcn = TGCN(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            improved=self.improved,
            cached=self.cached,
            add_self_loops=self.add_self_loops,
        )
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self._attention = torch.nn.Parameter(torch.empty(self.periods, device=device))
        torch.nn.init.uniform_(self._attention)

    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor = None,
        H: torch.FloatTensor = None,
    ) -> torch.FloatTensor:
        """
        Making a forward pass. If edge weights are not present the forward pass
        defaults to an unweighted graph. If the hidden state matrix is not present
        when the forward pass is called it is initialized with zeros.

        Arg types:
            * **X** (PyTorch Float Tensor): Node features for T time periods.
            * **edge_index** (PyTorch Long Tensor): Graph edge indices.
            * **edge_weight** (PyTorch Long Tensor, optional)*: Edge weight vector.
            * **H** (PyTorch Float Tensor, optional): Hidden state matrix for all nodes.

        Return types:
            * **H** (PyTorch Float Tensor): Hidden state matrix for all nodes.
        """
        H_accum = 0
        probs = torch.nn.functional.softmax(self._attention, dim=0)
        for period in range(self.periods):
            Xt = X[:, period, :]
            edge_weight_t = edge_weight[:, period]
            H_accum = H_accum + probs[period] * self._base_tgcn(Xt, edge_index, edge_weight_t, H)
            
        return H_accum

In [4]:
import torch
from torch_geometric.nn import GCNConv


class TGCN(torch.nn.Module):
    r"""An implementation of the Temporal Graph Convolutional Gated Recurrent Cell.
    For details see this paper: `"T-GCN: A Temporal Graph ConvolutionalNetwork for
    Traffic Prediction." <https://arxiv.org/abs/1811.05320>`_

    Args:
        in_channels (int): Number of input features.
        out_channels (int): Number of output features.
        improved (bool): Stronger self loops. Default is False.
        cached (bool): Caching the message weights. Default is False.
        add_self_loops (bool): Adding self-loops for smoothing. Default is True.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        improved: bool = False,
        cached: bool = False,
        add_self_loops: bool = True,
    ):
        super(TGCN, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.improved = improved
        self.cached = cached
        self.add_self_loops = add_self_loops

        self._create_parameters_and_layers()

    def _create_update_gate_parameters_and_layers(self):

        self.conv_z = GCNConv(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            improved=self.improved,
            cached=self.cached,
            add_self_loops=self.add_self_loops,
        )

        self.linear_z = torch.nn.Linear(2 * self.out_channels, self.out_channels)

    def _create_reset_gate_parameters_and_layers(self):

        self.conv_r = GCNConv(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            improved=self.improved,
            cached=self.cached,
            add_self_loops=self.add_self_loops,
        )

        self.linear_r = torch.nn.Linear(2 * self.out_channels, self.out_channels)

    def _create_candidate_state_parameters_and_layers(self):

        self.conv_h = GCNConv(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            improved=self.improved,
            cached=self.cached,
            add_self_loops=self.add_self_loops,
        )

        self.linear_h = torch.nn.Linear(2 * self.out_channels, self.out_channels)

    def _create_parameters_and_layers(self):
        self._create_update_gate_parameters_and_layers()
        self._create_reset_gate_parameters_and_layers()
        self._create_candidate_state_parameters_and_layers()

    def _set_hidden_state(self, X, H):
        if H is None:
            H = torch.zeros(X.shape[0], self.out_channels).to(X.device)
        return H

    def _calculate_update_gate(self, X, edge_index, edge_weight, H):
        Z = torch.cat([self.conv_z(X, edge_index, edge_weight), H], axis=1)
        Z = self.linear_z(Z)
        Z = torch.sigmoid(Z)
        return Z

    def _calculate_reset_gate(self, X, edge_index, edge_weight, H):
        R = torch.cat([self.conv_r(X, edge_index, edge_weight), H], axis=1)
        R = self.linear_r(R)
        R = torch.sigmoid(R)
        return R

    def _calculate_candidate_state(self, X, edge_index, edge_weight, H, R):
        H_tilde = torch.cat([self.conv_h(X, edge_index, edge_weight), H * R], axis=1)
        H_tilde = self.linear_h(H_tilde)
        H_tilde = torch.tanh(H_tilde)
        return H_tilde

    def _calculate_hidden_state(self, Z, H, H_tilde):
        H = Z * H + (1 - Z) * H_tilde
        return H

    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor = None,
        H: torch.FloatTensor = None,
    ) -> torch.FloatTensor:
        """
        Making a forward pass. If edge weights are not present the forward pass
        defaults to an unweighted graph. If the hidden state matrix is not present
        when the forward pass is called it is initialized with zeros.

        Arg types:
            * **X** *(PyTorch Float Tensor)* - Node features.
            * **edge_index** *(PyTorch Long Tensor)* - Graph edge indices.
            * **edge_weight** *(PyTorch Long Tensor, optional)* - Edge weight vector.
            * **H** *(PyTorch Float Tensor, optional)* - Hidden state matrix for all nodes.

        Return types:
            * **H** *(PyTorch Float Tensor)* - Hidden state matrix for all nodes.
        """
        H = self._set_hidden_state(X, H)
        Z = self._calculate_update_gate(X, edge_index, edge_weight, H)
        R = self._calculate_reset_gate(X, edge_index, edge_weight, H)
        H_tilde = self._calculate_candidate_state(X, edge_index, edge_weight, H, R)
        H = self._calculate_hidden_state(Z, H, H_tilde)
        return H


In [5]:
# Temporal operators......etc.
from torch_geometric.nn import GCNConv, GATConv, ChebConv, TransformerConv, GINConv, SGConv, GeneralConv, SAGEConv
from torch_geometric_temporal.nn.recurrent import DCRNN, GConvGRU, GConvLSTM, GCLSTM
from torch_geometric.nn import global_mean_pool
import torch.nn as nn


class GraphTemporal(nn.Module):
    def __init__(self, num_ch, num_t, op):
        super(GraphTemporal, self).__init__()
        self.lstm = nn.LSTM(input_size=32, hidden_size=10, num_layers=1, bidirectional=True, batch_first=True)
        self.linear = nn.Linear(32, 2)
        self.BN2 = nn.BatchNorm1d(56)
        self.BN1 = nn.BatchNorm1d(19)
        
        self.conv1 = nn.Conv2d(num_ch, num_ch, (3, 3), stride=(1, 2), padding=(1, 0), dilation=(1, 3))
        self.conv2 = nn.Conv2d(num_ch, num_ch, (3, 5), stride=(1, 2), padding=(1, 1), dilation=(1, 3))
        self.conv3 = nn.Conv2d(num_ch, num_ch, (3, 10), stride=(1, 2), padding=(1, 3), dilation=(1, 3))
        
        self.maxpool = nn.MaxPool2d((1, 5))
        self.relu = nn.ReLU()
        
        if op=="GCLSTM":
            self.GraphOp = GCLSTM(in_channels=56, out_channels=32, K=2)
        elif op=="GConvLSTM":
            self.GraphOp = GConvLSTM(in_channels=56, out_channels=32, K=2)
        elif op=="A3TGCN":
            self.GraphOp = A3TGCN(in_channels=56, out_channels=32, periods=60)
        elif op=="TGCN":
            self.GraphOp = TGCN(in_channels=56, out_channels=32)
        elif op=="DCRNN":
            self.GraphOp = DCRNN(in_channels=56, out_channels=32, K=2)
        elif op=="GConvGRU":
            self.GraphOp = GConvGRU(in_channels=56, out_channels=32, K=2)
            
        self.num_ch = num_ch
        self.num_t = num_t
        self.op = op
        #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        #self._attention = torch.nn.Parameter(torch.empty(self.num_t, device=device))
        #torch.nn.init.uniform_(self._attention)

    
    def forward(self, x, idx, attr, batch=False):
        batch_size = int(x.shape[0]/self.num_ch)
        
        x = x.reshape(batch_size, self.num_ch, x.shape[-2], x.shape[-1])
        x1 = self.conv1(x)
        x1 = self.relu(x1)
        x2 = self.conv2(x)
        x2 = self.relu(x2)
        x3 = self.conv3(x)
        x3 = self.relu(x3)
        x = torch.cat([x1, x2, x3], dim=-1)
        
        x = self.maxpool(x)
        #x = x.permute(0, 1, 3, 2)
        x = x.reshape(x.shape[0]*x.shape[1], x.shape[2], x.shape[3])

        x = x.permute(0, 2, 1)
        x = self.BN2(x)
        x = x.permute(0, 2, 1)
        if self.op == "A3TGCN":
            HS = self.GraphOp(x, idx, attr)
        else:
            for t_idx in range(self.num_t):
                attr_t = attr[:, t_idx]
                x_t = x[:, t_idx]
                if t_idx == 0:
                    HS = self.GraphOp(x_t, idx, attr_t)
                else:
                    #HS = HS + probs[t_idx]*self.GraphOp(x_t, idx, attr_t, HS)
                    HS = self.GraphOp(x_t, idx, attr_t, HS)
                if type(HS) == tuple:
                    HS = HS[0]
        out = self.relu(HS)
        out = global_mean_pool(out, batch)
        out = self.linear(out)
        return out


In [18]:
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, precision_score, recall_score, auc, roc_auc_score
from tqdm import tqdm
import torch
import numpy as np


def print_acc(model, data_iter):
    outs= []
    ys = []
    
    model.eval()
    with torch.no_grad():
        for batch in data_iter:
            x = batch.x.float()
            idx = batch.edge_index.long()
            attr = batch.edge_attr.float()
            y = batch.y
            batch = batch.batch
            y = torch.argmax(y, -1)
            out = model(x, idx, attr, batch)
            out = torch.exp(out)
            outs.extend(out.cpu().detach().numpy())
            ys.extend(y.cpu().detach().numpy())
    
    outs = np.array(outs)
    ys = np.array(ys)
    outs = np.argmax(outs, -1)

    metrics = [accuracy_score(outs, ys), f1_score(outs, ys), 
               precision_score(outs, ys), recall_score(outs, ys)]
    return metrics

    
def train_model(model, num_epochs, data_iter, val_iter=None, weight=None):
    criterion = torch.nn.CrossEntropyLoss(weight = weight)
    optimizer = torch.optim.AdamW(model.parameters(),lr=1e-3)

    model.train()
    for epoch in tqdm(range(num_epochs)): 
        model.train()
        losses = 0
        for _, batch in enumerate(data_iter):
            x = batch.x.float()
            idx = batch.edge_index.long()
            attr = batch.edge_attr.float()
            y = batch.y
            batch = batch.batch
            y = torch.argmax(y, -1)
            optimizer.zero_grad()
            out = model(x, idx, attr, batch)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            losses += loss.item() 

    return model            

In [23]:
train_X, train_graphs, train_y, test_X, test_graphs, test_y = stratify_train_test_split(conn="coh")
train_iter, test_iter = loaders(train_X, train_graphs, train_y, test_X, test_graphs, test_y, DEVICE, BATCH_SIZE, NUM_WINDOWS)
num_epochs = 15
results = []
gt_operators = ["A3TGCN", "GCLSTM", "TGCN", "DCRNN", "GConvGRU"]
for op in gt_operators:
    for i in range(5):
        model = GraphTemporal(num_ch=19, num_t=60, op=op).to(DEVICE)
        class_weight = torch.tensor([6., 1.]).to(DEVICE)
        model = train_model(model, num_epochs, train_iter, weight=class_weight)
        test_res = print_acc(model, test_iter)
        results.append(test_res)

100%|██████████| 15/15 [04:24<00:00, 17.65s/it]
100%|██████████| 15/15 [04:28<00:00, 17.90s/it]
100%|██████████| 15/15 [04:27<00:00, 17.84s/it]
100%|██████████| 15/15 [04:26<00:00, 17.80s/it]
100%|██████████| 15/15 [04:27<00:00, 17.82s/it]
100%|██████████| 15/15 [06:27<00:00, 25.80s/it]
100%|██████████| 15/15 [06:27<00:00, 25.83s/it]
100%|██████████| 15/15 [06:26<00:00, 25.77s/it]
100%|██████████| 15/15 [06:26<00:00, 25.74s/it]
100%|██████████| 15/15 [06:25<00:00, 25.71s/it]
100%|██████████| 15/15 [04:15<00:00, 17.05s/it]
100%|██████████| 15/15 [04:15<00:00, 17.03s/it]
100%|██████████| 15/15 [04:15<00:00, 17.01s/it]
100%|██████████| 15/15 [04:15<00:00, 17.04s/it]
100%|██████████| 15/15 [04:17<00:00, 17.15s/it]
100%|██████████| 15/15 [06:34<00:00, 26.29s/it]
100%|██████████| 15/15 [06:34<00:00, 26.29s/it]
100%|██████████| 15/15 [06:35<00:00, 26.35s/it]
100%|██████████| 15/15 [06:34<00:00, 26.32s/it]
100%|██████████| 15/15 [06:34<00:00, 26.33s/it]
100%|██████████| 15/15 [08:21<00:00, 33.

In [25]:
acc = ["accuracy", "f1", "precision", "recall"]
results = np.array(results).reshape(5, 5, 4)
for idx, res in enumerate(results):
    mean = np.mean(res, axis=0)
    std = np.std(res, axis=0)
    for mi, _ in enumerate(mean):
        print(gt_operators[idx], acc[mi], mean[mi], std[mi])

A3TGCN accuracy 0.8310000000000001 0.018275666882497058
A3TGCN f1 0.8940223106786573 0.013635002200660587
A3TGCN precision 0.8773006134969326 0.033151548677561475
A3TGCN recall 0.9124923519567038 0.012223926917704908
GCLSTM accuracy 0.8219999999999998 0.014352700094407285
GCLSTM f1 0.8924432268120622 0.009387066549634391
GCLSTM precision 0.9067484662576687 0.020677668155034017
GCLSTM recall 0.8790191748351315 0.012475334365817855
TGCN accuracy 0.7699999999999999 0.020000000000000018
TGCN f1 0.8534204574004705 0.017880748616214043
TGCN precision 0.8257668711656441 0.04840020349280006
TGCN recall 0.885858459142989 0.0197238166351186
DCRNN accuracy 0.825 0.012649110640673502
DCRNN f1 0.8934275327767482 0.008328499393054503
DCRNN precision 0.9006134969325152 0.01993629056352384
DCRNN recall 0.8868208651756936 0.012751590671834944
GConvGRU accuracy 0.827 0.031080540535840095
GConvGRU f1 0.8939095275765532 0.020632578849823335
GConvGRU precision 0.8969325153374234 0.03627422208366744
GConvGR

In [27]:
train_X, train_graphs, train_y, test_X, test_graphs, test_y = stratify_train_test_split(conn="plv")
train_iter, test_iter = loaders(train_X, train_graphs, train_y, test_X, test_graphs, test_y, DEVICE, BATCH_SIZE, NUM_WINDOWS)
num_epochs = 15
results = []
gt_operators = ["A3TGCN", "GCLSTM", "TGCN", "DCRNN", "GConvGRU"]
for op in gt_operators:
    for i in range(5):
        model = GraphTemporal(num_ch=19, num_t=60, op=op).to(DEVICE)
        class_weight = torch.tensor([6., 1.]).to(DEVICE)
        model = train_model(model, num_epochs, train_iter, weight=class_weight)
        test_res = print_acc(model, test_iter)
        results.append(test_res)

100%|██████████| 15/15 [04:27<00:00, 17.83s/it]
100%|██████████| 15/15 [04:27<00:00, 17.84s/it]
100%|██████████| 15/15 [04:27<00:00, 17.82s/it]
100%|██████████| 15/15 [04:26<00:00, 17.78s/it]
100%|██████████| 15/15 [04:27<00:00, 17.82s/it]
100%|██████████| 15/15 [06:26<00:00, 25.74s/it]
100%|██████████| 15/15 [06:26<00:00, 25.78s/it]
100%|██████████| 15/15 [06:25<00:00, 25.72s/it]
100%|██████████| 15/15 [06:26<00:00, 25.77s/it]
100%|██████████| 15/15 [06:25<00:00, 25.71s/it]
100%|██████████| 15/15 [04:15<00:00, 17.04s/it]
100%|██████████| 15/15 [04:15<00:00, 17.00s/it]
100%|██████████| 15/15 [04:15<00:00, 17.03s/it]
100%|██████████| 15/15 [04:15<00:00, 17.03s/it]
100%|██████████| 15/15 [04:15<00:00, 17.02s/it]
100%|██████████| 15/15 [06:35<00:00, 26.35s/it]
100%|██████████| 15/15 [06:35<00:00, 26.36s/it]
100%|██████████| 15/15 [06:35<00:00, 26.35s/it]
100%|██████████| 15/15 [06:35<00:00, 26.38s/it]
100%|██████████| 15/15 [06:35<00:00, 26.38s/it]
100%|██████████| 15/15 [08:22<00:00, 33.

In [28]:
acc = ["accuracy", "f1", "precision", "recall"]
results = np.array(results).reshape(5, 5, 4)
for idx, res in enumerate(results):
    mean = np.mean(res, axis=0)
    std = np.std(res, axis=0)
    for mi, _ in enumerate(mean):
        print(gt_operators[idx], acc[mi], mean[mi], std[mi])

A3TGCN accuracy 0.841 0.03397057550292605
A3TGCN f1 0.9010422546661955 0.023190321855283184
A3TGCN precision 0.8932515337423312 0.04434186748366292
A3TGCN recall 0.910235590116463 0.013803056648968236
GCLSTM accuracy 0.795 0.04037325847637269
GCLSTM f1 0.8735125551001428 0.02496005602656005
GCLSTM precision 0.8687116564417178 0.026487157233034263
GCLSTM recall 0.8784732497384239 0.02524086474680049
TGCN accuracy 0.8029999999999999 0.03655133376499412
TGCN f1 0.8787379606559979 0.02533209137055949
TGCN precision 0.8797546012269939 0.04189784313856169
TGCN recall 0.8784908973874593 0.01134365889705781
DCRNN accuracy 0.8089999999999999 0.03440930106817049
DCRNN f1 0.8832796637466943 0.022871347301263997
DCRNN precision 0.8895705521472393 0.03721656659657204
DCRNN recall 0.8777007855807077 0.0148014094375863
GConvGRU accuracy 0.813 0.015999999999999973
GConvGRU f1 0.8856538243205841 0.010825143779517439
GConvGRU precision 0.8895705521472392 0.02194913352147031
GConvGRU recall 0.88211771523