In [None]:
# From HGNN_conv.ipynb

from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor, scatter
from torch.nn import Parameter

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.utils import softmax

class HypergraphConv(MessagePassing):
    r"""The hypergraph convolutional operator from the `"Hypergraph Convolution
    and Hypergraph Attention" <https://arxiv.org/abs/1901.08150>`_ paper

    .. math::
        \mathbf{X}^{\prime} = \mathbf{D}^{-1} \mathbf{H} \mathbf{W}
        \mathbf{B}^{-1} \mathbf{H}^{\top} \mathbf{X} \mathbf{\Theta}

    where :math:`\mathbf{H} \in {\{ 0, 1 \}}^{N \times M}` is the incidence
    matrix, :math:`\mathbf{W} \in \mathbb{R}^M` is the diagonal hyperedge
    weight matrix, and
    :math:`\mathbf{D}` and :math:`\mathbf{B}` are the corresponding degree
    matrices.

    For example, in the hypergraph scenario
    :math:`\mathcal{G} = (\mathcal{V}, \mathcal{E})` with
    :math:`\mathcal{V} = \{ 0, 1, 2, 3 \}` and
    :math:`\mathcal{E} = \{ \{ 0, 1, 2 \}, \{ 1, 2, 3 \} \}`, the
    :obj:`hyperedge_index` is represented as:

    .. code-block:: python

        hyperedge_index = torch.tensor([
            [0, 1, 2, 1, 2, 3],
            [0, 0, 0, 1, 1, 1],
        ])

    Args:
        in_channels (int): Size of each input sample, or :obj:`-1` to derive
            the size from the first input(s) to the forward method.
        out_channels (int): Size of each output sample.
        use_attention (bool, optional): If set to :obj:`True`, attention
            will be added to this layer. (default: :obj:`False`)
        heads (int, optional): Number of multi-head-attentions.
            (default: :obj:`1`)
        concat (bool, optional): If set to :obj:`False`, the multi-head
            attentions are averaged instead of concatenated.
            (default: :obj:`True`)
        negative_slope (float, optional): LeakyReLU angle of the negative
            slope. (default: :obj:`0.2`)
        dropout (float, optional): Dropout probability of the normalized
            attention coefficients which exposes each node to a stochastically
            sampled neighborhood during training. (default: :obj:`0`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.

    Shapes:
        - **input:**
          node features :math:`(|\mathcal{V}|, F_{in})`,
          hyperedge indices :math:`(|\mathcal{V}|, |\mathcal{E}|)`,
          hyperedge weights :math:`(|\mathcal{E}|)` *(optional)*
          hyperedge features :math:`(|\mathcal{E}|, D)` *(optional)*
        - **output:** node features :math:`(|\mathcal{V}|, F_{out})`
    """
    def __init__(self, in_channels, out_channels, use_attention=False, heads=1,
                 concat=True, negative_slope=0.2, dropout=0, bias=True,
                 **kwargs):
        kwargs.setdefault('aggr', 'add')
        super().__init__(flow='source_to_target', node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_attention = use_attention

        if self.use_attention:
            self.heads = heads
            self.concat = concat
            self.negative_slope = negative_slope
            self.dropout = dropout
            self.lin = Linear(in_channels, heads * out_channels, bias=False,
                              weight_initializer='glorot')
            self.att = Parameter(torch.Tensor(1, heads, 2 * out_channels))
        else:
            self.heads = 1
            self.concat = True
            self.lin = Linear(in_channels, out_channels, bias=False,
                              weight_initializer='glorot')

        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        self.lin.reset_parameters()
        if self.use_attention:
            glorot(self.att)
        zeros(self.bias)


    def forward(self, x: Tensor, hyperedge_index: Tensor,
                hyperedge_weight: Optional[Tensor] = None,
                hyperedge_attr: Optional[Tensor] = None) -> Tensor:
        r"""Runs the forward pass of the module.

        Args:
            x (torch.Tensor): Node feature matrix
                :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
            hyperedge_index (torch.Tensor): The hyperedge indices, *i.e.*
                the sparse incidence matrix
                :math:`\mathbf{H} \in {\{ 0, 1 \}}^{N \times M}` mapping from
                nodes to edges.
            hyperedge_weight (torch.Tensor, optional): Hyperedge weights
                :math:`\mathbf{W} \in \mathbb{R}^M`. (default: :obj:`None`)
            hyperedge_attr (torch.Tensor, optional): Hyperedge feature matrix
                in :math:`\mathbb{R}^{M \times F}`.
                These features only need to get passed in case
                :obj:`use_attention=True`. (default: :obj:`None`)
        """
        num_nodes, num_edges = x.size(0), 0
        if hyperedge_index.numel() > 0:
            num_edges = int(hyperedge_index[1].max()) + 1

        if hyperedge_weight is None:
            hyperedge_weight = x.new_ones(num_edges)

        x = self.lin(x)

        alpha = None
        if self.use_attention:
            assert hyperedge_attr is not None
            x = x.view(-1, self.heads, self.out_channels)
            hyperedge_attr = self.lin(hyperedge_attr)
            hyperedge_attr = hyperedge_attr.view(-1, self.heads,
                                                 self.out_channels)
            x_i = x[hyperedge_index[0]]
            x_j = hyperedge_attr[hyperedge_index[1]]
            alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
            alpha = F.leaky_relu(alpha, self.negative_slope)
            alpha = softmax(alpha, hyperedge_index[0], num_nodes=x.size(0))
            alpha = F.dropout(alpha, p=self.dropout, training=self.training)

        D = scatter(hyperedge_weight[hyperedge_index[1]], hyperedge_index[0],
                    dim=0, dim_size=num_nodes, reduce='sum')
        D = 1.0 / D
        D[D == float("inf")] = 0

        B = scatter(x.new_ones(hyperedge_index.size(1)), hyperedge_index[1],
                    dim=0, dim_size=num_edges, reduce='sum')
        B = 1.0 / B
        B[B == float("inf")] = 0

        out = self.propagate(hyperedge_index, x=x, norm=B, alpha=alpha,
                             size=(num_nodes, num_edges))
        out = self.propagate(hyperedge_index.flip([0]), x=out, norm=D,
                             alpha=alpha, size=(num_edges, num_nodes))

        if self.concat is True:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)

        if self.bias is not None:
            out = out + self.bias

        return out


    def message(self, x_j: Tensor, norm_i: Tensor, alpha: Tensor) -> Tensor:
        H, F = self.heads, self.out_channels

        out = norm_i.view(-1, 1, 1) * x_j.view(-1, H, F)

        if alpha is not None:
            out = alpha.view(-1, self.heads, 1) * out

        return out
    

In [None]:
# From HGNN_dataset.ipynb

from typing import Optional
import os
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch import Tensor, scatter
from torch.nn import Parameter
from csv import writer

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.utils import softmax

from networkx.convert_matrix import from_numpy_array
from torch_geometric.utils import from_networkx
from torch_geometric.data import InMemoryDataset
from torch_geometric.nn import HypergraphConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.data import InMemoryDataset, Data, DataLoader
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split

#load matrix and correlation
def matrix_loader(root):
    ts_list = sorted(os.listdir(root))
    ts_path_list = []
    for i in range(0, len(ts_list)):
            ts_path_list.append(os.path.join(root, ts_list[i]))
    return ts_path_list

def filter_SMC_patient_info():
    df          = pd.read_csv('/Users/georgepulickal/Documents/ADNI_FULL/patient_info.csv')
    labels      = df['Research Group']
    label_idx_list = [i for i in range(len(labels)) if labels[i] != 'SMC']
    return label_idx_list

def store_results(List):
    with open('results.csv', 'a') as f_object:
        writer_object = writer(f_object)
        writer_object.writerow(List)
        f_object.close()

corr_list = matrix_loader('ADNI_gsr_172/corr_matrices')
hg_list = matrix_loader('ADNI_gsr_full/hypergraphs/cluster/thresh_0.6')
corr_test = np.loadtxt(corr_list[0], delimiter=',')
hg_test = np.loadtxt(hg_list[0], delimiter=',')
hg_nx = from_numpy_array(hg_test)
hg_matrix_data = from_networkx(hg_nx)
hg_matrix_data.x = torch.tensor(corr_test).float()

class HGNN_ADNI_dataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, hg_data_path = 'hypergraphs/cluster/thresh_0.6'):
        self.hg_data_path = hg_data_path
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return ['data.pt']

    def process(self):
        """ Converts raw data into GNN-readable format by constructing
        graphs out of connectivity matrices.

        """

        # Paths of connectivity matrices
        full_corr_list = matrix_loader('ADNI_gsr_full/corr_matrices')
        idx = filter_SMC_patient_info()
        corr_list = [full_corr_list[i] for i in idx]

        hg_list   = matrix_loader(self.hg_data_path)
        idx = filter_SMC_patient_info()
        new_hg_list = [hg_list[i] for i in idx]
        labels = torch.from_numpy(np.loadtxt('ADNI_gsr_172/labels.csv', delimiter=','))
        assert len(corr_list) == len(new_hg_list)
        assert len(labels) == len(corr_list)

        graphs = []
        for i in range(0, len(corr_list)):
            corr_array = np.loadtxt(corr_list[i], delimiter=',')
            hg_array = np.loadtxt(new_hg_list[i], delimiter=',')

            #Pushing partial correlation matrices through pipeline to get final Data object
            hg_nx = from_numpy_array(hg_array)
            hg_matrix_data = from_networkx(hg_nx)
            hg_matrix_data.x = torch.tensor(corr_array).float()
            hg_matrix_data.y = labels[i].type(torch.LongTensor)
            #hg_matrix_data.pos = coordinates

            # Add to running list of all dataset items
            graphs.append(hg_matrix_data)

        data, slices = self.collate(graphs)
        torch.save((data, slices), self.processed_paths[0])

dataset = HGNN_ADNI_dataset('ADNI_gsr_pyg_test_hypergraph_cluster')

print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of hypergraphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')


#hypergraph convolution
class HyperGraph1(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(HyperGraph1, self).__init__()
        torch.manual_seed(12345)
        self.hconv1 = HypergraphConv(dataset.num_node_features, hidden_channels, use_attention=False, heads=1)
        self.hconv2 = HypergraphConv(hidden_channels, hidden_channels)
        #self.conv3 = HypergraphConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings
        x = self.hconv1(x, edge_index)
        x = x.relu()
        x = self.hconv2(x, edge_index)
        #x = x.relu()
        #x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)

        return x
    
class HyperGraph2(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(HyperGraph2, self).__init__()
        torch.manual_seed(12345)
        self.hconv1 = HypergraphConv(dataset.num_node_features, hidden_channels, use_attention=True, heads=8, concat=True,bias=True,dropout=0.6)
        self.hconv2 = HypergraphConv(hidden_channels * 8, hidden_channels,  use_attention=False, heads=1, concat=True,bias=True,dropout=0.6)
        #self.conv3 = HypergraphConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings
        x = self.hconv1(x, edge_index)
        x = x.relu()
        x = self.hconv2(x, edge_index)
        #x = x.relu()
        #x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)

        return x
    
def train():
    model.train()
    for data in train_loader:  # Iterate in batches over the training dataset.
         out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
         loss = criterion(out, data.y)  # Compute the loss.
         loss.backward()  # Derive gradients.#
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.


def test(loader):
     model.eval()

     correct = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data.x, data.edge_index, data.batch)
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
     return correct / len(loader.dataset)  # Derive ratio of correct predictions.


tot_test_acc = []
dataset = dataset.shuffle()

kf = StratifiedKFold(n_splits=5, shuffle=False)
for train_val_idx, test_idx in kf.split(dataset, dataset.data.y):
    X_train_val = [dataset[i] for i in train_val_idx]
    X_test      = [dataset[i] for i in test_idx]
    Y_train_val = [dataset.data.y[i] for i in train_val_idx]
    Y_test      = [dataset.data.y[i] for i in test_idx]

    X_train, X_valid, Y_train, Y_valid = train_test_split(X_train_val, Y_train_val , test_size=0.125,
                                                    random_state=42, stratify=Y_train_val)

    print(f'Number of training graphs: {len(X_train)}')
    print(f'Number of validation graphs: {len(X_valid)}')
    print(f'Number of test graphs: {len(X_test)}')

    train_loader = DataLoader(X_train, batch_size=64, shuffle=True)
    valid_loader = DataLoader(X_valid, batch_size=32, shuffle=True)
    test_loader = DataLoader(X_test, batch_size=32, shuffle=False)

    model = HyperGraph1(hidden_channels=8)
    #model = HyperGraph2(hidden_channels=8)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(1, 30):
        train()
        train_acc = test(train_loader)
        valid_acc = test(valid_loader)
        print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Valid Acc: {valid_acc:.4f}')
        #wandb.log({"val_acc": valid_acc , "train_acc": train_acc})

    test_acc = test(test_loader)
    print(f'Test Acc: {test_acc: .4f}')
    tot_test_acc.append(test_acc)

print(f'Average Test Accuracy: {sum(tot_test_acc) / len(tot_test_acc)}')
print(f'Max test accuracy: {max(tot_test_acc)}')
print(f'Standard Deviation: {np.std(tot_test_acc)}')
results = [dataset.hg_data_path, (sum(tot_test_acc) / len(tot_test_acc)) , np.std(tot_test_acc)]
store_results(results)