In [4]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import random
import pandas as pd
import torch
visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
print(f"Using GPU(s): {visible_devices}")
print(torch.cuda.is_available())
num_gpus = torch.cuda.device_count()
print(f'Available GPUs: {num_gpus}')
import torch_geometric.transforms as T
from typing import Optional
import torch
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import to_undirected, add_self_loops
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.datasets import WebKB
from torch_geometric.datasets import Actor
from torch_geometric.datasets import CitationFull
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset_sparse = Planetoid(root="/data/ /Pooling", name='Cora')
dataset_sparse = dataset_sparse[0]
num_nodes = dataset_sparse.num_nodes
edge_index = dataset_sparse.edge_index
num_edges = edge_index.size(1)
num_new_edges = int(0.5 * num_edges)
new_edges = set()
while len(new_edges) < num_new_edges:
    i = torch.randint(0, num_nodes, (1,))
    j = torch.randint(0, num_nodes, (1,))
    if i != j:  
        edge = (i.item(), j.item())
        reverse_edge = (j.item(), i.item())
        if edge not in new_edges and reverse_edge not in new_edges:
            new_edges.add(edge)
new_edges = torch.tensor(list(new_edges)).t().contiguous()
new_edge_index = torch.cat([edge_index, new_edges], dim=1)
new_edge_index = to_undirected(new_edge_index)
dataset_sparse.edge_index = new_edge_index
print(f"Original number of edges: {num_edges}")
print(f"New number of edges: {dataset_sparse.edge_index.size(1)}")

Using GPU(s): 1
True
Available GPUs: 1
Original number of edges: 10556
New number of edges: 21092


### TopKPooling with HierarchicalGCN (2019)

In [10]:
import time
import tracemalloc
import torch
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import torch.nn as nn
from sklearn.model_selection import KFold
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, TopKPooling
from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import ToUndirected
from torch.nn import Linear
import torch.optim as optim
from torch_geometric.nn import global_mean_pool
from torch_geometric.utils import to_dense_batch
from sklearn.model_selection import KFold
import numpy as np
import random
from typing import Callable, Optional, Union
dataset = Planetoid(root="/data/ /Pooling", name='Cora')
graph = dataset_sparse
num_classes = dataset.num_classes
in_channels = dataset.num_features
hidden_channels = 64
out_channels = num_classes
depth = 2
pool_ratios = [0.7, 0.7]  
class HierarchicalGCN_TOPK(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, depth, pool_ratios, act=F.relu, sum_res=False):
        super(HierarchicalGCN_TOPK, self).__init__()
        assert depth >= 1
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.depth = depth
        self.pool_ratios = pool_ratios
        self.act = act
        self.sum_res = sum_res
        channels = self.hidden_channels
        self.down_convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        self.down_convs.append(GCNConv(self.in_channels, channels))
        for i in range(self.depth):
            self.pools.append(TopKPooling(channels, ratio=pool_ratios[i]))
            self.down_convs.append(GCNConv(channels, channels))
        in_channels = channels if sum_res else 2 * channels
        self.up_convs = torch.nn.ModuleList()
        for i in range(self.depth):
            self.up_convs.append(GCNConv(in_channels, channels))
        self.up_convs.append(GCNConv(channels, self.out_channels))
    def forward(self, x, edge_index, batch=None):
        x, edge_index = x.to(device), edge_index.to(device)
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        if batch is not None:
            batch = batch.to(device)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.down_convs[0](x, edge_index)
        x = self.act(x)
        xs = [x]
        edge_indices = [edge_index]
        perms = []
        for i in range(1, self.depth + 1):
            x, edge_index, _, batch, perm, _ = self.pools[i - 1](x, edge_index, batch=batch)
            x = self.down_convs[i](x, edge_index)
            x = self.act(x)
            if i < self.depth:
                xs.append(x)
                edge_indices.append(edge_index)
            perms.append(perm)
        for i in range(self.depth):
            j = self.depth - 1 - i
            res = xs[j]
            edge_index = edge_indices[j]
            perm = perms[j]
            up = torch.zeros_like(res)
            up[perm] = x
            x = res + up if self.sum_res else torch.cat((res, up), dim=-1)
            x = self.up_convs[i](x, edge_index)
            x = self.act(x)
        x = self.up_convs[-1](x, edge_index)
        return x
def train_node_classifier(model, graph, optimizer, criterion, train_mask, val_mask, n_epochs=200, patience=150, min_delta=0.0001):
    best_val_acc = 0
    patience_counter = 0
    model.to(device)
    graph = graph.to(device)  
    for epoch in range(1, n_epochs + 1):
        model.train()
        optimizer.zero_grad()
        out = model(graph.x, graph.edge_index)
        loss = criterion(out[train_mask], graph.y[train_mask])
        loss.backward()
        optimizer.step()
        val_acc = eval_node_classifier(model, graph, val_mask)
        if val_acc > best_val_acc + min_delta:
            best_val_acc = val_acc
            patience_counter = 0  
        else:
            patience_counter += 1  
        if patience_counter >= patience:
            break
    return model, best_val_acc
def eval_node_classifier(model, graph, mask):
    model.eval()
    pred = model(graph.x, graph.edge_index).argmax(dim=1)
    correct = (pred[mask] == graph.y[mask]).sum()
    acc = int(correct) / int(mask.sum())
    return acc
kf = KFold(n_splits=5, shuffle=True)
seeds = [42, 123, 456]
results = []
val_accuracies_list = []
times = []
memories = []
gpu_memories = []
for seed in seeds:
    graph = graph.to(device)
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    val_accuracies = []
    start_time = time.time()
    tracemalloc.start()
    for fold, (train_index, test_index) in enumerate(kf.split(graph.x)):
        model = HierarchicalGCN_TOPK(in_channels, hidden_channels, out_channels, depth, pool_ratios).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()
        train_mask = torch.zeros(graph.num_nodes, dtype=torch.bool)
        test_mask = torch.zeros(graph.num_nodes, dtype=torch.bool)
        train_mask[train_index] = True
        test_mask[test_index] = True
        val_mask = test_mask  
        model, best_val_acc = train_node_classifier(model, graph, optimizer, criterion, train_mask, val_mask)
        val_accuracies.append(best_val_acc)
        print(f'Seed {seed}, Fold {fold + 1} Val Acc: {best_val_acc:.3f}')
    mean_val_acc = np.mean(val_accuracies)
    end_time = time.time()
    current, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()
    memory_usage = peak / 10**6  
    if torch.cuda.is_available():
        gpu_memory_usage = torch.cuda.memory_reserved(device) / 10**6  
    else:
        gpu_memory_usage = 0
    elapsed_time = end_time - start_time
    results.append({
        'seed': seed,
        'mean_val_acc': mean_val_acc,
        'time': elapsed_time,
        'memory': memory_usage,
        'gpu_memory': gpu_memory_usage
    })
    print(f'Seed {seed} Results: Mean Val Acc: {mean_val_acc:.3f}, Time: {elapsed_time:.3f} seconds, Memory: {memory_usage:.3f} MB, GPU Memory: {gpu_memory_usage:.3f} MB')
for result in results:
    print(result)
mean_val_acc_values = [result['mean_val_acc'] for result in results]
total_mean_val_acc = np.mean(mean_val_acc_values) * 100
standard_deviation = np.std(mean_val_acc_values) * 100
print(f"Total Mean Val Acc: {total_mean_val_acc:.2f}$\\pm${standard_deviation:.2f}")

Seed 42, Fold 1 Val Acc: 0.762
Seed 42, Fold 2 Val Acc: 0.731
Seed 42, Fold 3 Val Acc: 0.731
Seed 42, Fold 4 Val Acc: 0.723
Seed 42, Fold 5 Val Acc: 0.763
Seed 42 Results: Mean Val Acc: 0.742, Time: 26.687 seconds, Memory: 0.359 MB, GPU Memory: 94.372 MB
Seed 123, Fold 1 Val Acc: 0.766
Seed 123, Fold 2 Val Acc: 0.745
Seed 123, Fold 3 Val Acc: 0.721
Seed 123, Fold 4 Val Acc: 0.738
Seed 123, Fold 5 Val Acc: 0.726
Seed 123 Results: Mean Val Acc: 0.739, Time: 25.960 seconds, Memory: 0.333 MB, GPU Memory: 94.372 MB
Seed 456, Fold 1 Val Acc: 0.732
Seed 456, Fold 2 Val Acc: 0.736
Seed 456, Fold 3 Val Acc: 0.751
Seed 456, Fold 4 Val Acc: 0.734
Seed 456, Fold 5 Val Acc: 0.726
Seed 456 Results: Mean Val Acc: 0.736, Time: 26.387 seconds, Memory: 0.330 MB, GPU Memory: 94.372 MB
{'seed': 42, 'mean_val_acc': 0.7418768032412302, 'time': 26.68696689605713, 'memory': 0.358538, 'gpu_memory': 94.37184}
{'seed': 123, 'mean_val_acc': 0.7392855924862392, 'time': 25.959501028060913, 'memory': 0.333399, 'gpu_

### SAGPooling with HierarchicalGCN (2019)

In [11]:
import time
import tracemalloc
import torch
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import torch.nn as nn
from sklearn.model_selection import KFold
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGPooling
from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import ToUndirected
from torch.nn import Linear
import torch.optim as optim
from torch_geometric.nn import global_mean_pool
from torch_geometric.utils import to_dense_batch
from sklearn.model_selection import KFold
import numpy as np
import random
from typing import Callable, Optional, Union
dataset = Planetoid(root="/data/ /Pooling", name='Cora')
graph = dataset_sparse
num_classes = dataset.num_classes
in_channels = dataset.num_features
hidden_channels = 64
out_channels = num_classes
depth = 2
pool_ratios = [0.7, 0.7]  
class HierarchicalGCN_SAG(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, depth, pool_ratios, act=F.relu, sum_res=False):
        super(HierarchicalGCN_SAG, self).__init__()
        assert depth >= 1
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.depth = depth
        self.pool_ratios = pool_ratios
        self.act = act
        self.sum_res = sum_res
        channels = self.hidden_channels
        self.down_convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        self.down_convs.append(GCNConv(self.in_channels, channels))
        for i in range(self.depth):
            self.pools.append(SAGPooling(channels, ratio=pool_ratios[i]))
            self.down_convs.append(GCNConv(channels, channels))
        in_channels = channels if sum_res else 2 * channels
        self.up_convs = torch.nn.ModuleList()
        for i in range(self.depth):
            self.up_convs.append(GCNConv(in_channels, channels))
        self.up_convs.append(GCNConv(channels, self.out_channels))
    def forward(self, x, edge_index, batch=None):
        x, edge_index = x.to(device), edge_index.to(device)
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        if batch is not None:
            batch = batch.to(device)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.down_convs[0](x, edge_index)
        x = self.act(x)
        xs = [x]
        edge_indices = [edge_index]
        perms = []
        for i in range(1, self.depth + 1):
            x, edge_index, _, batch, perm, _ = self.pools[i - 1](x, edge_index, None, batch)
            x = self.down_convs[i](x, edge_index)
            x = self.act(x)
            if i < self.depth:
                xs.append(x)
                edge_indices.append(edge_index)
            perms.append(perm)
        for i in range(self.depth):
            j = self.depth - 1 - i
            res = xs[j]
            edge_index = edge_indices[j]
            perm = perms[j]
            up = torch.zeros_like(res)
            up[perm] = x
            x = res + up if self.sum_res else torch.cat((res, up), dim=-1)
            x = self.up_convs[i](x, edge_index)
            x = self.act(x)
        x = self.up_convs[-1](x, edge_index)
        return x
def train_node_classifier(model, graph, optimizer, criterion, train_mask, val_mask, n_epochs=200, patience=150, min_delta=0.0001):
    best_val_acc = 0
    patience_counter = 0
    model.to(device)
    graph = graph.to(device)  
    for epoch in range(1, n_epochs + 1):
        model.train()
        optimizer.zero_grad()
        out = model(graph.x, graph.edge_index)
        loss = criterion(out[train_mask], graph.y[train_mask])
        loss.backward()
        optimizer.step()
        val_acc = eval_node_classifier(model, graph, val_mask)
        if val_acc > best_val_acc + min_delta:
            best_val_acc = val_acc
            patience_counter = 0  
        else:
            patience_counter += 1  
        if patience_counter >= patience:
            break
    return model, best_val_acc
def eval_node_classifier(model, graph, mask):
    model.eval()
    pred = model(graph.x, graph.edge_index).argmax(dim=1)
    correct = (pred[mask] == graph.y[mask]).sum()
    acc = int(correct) / int(mask.sum())
    return acc
kf = KFold(n_splits=5, shuffle=True)
seeds = [42, 123, 456]
results = []
val_accuracies_list = []
times = []
memories = []
gpu_memories = []
for seed in seeds:
    graph = graph.to(device)
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    val_accuracies = []
    start_time = time.time()
    tracemalloc.start()
    for fold, (train_index, test_index) in enumerate(kf.split(graph.x)):
        model = HierarchicalGCN_SAG(in_channels, hidden_channels, out_channels, depth, pool_ratios).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()
        train_mask = torch.zeros(graph.num_nodes, dtype=torch.bool)
        test_mask = torch.zeros(graph.num_nodes, dtype=torch.bool)
        train_mask[train_index] = True
        test_mask[test_index] = True
        val_mask = test_mask  
        model, best_val_acc = train_node_classifier(model, graph, optimizer, criterion, train_mask, val_mask)
        val_accuracies.append(best_val_acc)
        print(f'Seed {seed}, Fold {fold + 1} Val Acc: {best_val_acc:.3f}')
    mean_val_acc = np.mean(val_accuracies)
    end_time = time.time()
    current, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()
    memory_usage = peak / 10**6  
    if torch.cuda.is_available():
        gpu_memory_usage = torch.cuda.memory_reserved(device) / 10**6  
    else:
        gpu_memory_usage = 0
    elapsed_time = end_time - start_time
    results.append({
        'seed': seed,
        'mean_val_acc': mean_val_acc,
        'time': elapsed_time,
        'memory': memory_usage,
        'gpu_memory': gpu_memory_usage
    })
    print(f'Seed {seed} Results: Mean Val Acc: {mean_val_acc:.3f}, Time: {elapsed_time:.3f} seconds, Memory: {memory_usage:.3f} MB, GPU Memory: {gpu_memory_usage:.3f} MB')
for result in results:
    print(result)
mean_val_acc_values = [result['mean_val_acc'] for result in results]
total_mean_val_acc = np.mean(mean_val_acc_values) * 100
standard_deviation = np.std(mean_val_acc_values) * 100
print(f"Total Mean Val Acc: {total_mean_val_acc:.2f}$\\pm${standard_deviation:.2f}")

Seed 42, Fold 1 Val Acc: 0.744
Seed 42, Fold 2 Val Acc: 0.749
Seed 42, Fold 3 Val Acc: 0.734
Seed 42, Fold 4 Val Acc: 0.734
Seed 42, Fold 5 Val Acc: 0.760
Seed 42 Results: Mean Val Acc: 0.744, Time: 29.773 seconds, Memory: 0.960 MB, GPU Memory: 96.469 MB
Seed 123, Fold 1 Val Acc: 0.753
Seed 123, Fold 2 Val Acc: 0.753
Seed 123, Fold 3 Val Acc: 0.727
Seed 123, Fold 4 Val Acc: 0.758
Seed 123, Fold 5 Val Acc: 0.738
Seed 123 Results: Mean Val Acc: 0.746, Time: 29.040 seconds, Memory: 0.404 MB, GPU Memory: 96.469 MB
Seed 456, Fold 1 Val Acc: 0.731
Seed 456, Fold 2 Val Acc: 0.745
Seed 456, Fold 3 Val Acc: 0.751
Seed 456, Fold 4 Val Acc: 0.730
Seed 456, Fold 5 Val Acc: 0.756
Seed 456 Results: Mean Val Acc: 0.743, Time: 28.802 seconds, Memory: 0.404 MB, GPU Memory: 96.469 MB
{'seed': 42, 'mean_val_acc': 0.7440935536896959, 'time': 29.772623777389526, 'memory': 0.960185, 'gpu_memory': 96.468992}
{'seed': 123, 'mean_val_acc': 0.7455702505269045, 'time': 29.039822816848755, 'memory': 0.403968, 'gp

### ASAPooling with HierarchicalGCN (2020)

In [12]:
import time
import tracemalloc
import torch
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import torch.nn as nn
from sklearn.model_selection import KFold
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, ASAPooling
from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import ToUndirected
from torch.nn import Linear
import torch.optim as optim
from torch_geometric.nn import global_mean_pool
from torch_geometric.utils import to_dense_batch
from sklearn.model_selection import KFold
import numpy as np
import random
from typing import Callable, Optional, Union
dataset = Planetoid(root="/data/ /Pooling", name='Cora')
graph = dataset_sparse
num_classes = dataset.num_classes
in_channels = dataset.num_features
hidden_channels = 64
out_channels = num_classes
depth = 2
pool_ratios = [0.7, 0.7]  
class HierarchicalGCN_ASA(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, depth, pool_ratios, act=F.relu, sum_res=False):
        super(HierarchicalGCN_ASA, self).__init__()
        assert depth >= 1
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.depth = depth
        self.pool_ratios = pool_ratios
        self.act = act
        self.sum_res = sum_res
        channels = self.hidden_channels
        self.down_convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        self.down_convs.append(GCNConv(self.in_channels, channels))
        for i in range(self.depth):
            self.pools.append(ASAPooling(channels, ratio=pool_ratios[i]))
            self.down_convs.append(GCNConv(channels, channels))
        in_channels = channels if sum_res else 2 * channels
        self.up_convs = torch.nn.ModuleList()
        for i in range(self.depth):
            self.up_convs.append(GCNConv(in_channels, channels))
        self.up_convs.append(GCNConv(channels, self.out_channels))
    def forward(self, x, edge_index, batch=None):
        x, edge_index = x.to(device), edge_index.to(device)
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        if batch is not None:
            batch = batch.to(device)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.down_convs[0](x, edge_index)
        x = self.act(x)
        xs = [x]
        edge_indices = [edge_index]
        perms = []
        for i in range(1, self.depth + 1):
            x, edge_index, _, batch, perm = self.pools[i - 1](x, edge_index, batch=batch)
            x = self.down_convs[i](x, edge_index)
            x = self.act(x)
            if i < self.depth:
                xs.append(x)
                edge_indices.append(edge_index)
            perms.append(perm)
        for i in range(self.depth):
            j = self.depth - 1 - i
            res = xs[j]
            edge_index = edge_indices[j]
            perm = perms[j]
            up = torch.zeros_like(res)
            up[perm] = x
            x = res + up if self.sum_res else torch.cat((res, up), dim=-1)
            x = self.up_convs[i](x, edge_index)
            x = self.act(x)
        x = self.up_convs[-1](x, edge_index)
        return x
def train_node_classifier(model, graph, optimizer, criterion, train_mask, val_mask, n_epochs=200, patience=150, min_delta=0.0001):
    best_val_acc = 0
    patience_counter = 0
    model.to(device)
    graph = graph.to(device)  
    for epoch in range(1, n_epochs + 1):
        model.train()
        optimizer.zero_grad()
        out = model(graph.x, graph.edge_index)
        loss = criterion(out[train_mask], graph.y[train_mask])
        loss.backward()
        optimizer.step()
        val_acc = eval_node_classifier(model, graph, val_mask)
        if val_acc > best_val_acc + min_delta:
            best_val_acc = val_acc
            patience_counter = 0  
        else:
            patience_counter += 1  
        if patience_counter >= patience:
            break
    return model, best_val_acc
def eval_node_classifier(model, graph, mask):
    model.eval()
    pred = model(graph.x, graph.edge_index).argmax(dim=1)
    correct = (pred[mask] == graph.y[mask]).sum()
    acc = int(correct) / int(mask.sum())
    return acc
kf = KFold(n_splits=5, shuffle=True)
seeds = [42, 123, 456]
results = []
val_accuracies_list = []
times = []
memories = []
gpu_memories = []
for seed in seeds:
    graph = graph.to(device)
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    val_accuracies = []
    start_time = time.time()
    tracemalloc.start()
    for fold, (train_index, test_index) in enumerate(kf.split(graph.x)):
        model = HierarchicalGCN_ASA(in_channels, hidden_channels, out_channels, depth, pool_ratios).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()
        train_mask = torch.zeros(graph.num_nodes, dtype=torch.bool)
        test_mask = torch.zeros(graph.num_nodes, dtype=torch.bool)
        train_mask[train_index] = True
        test_mask[test_index] = True
        val_mask = test_mask  
        model, best_val_acc = train_node_classifier(model, graph, optimizer, criterion, train_mask, val_mask)
        val_accuracies.append(best_val_acc)
        print(f'Seed {seed}, Fold {fold + 1} Val Acc: {best_val_acc:.3f}')
    mean_val_acc = np.mean(val_accuracies)
    end_time = time.time()
    current, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()
    memory_usage = peak / 10**6  
    if torch.cuda.is_available():
        gpu_memory_usage = torch.cuda.memory_reserved(device) / 10**6  
    else:
        gpu_memory_usage = 0
    elapsed_time = end_time - start_time
    results.append({
        'seed': seed,
        'mean_val_acc': mean_val_acc,
        'time': elapsed_time,
        'memory': memory_usage,
        'gpu_memory': gpu_memory_usage
    })
    print(f'Seed {seed} Results: Mean Val Acc: {mean_val_acc:.3f}, Time: {elapsed_time:.3f} seconds, Memory: {memory_usage:.3f} MB, GPU Memory: {gpu_memory_usage:.3f} MB')
for result in results:
    print(result)
mean_val_acc_values = [result['mean_val_acc'] for result in results]
total_mean_val_acc = np.mean(mean_val_acc_values) * 100
standard_deviation = np.std(mean_val_acc_values) * 100
print(f"Total Mean Val Acc: {total_mean_val_acc:.2f}$\\pm${standard_deviation:.2f}")

  adj = torch.sparse_csr_tensor(


Seed 42, Fold 1 Val Acc: 0.762


RuntimeError: CUDA error: insufficient resources when calling `cusparseSpGEMM_compute( handle, opA, opB, &alpha_, descA.descriptor(), descB.descriptor(), &beta_, descC.descriptor(), compute_type, CUSPARSE_SPGEMM_DEFAULT, spgemm_desc.descriptor(), &buffer_size2, buffer2.get())`

### PANPooling with HierarchicalGCN (2020)

In [13]:
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_sparse import spspmm
from torch_sparse import coalesce
from torch_sparse import eye
from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_scatter import scatter_max
class PANPooling(torch.nn.Module):
    r""" General Graph pooling layer based on PAN, which can work with all layers.
    """
    def __init__(self, in_channels, ratio=0.5, pan_pool_weight=None, min_score=None, multiplier=1,
                 nonlinearity=torch.tanh, filter_size=3, panpool_filter_weight=None):
        super(PANPooling, self).__init__()
        self.in_channels = in_channels
        self.ratio = ratio
        self.min_score = min_score
        self.multiplier = multiplier
        self.nonlinearity = nonlinearity
        self.filter_size = filter_size
        if panpool_filter_weight is None:
            self.panpool_filter_weight = torch.nn.Parameter(0.5 * torch.ones(filter_size), requires_grad=True)
        self.transform = Parameter(torch.ones(in_channels), requires_grad=True)
        if pan_pool_weight is None:
            self.pan_pool_weight = torch.nn.Parameter(0.5 * torch.ones(2), requires_grad=True)
        else:
            self.pan_pool_weight = pan_pool_weight
    def forward(self, x, edge_index, M=None, batch=None, num_nodes=None):
        """"""
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        num_nodes = maybe_num_nodes(edge_index, num_nodes)
        edge_index, edge_weight = self.panentropy_sparse(edge_index, num_nodes)
        num_nodes = x.size(0)
        degree = torch.zeros(num_nodes, device=edge_index.device)
        degree = scatter_add(edge_weight, edge_index[0], out=degree)
        xtransform = torch.matmul(x, self.transform)
        x_transform_norm = xtransform 
        degree_norm = degree 
        score = self.pan_pool_weight[0] * x_transform_norm + self.pan_pool_weight[1] * degree_norm
        if self.min_score is None:
            score = self.nonlinearity(score)
        else:
            score = softmax(score, batch)
        perm = self.topk(score, self.ratio, batch, self.min_score)
        x = x[perm] * score[perm].view(-1, 1)
        x = self.multiplier * x if self.multiplier != 1 else x
        batch = batch[perm]
        edge_index, edge_weight = self.filter_adj(edge_index, edge_weight, perm, num_nodes=score.size(0))
        return x, edge_index, edge_weight, batch, perm, score[perm]
    def topk(self, x, ratio, batch, min_score=None, tol=1e-7):
        if min_score is not None:
            scores_max = scatter_max(x, batch)[0][batch] - tol
            scores_min = scores_max.clamp(max=min_score)
            perm = torch.nonzero(x > scores_min).view(-1)
        else:
            num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
            batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
            cum_num_nodes = torch.cat(
                [num_nodes.new_zeros(1),
                 num_nodes.cumsum(dim=0)[:-1]], dim=0)
            index = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
            index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes)
            dense_x = x.new_full((batch_size * max_num_nodes, ), -2)
            dense_x[index] = x
            dense_x = dense_x.view(batch_size, max_num_nodes)
            _, perm = dense_x.sort(dim=-1, descending=True)
            perm = perm + cum_num_nodes.view(-1, 1)
            perm = perm.view(-1)
            k = (ratio * num_nodes.to(torch.float)).ceil().to(torch.long)
            mask = [
                torch.arange(k[i], dtype=torch.long, device=x.device) +
                i * max_num_nodes for i in range(batch_size)
            ]
            mask = torch.cat(mask, dim=0)
            perm = perm[mask]
        return perm
    def filter_adj(self, edge_index, edge_weight, perm, num_nodes=None):
        num_nodes = maybe_num_nodes(edge_index, num_nodes)
        mask = perm.new_full((num_nodes, ), -1)
        i = torch.arange(perm.size(0), dtype=torch.long, device=perm.device)
        mask[perm] = i
        row, col = edge_index
        row, col = mask[row], mask[col]
        mask = (row >= 0) & (col >= 0)
        row, col = row[mask], col[mask]
        if edge_weight is not None:
            edge_weight = edge_weight[mask]
        return torch.stack([row, col], dim=0), edge_weight
    def panentropy_sparse(self, edge_index, num_nodes):
        edge_value = torch.ones(edge_index.size(1), device=edge_index.device)
        edge_index, edge_value = coalesce(edge_index, edge_value, num_nodes, num_nodes)
        pan_index, pan_value = eye(num_nodes, device=edge_index.device)
        indextmp = pan_index.clone().to(edge_index.device)
        valuetmp = pan_value.clone().to(edge_index.device)
        pan_value = self.panpool_filter_weight[0] * pan_value
        for i in range(self.filter_size - 1):
            indextmp, valuetmp = spspmm(indextmp, valuetmp, edge_index, edge_value, num_nodes, num_nodes, num_nodes)
            valuetmp = valuetmp * self.panpool_filter_weight[i+1]
            indextmp, valuetmp = coalesce(indextmp, valuetmp, num_nodes, num_nodes)
            pan_index = torch.cat((pan_index, indextmp), 1)
            pan_value = torch.cat((pan_value, valuetmp))
        return coalesce(pan_index, pan_value, num_nodes, num_nodes, op='add')

In [14]:
import time
import tracemalloc
import torch
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import torch.nn as nn
from sklearn.model_selection import KFold
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, ASAPooling
from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import ToUndirected
from torch.nn import Linear
import torch.optim as optim
from torch_geometric.nn import global_mean_pool
from torch_geometric.utils import to_dense_batch
from sklearn.model_selection import KFold
import numpy as np
import random
from typing import Callable, Optional, Union
dataset = Planetoid(root="/data/ /Pooling", name='Cora')
graph = dataset_sparse
num_classes = dataset.num_classes
in_channels = dataset.num_features
hidden_channels = 64
out_channels = num_classes
depth = 2
pool_ratios = [0.7, 0.7]  
class HierarchicalGCN_PAN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, depth, pool_ratios, act=F.relu, sum_res=False):
        super(HierarchicalGCN_PAN, self).__init__()
        assert depth >= 1
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.depth = depth
        self.pool_ratios = pool_ratios
        self.act = act
        self.sum_res = sum_res
        channels = self.hidden_channels
        self.down_convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        self.down_convs.append(GCNConv(self.in_channels, channels))
        for i in range(self.depth):
            self.pools.append(PANPooling(channels, ratio=pool_ratios[i]))
            self.down_convs.append(GCNConv(channels, channels))
        in_channels = channels if sum_res else 2 * channels
        self.up_convs = torch.nn.ModuleList()
        for i in range(self.depth):
            self.up_convs.append(GCNConv(in_channels, channels))
        self.up_convs.append(GCNConv(channels, self.out_channels))
    def forward(self, x, edge_index, batch=None):
        x, edge_index = x.to(device), edge_index.to(device)
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        if batch is not None:
            batch = batch.to(device)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.down_convs[0](x, edge_index)
        x = self.act(x)
        xs = [x]
        edge_indices = [edge_index]
        perms = []
        for i in range(1, self.depth + 1):
            x, edge_index, _, batch, perm, score_perm = self.pools[i - 1](x, edge_index, batch=batch, M=None)
            x = self.down_convs[i](x, edge_index)
            x = self.act(x)
            if i < self.depth:
                xs.append(x)
                edge_indices.append(edge_index)
            perms.append(perm)
        for i in range(self.depth):
            j = self.depth - 1 - i
            res = xs[j]
            edge_index = edge_indices[j]
            perm = perms[j]
            up = torch.zeros_like(res)
            up[perm] = x
            x = res + up if self.sum_res else torch.cat((res, up), dim=-1)
            x = self.up_convs[i](x, edge_index)
            x = self.act(x)
        x = self.up_convs[-1](x, edge_index)
        return x
def train_node_classifier(model, graph, optimizer, criterion, train_mask, val_mask, n_epochs=200, patience=150, min_delta=0.0001):
    best_val_acc = 0
    patience_counter = 0
    model.to(device)
    graph = graph.to(device)  
    for epoch in range(1, n_epochs + 1):
        model.train()
        optimizer.zero_grad()
        out = model(graph.x, graph.edge_index)
        loss = criterion(out[train_mask], graph.y[train_mask])
        loss.backward()
        optimizer.step()
        val_acc = eval_node_classifier(model, graph, val_mask)
        if val_acc > best_val_acc + min_delta:
            best_val_acc = val_acc
            patience_counter = 0  
        else:
            patience_counter += 1  
        if patience_counter >= patience:
            break
    return model, best_val_acc
def eval_node_classifier(model, graph, mask):
    model.eval()
    pred = model(graph.x, graph.edge_index).argmax(dim=1)
    correct = (pred[mask] == graph.y[mask]).sum()
    acc = int(correct) / int(mask.sum())
    return acc
kf = KFold(n_splits=5, shuffle=True)
seeds = [42, 123, 456]
results = []
val_accuracies_list = []
times = []
memories = []
gpu_memories = []
for seed in seeds:
    graph = graph.to(device)
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    val_accuracies = []
    start_time = time.time()
    tracemalloc.start()
    for fold, (train_index, test_index) in enumerate(kf.split(graph.x)):
        model = HierarchicalGCN_PAN(in_channels, hidden_channels, out_channels, depth, pool_ratios).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()
        train_mask = torch.zeros(graph.num_nodes, dtype=torch.bool)
        test_mask = torch.zeros(graph.num_nodes, dtype=torch.bool)
        train_mask[train_index] = True
        test_mask[test_index] = True
        val_mask = test_mask  
        model, best_val_acc = train_node_classifier(model, graph, optimizer, criterion, train_mask, val_mask)
        val_accuracies.append(best_val_acc)
        print(f'Seed {seed}, Fold {fold + 1} Val Acc: {best_val_acc:.3f}')
    mean_val_acc = np.mean(val_accuracies)
    end_time = time.time()
    current, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()
    memory_usage = peak / 10**6  
    if torch.cuda.is_available():
        gpu_memory_usage = torch.cuda.memory_reserved(device) / 10**6  
    else:
        gpu_memory_usage = 0
    elapsed_time = end_time - start_time
    results.append({
        'seed': seed,
        'mean_val_acc': mean_val_acc,
        'time': elapsed_time,
        'memory': memory_usage,
        'gpu_memory': gpu_memory_usage
    })
    print(f'Seed {seed} Results: Mean Val Acc: {mean_val_acc:.3f}, Time: {elapsed_time:.3f} seconds, Memory: {memory_usage:.3f} MB, GPU Memory: {gpu_memory_usage:.3f} MB')
for result in results:
    print(result)
mean_val_acc_values = [result['mean_val_acc'] for result in results]
total_mean_val_acc = np.mean(mean_val_acc_values) * 100
standard_deviation = np.std(mean_val_acc_values) * 100
print(f"Total Mean Val Acc: {total_mean_val_acc:.2f}$\\pm${standard_deviation:.2f}")

Seed 42, Fold 1 Val Acc: 0.769
Seed 42, Fold 2 Val Acc: 0.751
Seed 42, Fold 3 Val Acc: 0.755
Seed 42, Fold 4 Val Acc: 0.738
Seed 42, Fold 5 Val Acc: 0.780
Seed 42 Results: Mean Val Acc: 0.758, Time: 63.517 seconds, Memory: 10.927 MB, GPU Memory: 55807.312 MB
Seed 123, Fold 1 Val Acc: 0.782
Seed 123, Fold 2 Val Acc: 0.762
Seed 123, Fold 3 Val Acc: 0.740
Seed 123, Fold 4 Val Acc: 0.760
Seed 123, Fold 5 Val Acc: 0.749
Seed 123 Results: Mean Val Acc: 0.758, Time: 63.952 seconds, Memory: 0.309 MB, GPU Memory: 55807.312 MB
Seed 456, Fold 1 Val Acc: 0.753
Seed 456, Fold 2 Val Acc: 0.760
Seed 456, Fold 3 Val Acc: 0.766
Seed 456, Fold 4 Val Acc: 0.747
Seed 456, Fold 5 Val Acc: 0.752
Seed 456 Results: Mean Val Acc: 0.756, Time: 63.655 seconds, Memory: 0.309 MB, GPU Memory: 55807.312 MB
{'seed': 42, 'mean_val_acc': 0.7584935646029289, 'time': 63.51655697822571, 'memory': 10.926841, 'gpu_memory': 55807.311872}
{'seed': 123, 'mean_val_acc': 0.7584901542176234, 'time': 63.95180559158325, 'memory': 0

### CoPooling with HierarchicalGCN (2023)

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.utils import add_remaining_self_loops, to_dense_adj, add_self_loops
from typing import Callable, Optional, Union
from torch_sparse import coalesce, transpose
from torch_scatter import scatter
def cumsum(x: Tensor, dim: int = 0) -> Tensor:
    r"""Returns the cumulative sum of elements of :obj:`x`.
    In contrast to :meth:`torch.cumsum`, prepends the output with zero.
    Args:
        x (torch.Tensor): The input tensor.
        dim (int, optional): The dimension to do the operation over.
            (default: :obj:`0`)
    Example:
        >>> x = torch.tensor([2, 4, 1])
        >>> cumsum(x)
        tensor([0, 2, 6, 7])
    """
    size = x.size()[:dim] + (x.size(dim) + 1, ) + x.size()[dim + 1:]
    out = x.new_empty(size)
    out.narrow(dim, 0, 1).zero_()
    torch.cumsum(x, dim=dim, out=out.narrow(dim, 1, x.size(dim)))
    return out
def maybe_num_nodes(edge_index, num_nodes=None):
    if num_nodes is not None:
        return num_nodes
    elif isinstance(edge_index, Tensor):
        return int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0
    else:
        return max(edge_index.size(0), edge_index.size(1))
def maybe_num_nodes(edge_index, num_nodes=None):
    if num_nodes is not None:
        return num_nodes
    elif isinstance(edge_index, Tensor):
        return int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0
    else:
        return max(edge_index.size(0), edge_index.size(1))
def filter_adj(edge_index, edge_attr, perm, num_nodes=None):
    num_nodes = maybe_num_nodes(edge_index, num_nodes)
    mask = perm.new_full((num_nodes, ), -1)
    i = torch.arange(perm.size(0), dtype=torch.long, device=perm.device)
    mask[perm] = i
    row, col = edge_index
    row, col = mask[row], mask[col]
    mask = (row >= 0) & (col >= 0)
    row, col = row[mask], col[mask]
    if edge_attr is not None:
        edge_attr = edge_attr[mask]
    return torch.stack([row, col], dim=0), edge_attr
def topk(
    x: Tensor,
    ratio: Optional[Union[float, int]],
    batch: Tensor,
    min_score: Optional[float] = None,
    tol: float = 1e-7,
) -> Tensor:
    if min_score is not None:
        scores_max = scatter(x, batch, reduce='max')[batch] - tol
        scores_min = scores_max.clamp(max=min_score)
        perm = (x > scores_min).nonzero().view(-1)
        return perm
    if ratio is not None:
        num_nodes = scatter(batch.new_ones(x.size(0)), batch, reduce='sum')
        if ratio >= 1:
            k = num_nodes.new_full((num_nodes.size(0), ), int(ratio))
        else:
            k = (float(ratio) * num_nodes.to(x.dtype)).ceil().to(torch.long)
        x, x_perm = torch.sort(x.view(-1), descending=True)
        batch = batch[x_perm]
        batch, batch_perm = torch.sort(batch, descending=False, stable=True)
        arange = torch.arange(x.size(0), dtype=torch.long, device=x.device)
        ptr = cumsum(num_nodes)
        batched_arange = arange - ptr[batch]
        mask = batched_arange < k[batch]
        return x_perm[batch_perm[mask]]
    raise ValueError("At least one of the 'ratio' and 'min_score' parameters "
                     "must be specified")
class GPR_prop(MessagePassing):
    '''
    propagation class for GPR_GNN
    '''
    def __init__(self, K, alpha, Init, Gamma=None, bias=True, **kwargs):
        super(GPR_prop, self).__init__(aggr='add', **kwargs)
        self.K = K
        self.Init = Init
        self.alpha = alpha
        assert Init in ['SGC', 'PPR', 'NPPR', 'Random', 'WS']
        if Init == 'SGC':
            TEMP = 0.0*np.ones(K+1)
            TEMP[alpha] = 1.0
        elif Init == 'PPR':
            TEMP = alpha*(1-alpha)**np.arange(K+1)
            TEMP[-1] = (1-alpha)**K
        elif Init == 'NPPR':
            TEMP = (alpha)**np.arange(K+1)
            TEMP = TEMP/np.sum(np.abs(TEMP))
        elif Init == 'Random':
            bound = np.sqrt(3/(K+1))
            TEMP = np.random.uniform(-bound, bound, K+1)
            TEMP = TEMP/np.sum(np.abs(TEMP))
        elif Init == 'WS':
            TEMP = Gamma
        self.temp = Parameter(torch.tensor(TEMP))
    def reset_parameters(self):
        torch.nn.init.zeros_(self.temp)
        for k in range(self.K+1):
            self.temp.data[k] = self.alpha*(1-self.alpha)**k
        self.temp.data[-1] = (1-self.alpha)**self.K
    def forward(self, x, edge_index, edge_weight=None):
        edge_index, norm = gcn_norm(
            edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype)
        hidden = x*(self.temp[0])
        for k in range(self.K):
            x = self.propagate(edge_index, x=x, norm=norm)
            gamma = self.temp[k+1]
            hidden = hidden + gamma*x
        return hidden
    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j
    def __repr__(self):
        return '{}(K={}, temp={})'.format(self.__class__.__name__, self.K,
                                           self.temp)
class NodeInformationScore(MessagePassing):
    def __init__(self, improved=False, cached=False, **kwargs):
        super(NodeInformationScore, self).__init__(aggr='add', **kwargs)
        self.improved = improved
        self.cached = cached
        self.cached_result = None
        self.cached_num_edges = None
    @staticmethod
    def norm(edge_index, num_nodes, edge_weight, dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, device=edge_index.device)
        edge_index, edge_weight = add_remaining_self_loops(edge_index, edge_weight, 0, num_nodes) 
        edge_index = edge_index.type(torch.long)
        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        expand_deg = torch.zeros((edge_weight.size(0),), dtype=dtype, device=edge_index.device)
        expand_deg[-num_nodes:] = torch.ones((num_nodes,), dtype=dtype, device=edge_index.device)
        return edge_index, expand_deg - deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
    def forward(self, x, edge_index, edge_weight):
        if self.cached and self.cached_result is not None:
            if edge_index.size(1) != self.cached_num_edges:
                raise RuntimeError(
                    'Cached {} number of edges, but found {}'.format(self.cached_num_edges, edge_index.size(1)))
        if not self.cached or self.cached_result is None:
            self.cached_num_edges = edge_index.size(1)
            edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, x.dtype)
            self.cached_result = edge_index, norm
        edge_index, norm = self.cached_result
        return self.propagate(edge_index, x=x, norm=norm)
    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j
    def update(self, aggr_out):
        return aggr_out
class graph_attention(torch.nn.Module):
    src_nodes_dim = 0  
    trg_nodes_dim = 1  
    nodes_dim = 0      
    head_dim = 1       
    def __init__(self, num_in_features, num_out_features, num_of_heads, dropout_prob=0.6, log_attention_weights=False):
        super().__init__()
        self.num_of_heads = num_of_heads
        self.num_out_features = num_out_features
        self.linear_proj = nn.Linear(num_in_features, num_of_heads * num_out_features, bias=False)
        self.scoring_fn_target = nn.Parameter(torch.Tensor(1, num_of_heads, num_out_features))
        self.scoring_fn_source = nn.Parameter(torch.Tensor(1, num_of_heads, num_out_features))
        self.init_params()
    def init_params(self):
        """
        The reason we're using Glorot (aka Xavier uniform) initialization is because it's a default TF initialization:
            https://stackoverflow.com/questions/37350131/what-is-the-default-variable-initializer-in-tensorflow
        The original repo was developed in TensorFlow (TF) and they used the default initialization.
        Feel free to experiment - there may be better initializations depending on your problem.
        """
        nn.init.xavier_uniform_(self.linear_proj.weight)
        nn.init.xavier_uniform_(self.scoring_fn_target)
        nn.init.xavier_uniform_(self.scoring_fn_source)
    def forward(self, x, edge_index):
        in_nodes_features = x  
        num_of_nodes = in_nodes_features.shape[self.nodes_dim]
        nodes_features_proj = self.linear_proj(in_nodes_features).view(-1, self.num_of_heads, self.num_out_features)
        scores_source = (nodes_features_proj * self.scoring_fn_source).sum(dim=-1)
        scores_target = (nodes_features_proj * self.scoring_fn_target).sum(dim=-1)
        scores_source_lifted, scores_target_lifted, nodes_features_proj_lifted = self.lift(scores_source, scores_target, nodes_features_proj, edge_index)
        scores_per_edge = scores_source_lifted + scores_target_lifted
        return torch.sigmoid(scores_per_edge)
    def lift(self, scores_source, scores_target, nodes_features_matrix_proj, edge_index):
        """
        Lifts i.e. duplicates certain vectors depending on the edge index.
        One of the tensor dims goes from N -> E (that's where the "lift" comes from).
        """
        src_nodes_index = edge_index[self.src_nodes_dim]
        trg_nodes_index = edge_index[self.trg_nodes_dim]
        scores_source = scores_source.index_select(self.nodes_dim, src_nodes_index)
        scores_target = scores_target.index_select(self.nodes_dim, trg_nodes_index)
        nodes_features_matrix_proj_lifted = nodes_features_matrix_proj.index_select(self.nodes_dim, src_nodes_index)
        return scores_source, scores_target, nodes_features_matrix_proj_lifted
class CoPooling(torch.nn.Module):
    def __init__(self, ratio=0.5, K=0.05, edge_ratio=0.6, nhid=64, alpha=0.1, Init='Random', Gamma=None):
        super(CoPooling, self).__init__()
        self.ratio = ratio
        self.calc_information_score = NodeInformationScore()
        self.edge_ratio = edge_ratio
        self.prop1 = GPR_prop(K, alpha, Init, Gamma)
        score_dim = 32
        self.G_att = graph_attention(num_in_features=nhid, num_out_features=score_dim, num_of_heads=1)
        self.weight = Parameter(torch.Tensor(2*nhid, nhid))
        nn.init.xavier_uniform_(self.weight.data)
        self.bias = Parameter(torch.Tensor(nhid))
        nn.init.zeros_(self.bias.data)
        self.reset_parameters()
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight.data)
        nn.init.zeros_(self.bias.data)
        self.prop1.reset_parameters()
        self.G_att.init_params()
    def forward(self, x, edge_index, edge_attr, batch=None, nodes_index=None, node_attr=None):
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        ori_batch = batch.clone()
        device = x.device
        num_nodes = x.shape[0]
        x_cut = self.prop1(x, edge_index) 
        attention = self.G_att(x_cut, edge_index) 
        attention = attention.sum(dim=1) 
        edge_index, attention = add_self_loops(edge_index, attention, 1.0, num_nodes) 
        edge_index_t, attention_t = transpose(edge_index, attention, num_nodes, num_nodes)
        edge_tmp = torch.cat((edge_index, edge_index_t), 1)
        att_tmp = torch.cat((attention, attention_t),0)
        edge_index, attention = coalesce(edge_tmp, att_tmp, num_nodes, num_nodes, 'mean')
        attention_np = attention.cpu().data.numpy()
        cut_val = np.percentile(attention_np, int(100*(1-self.edge_ratio))) 
        attention = attention * (attention >= cut_val) 
        kep_idx = attention > 0.0
        cut_edge_index, cut_edge_attr = edge_index[:, kep_idx], attention[kep_idx]
        x_information_score = self.calc_information_score(x, cut_edge_index, cut_edge_attr)
        score = torch.sum(torch.abs(x_information_score), dim=1)
        perm = topk(score, self.ratio, batch)
        x_topk = x[perm]
        batch = batch[perm]
        if nodes_index is not None:
            nodes_index = nodes_index[perm]
        if node_attr is not None:
            node_attr = node_attr[perm]
        if cut_edge_index is not None or cut_edge_index.nelement() != 0:
            induced_edge_index, induced_edge_attr = filter_adj(cut_edge_index, cut_edge_attr, perm, num_nodes=num_nodes)
        else:
            print('All edges are cut!')
            induced_edge_index, induced_edge_attr = cut_edge_index, cut_edge_attr
        attention_dense = (to_dense_adj(cut_edge_index, edge_attr=cut_edge_attr, max_num_nodes=num_nodes)).squeeze()
        x = F.relu(torch.matmul(torch.cat((x_topk, torch.matmul(attention_dense[perm],x)), 1), self.weight) + self.bias)
        return x, induced_edge_index, perm, induced_edge_attr, batch, nodes_index, node_attr, attention_dense

In [16]:
import time
import tracemalloc
import torch
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import torch.nn as nn
from sklearn.model_selection import KFold
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, ASAPooling
from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import ToUndirected
from torch.nn import Linear
import torch.optim as optim
from torch_geometric.nn import global_mean_pool
from torch_geometric.utils import to_dense_batch
from sklearn.model_selection import KFold
import numpy as np
import random
from typing import Callable, Optional, Union
dataset = Planetoid(root="/data/ /Pooling", name='Cora')
graph = dataset_sparse
num_classes = dataset.num_classes
in_channels = dataset.num_features
hidden_channels = 64
out_channels = num_classes
depth = 2
pool_ratios = [0.7, 0.7]  
class HierarchicalGCN_CO(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, depth, pool_ratios, act=F.relu, sum_res=False):
        super(HierarchicalGCN_CO, self).__init__()
        assert depth >= 1
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.depth = depth
        self.pool_ratios = pool_ratios
        self.act = act
        self.sum_res = sum_res
        channels = self.hidden_channels
        self.down_convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        self.down_convs.append(GCNConv(self.in_channels, channels))
        for i in range(self.depth):
            self.pools.append(CoPooling(ratio=pool_ratios[i], K=1, edge_ratio=0.6, nhid=64, alpha=0.1, Init='Random', Gamma=1.0))
            self.down_convs.append(GCNConv(channels, channels))
        in_channels = channels if sum_res else 2 * channels
        self.up_convs = torch.nn.ModuleList()
        for i in range(self.depth):
            self.up_convs.append(GCNConv(in_channels, channels))
        self.up_convs.append(GCNConv(channels, self.out_channels))
    def forward(self, x, edge_index, batch=None):
        x, edge_index = x.to(device), edge_index.to(device)
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        if batch is not None:
            batch = batch.to(device)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.down_convs[0](x, edge_index)
        x = self.act(x)
        xs = [x]
        edge_indices = [edge_index]
        perms = []
        for i in range(1, self.depth + 1):
            x, edge_index, perm, _, batch, _, _, _ = self.pools[i - 1](x, edge_index, edge_attr=None, batch=batch)
            x = self.down_convs[i](x, edge_index)
            x = self.act(x)
            if i < self.depth:
                xs.append(x)
                edge_indices.append(edge_index)
            perms.append(perm)
        for i in range(self.depth):
            j = self.depth - 1 - i
            res = xs[j]
            edge_index = edge_indices[j]
            perm = perms[j]
            up = torch.zeros_like(res)
            up[perm] = x
            x = res + up if self.sum_res else torch.cat((res, up), dim=-1)
            x = self.up_convs[i](x, edge_index)
            x = self.act(x)
        x = self.up_convs[-1](x, edge_index)
        return x
def train_node_classifier(model, graph, optimizer, criterion, train_mask, val_mask, n_epochs=200, patience=150, min_delta=0.0001):
    best_val_acc = 0
    patience_counter = 0
    model.to(device)
    graph = graph.to(device)  
    for epoch in range(1, n_epochs + 1):
        model.train()
        optimizer.zero_grad()
        out = model(graph.x, graph.edge_index)
        loss = criterion(out[train_mask], graph.y[train_mask])
        loss.backward()
        optimizer.step()
        val_acc = eval_node_classifier(model, graph, val_mask)
        if val_acc > best_val_acc + min_delta:
            best_val_acc = val_acc
            patience_counter = 0  
        else:
            patience_counter += 1  
        if patience_counter >= patience:
            break
    return model, best_val_acc
def eval_node_classifier(model, graph, mask):
    model.eval()
    pred = model(graph.x, graph.edge_index).argmax(dim=1)
    correct = (pred[mask] == graph.y[mask]).sum()
    acc = int(correct) / int(mask.sum())
    return acc
kf = KFold(n_splits=5, shuffle=True)
seeds = [42, 123, 456]
results = []
val_accuracies_list = []
times = []
memories = []
gpu_memories = []
for seed in seeds:
    graph = graph.to(device)
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    val_accuracies = []
    start_time = time.time()
    tracemalloc.start()
    for fold, (train_index, test_index) in enumerate(kf.split(graph.x)):
        model = HierarchicalGCN_CO(in_channels, hidden_channels, out_channels, depth, pool_ratios).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()
        train_mask = torch.zeros(graph.num_nodes, dtype=torch.bool)
        test_mask = torch.zeros(graph.num_nodes, dtype=torch.bool)
        train_mask[train_index] = True
        test_mask[test_index] = True
        val_mask = test_mask  
        model, best_val_acc = train_node_classifier(model, graph, optimizer, criterion, train_mask, val_mask)
        val_accuracies.append(best_val_acc)
        print(f'Seed {seed}, Fold {fold + 1} Val Acc: {best_val_acc:.3f}')
    mean_val_acc = np.mean(val_accuracies)
    end_time = time.time()
    current, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()
    memory_usage = peak / 10**6  
    if torch.cuda.is_available():
        gpu_memory_usage = torch.cuda.memory_reserved(device) / 10**6  
    else:
        gpu_memory_usage = 0
    elapsed_time = end_time - start_time
    results.append({
        'seed': seed,
        'mean_val_acc': mean_val_acc,
        'time': elapsed_time,
        'memory': memory_usage,
        'gpu_memory': gpu_memory_usage
    })
    print(f'Seed {seed} Results: Mean Val Acc: {mean_val_acc:.3f}, Time: {elapsed_time:.3f} seconds, Memory: {memory_usage:.3f} MB, GPU Memory: {gpu_memory_usage:.3f} MB')
for result in results:
    print(result)
mean_val_acc_values = [result['mean_val_acc'] for result in results]
total_mean_val_acc = np.mean(mean_val_acc_values) * 100
standard_deviation = np.std(mean_val_acc_values) * 100
print(f"Total Mean Val Acc: {total_mean_val_acc:.2f}$\\pm${standard_deviation:.2f}")

Seed 42, Fold 1 Val Acc: 0.679
Seed 42, Fold 2 Val Acc: 0.688
Seed 42, Fold 3 Val Acc: 0.686
Seed 42, Fold 4 Val Acc: 0.697
Seed 42, Fold 5 Val Acc: 0.704
Seed 42 Results: Mean Val Acc: 0.691, Time: 60.191 seconds, Memory: 0.534 MB, GPU Memory: 55813.603 MB
Seed 123, Fold 1 Val Acc: 0.690
Seed 123, Fold 2 Val Acc: 0.708
Seed 123, Fold 3 Val Acc: 0.672
Seed 123, Fold 4 Val Acc: 0.725
Seed 123, Fold 5 Val Acc: 0.662
Seed 123 Results: Mean Val Acc: 0.691, Time: 59.388 seconds, Memory: 0.450 MB, GPU Memory: 55813.603 MB
Seed 456, Fold 1 Val Acc: 0.723
Seed 456, Fold 2 Val Acc: 0.720
Seed 456, Fold 3 Val Acc: 0.729
Seed 456, Fold 4 Val Acc: 0.640
Seed 456, Fold 5 Val Acc: 0.706
Seed 456 Results: Mean Val Acc: 0.703, Time: 60.791 seconds, Memory: 0.451 MB, GPU Memory: 55813.603 MB
{'seed': 42, 'mean_val_acc': 0.690922918471329, 'time': 60.19131898880005, 'memory': 0.53436, 'gpu_memory': 55813.603328}
{'seed': 123, 'mean_val_acc': 0.6912864655448773, 'time': 59.3875253200531, 'memory': 0.4499

### CGIPooling with HierarchicalGCN (2021)

In [17]:
from torch_scatter import scatter_add, scatter
from torch_geometric.nn.inits import uniform
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.nn import GCNConv, GATConv, LEConv, GCNConv, GraphConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from dataclasses import dataclass
from typing import Optional
import torch
from torch import Tensor
@dataclass(init=False)
class SelectOutput:
    r"""The output of the :class:`Select` method, which holds an assignment
    from selected nodes to their respective cluster(s).
    Args:
        node_index (torch.Tensor): The indices of the selected nodes.
        num_nodes (int): The number of nodes.
        cluster_index (torch.Tensor): The indices of the clusters each node in
            :obj:`node_index` is assigned to.
        num_clusters (int): The number of clusters.
        weight (torch.Tensor, optional): A weight vector, denoting the strength
            of the assignment of a node to its cluster. (default: :obj:`None`)
    """
    node_index: Tensor
    num_nodes: int
    cluster_index: Tensor
    num_clusters: int
    weight: Optional[Tensor] = None
    def __init__(
        self,
        node_index: Tensor,
        num_nodes: int,
        cluster_index: Tensor,
        num_clusters: int,
        weight: Optional[Tensor] = None,
    ):
        if node_index.dim() != 1:
            raise ValueError(f"Expected 'node_index' to be one-dimensional "
                             f"(got {node_index.dim()} dimensions)")
        if cluster_index.dim() != 1:
            raise ValueError(f"Expected 'cluster_index' to be one-dimensional "
                             f"(got {cluster_index.dim()} dimensions)")
        if node_index.numel() != cluster_index.numel():
            raise ValueError(f"Expected 'node_index' and 'cluster_index' to "
                             f"hold the same number of values (got "
                             f"{node_index.numel()} and "
                             f"{cluster_index.numel()} values)")
        if weight is not None and weight.dim() != 1:
            raise ValueError(f"Expected 'weight' vector to be one-dimensional "
                             f"(got {weight.dim()} dimensions)")
        if weight is not None and weight.numel() != node_index.numel():
            raise ValueError(f"Expected 'weight' to hold {node_index.numel()} "
                             f"values (got {weight.numel()} values)")
        self.node_index = node_index
        self.num_nodes = num_nodes
        self.cluster_index = cluster_index
        self.num_clusters = num_clusters
        self.weight = weight
class Select(torch.nn.Module):
    r"""An abstract base class for implementing custom node selections as
    described in the `"Understanding Pooling in Graph Neural Networks"
    <https://arxiv.org/abs/1905.05178>`_ paper, which maps the nodes of an
    input graph to supernodes in the coarsened graph.
    Specifically, :class:`Select` returns a :class:`SelectOutput` output, which
    holds a (sparse) mapping :math:`\mathbf{C} \in {[0, 1]}^{N \times C}` that
    assigns selected nodes to one or more of :math:`C` super nodes.
    """
    def reset_parameters(self):
        r"""Resets all learnable parameters of the module."""
        pass
    def forward(self, *args, **kwargs) -> SelectOutput:
        raise NotImplementedError
    def __repr__(self) -> str:
        return f'{self.__class__.__name__}()'
def cumsum(x: Tensor, dim: int = 0) -> Tensor:
    r"""Returns the cumulative sum of elements of :obj:`x`.
    In contrast to :meth:`torch.cumsum`, prepends the output with zero.
    Args:
        x (torch.Tensor): The input tensor.
        dim (int, optional): The dimension to do the operation over.
            (default: :obj:`0`)
    Example:
        >>> x = torch.tensor([2, 4, 1])
        >>> cumsum(x)
        tensor([0, 2, 6, 7])
    """
    size = x.size()[:dim] + (x.size(dim) + 1, ) + x.size()[dim + 1:]
    out = x.new_empty(size)
    out.narrow(dim, 0, 1).zero_()
    torch.cumsum(x, dim=dim, out=out.narrow(dim, 1, x.size(dim)))
    return out
def maybe_num_nodes(edge_index, num_nodes=None):
    if num_nodes is not None:
        return num_nodes
    elif isinstance(edge_index, Tensor):
        return int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0
    else:
        return max(edge_index.size(0), edge_index.size(1))
def maybe_num_nodes(edge_index, num_nodes=None):
    if num_nodes is not None:
        return num_nodes
    elif isinstance(edge_index, Tensor):
        return int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0
    else:
        return max(edge_index.size(0), edge_index.size(1))
def filter_adj(edge_index, edge_attr, perm, num_nodes=None):
    num_nodes = maybe_num_nodes(edge_index, num_nodes)
    mask = perm.new_full((num_nodes, ), -1)
    i = torch.arange(perm.size(0), dtype=torch.long, device=perm.device)
    mask[perm] = i
    row, col = edge_index
    row, col = mask[row], mask[col]
    mask = (row >= 0) & (col >= 0)
    row, col = row[mask], col[mask]
    if edge_attr is not None:
        edge_attr = edge_attr[mask]
    return torch.stack([row, col], dim=0), edge_attr
def topk(
    x: Tensor,
    ratio: Optional[Union[float, int]],
    batch: Tensor,
    min_score: Optional[float] = None,
    tol: float = 1e-7,
) -> Tensor:
    if min_score is not None:
        scores_max = scatter(x, batch, reduce='max')[batch] - tol
        scores_min = scores_max.clamp(max=min_score)
        perm = (x > scores_min).nonzero().view(-1)
        return perm
    if ratio is not None:
        num_nodes = scatter(batch.new_ones(x.size(0)), batch, reduce='sum')
        if ratio >= 1:
            k = num_nodes.new_full((num_nodes.size(0), ), int(ratio))
        else:
            k = (float(ratio) * num_nodes.to(x.dtype)).ceil().to(torch.long)
        x, x_perm = torch.sort(x.view(-1), descending=True)
        batch = batch[x_perm]
        batch, batch_perm = torch.sort(batch, descending=False, stable=True)
        arange = torch.arange(x.size(0), dtype=torch.long, device=x.device)
        ptr = cumsum(num_nodes)
        batched_arange = arange - ptr[batch]
        mask = batched_arange < k[batch]
        return x_perm[batch_perm[mask]]
    raise ValueError("At least one of the 'ratio' and 'min_score' parameters "
                     "must be specified")
class Discriminator(torch.nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(in_channels * 2, in_channels)
        self.fc2 = nn.Linear(in_channels, 1)
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.sigmoid(self.fc2(x))
        return x
class CGIPool(torch.nn.Module):
    def __init__(self, in_channels, ratio=0.5, non_lin=torch.tanh):
        super(CGIPool, self).__init__()
        self.in_channels = in_channels
        self.ratio = ratio
        self.non_lin = non_lin
        self.hidden_dim = in_channels
        self.transform = GraphConv(in_channels, self.hidden_dim)
        self.pp_conv = GraphConv(self.hidden_dim, self.hidden_dim)
        self.np_conv = GraphConv(self.hidden_dim, self.hidden_dim)
        self.positive_pooling = GraphConv(self.hidden_dim, 1)
        self.negative_pooling = GraphConv(self.hidden_dim, 1)
        self.discriminator = Discriminator(self.hidden_dim)
        self.loss_fn = torch.nn.BCELoss()
    def forward(self, x, edge_index, edge_attr=None, batch=None):
        device = x.device  
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        x_transform = F.leaky_relu(self.transform(x, edge_index), 0.2)
        x_tp = F.leaky_relu(self.pp_conv(x, edge_index), 0.2)
        x_tn = F.leaky_relu(self.np_conv(x, edge_index), 0.2)
        s_pp = self.positive_pooling(x_tp, edge_index).squeeze()
        s_np = self.negative_pooling(x_tn, edge_index).squeeze()
        perm_positive = topk(s_pp, 1, batch)
        perm_negative = topk(s_np, 1, batch)
        x_pp = x_transform[perm_positive] * self.non_lin(s_pp[perm_positive]).view(-1, 1)
        x_np = x_transform[perm_negative] * self.non_lin(s_np[perm_negative]).view(-1, 1)
        x_pp_readout = gap(x_pp, batch[perm_positive])
        x_np_readout = gap(x_np, batch[perm_negative])
        x_readout = gap(x_transform, batch)
        positive_pair = torch.cat([x_pp_readout, x_readout], dim=1)
        negative_pair = torch.cat([x_np_readout, x_readout], dim=1)
        real = torch.ones(positive_pair.shape[0], device=device)  
        fake = torch.zeros(negative_pair.shape[0], device=device)  
        score = (s_pp - s_np)
        perm = topk(score, self.ratio, batch)
        x = x_transform[perm] * self.non_lin(score[perm]).view(-1, 1)
        batch = batch[perm]
        filter_edge_index, filter_edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=score.size(0))
        return x, filter_edge_index, filter_edge_attr, batch, perm

In [18]:
import time
import tracemalloc
import torch
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import torch.nn as nn
from sklearn.model_selection import KFold
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, TopKPooling
from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import ToUndirected
from torch.nn import Linear
import torch.optim as optim
from torch_geometric.nn import global_mean_pool
from torch_geometric.utils import to_dense_batch
from sklearn.model_selection import KFold
import numpy as np
import random
from typing import Callable, Optional, Union
dataset = Planetoid(root="/data/ /Pooling", name='Cora')
graph = dataset_sparse
num_classes = dataset.num_classes
in_channels = dataset.num_features
hidden_channels = 64
out_channels = num_classes
depth = 2
pool_ratios = [0.7, 0.7]  
class HierarchicalGCN_CGI(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, depth, pool_ratios, act=F.relu, sum_res=False):
        super(HierarchicalGCN_CGI, self).__init__()
        assert depth >= 1
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.depth = depth
        self.pool_ratios = pool_ratios
        self.act = act
        self.sum_res = sum_res
        channels = self.hidden_channels
        self.down_convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        self.down_convs.append(GCNConv(self.in_channels, channels))
        for i in range(self.depth):
            self.pools.append(CGIPool(channels, ratio=pool_ratios[i]))
            self.down_convs.append(GCNConv(channels, channels))
        in_channels = channels if sum_res else 2 * channels
        self.up_convs = torch.nn.ModuleList()
        for i in range(self.depth):
            self.up_convs.append(GCNConv(in_channels, channels))
        self.up_convs.append(GCNConv(channels, self.out_channels))
    def forward(self, x, edge_index, batch=None):
        x, edge_index = x.to(device), edge_index.to(device)
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        if batch is not None:
            batch = batch.to(device)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.down_convs[0](x, edge_index)
        x = self.act(x)
        xs = [x]
        edge_indices = [edge_index]
        perms = []
        for i in range(1, self.depth + 1):
            x, edge_index, _, batch, perm = self.pools[i - 1](x, edge_index, None, batch)
            x = self.down_convs[i](x, edge_index)
            x = self.act(x)
            if i < self.depth:
                xs.append(x)
                edge_indices.append(edge_index)
            perms.append(perm)
        for i in range(self.depth):
            j = self.depth - 1 - i
            res = xs[j]
            edge_index = edge_indices[j]
            perm = perms[j]
            up = torch.zeros_like(res)
            up[perm] = x
            x = res + up if self.sum_res else torch.cat((res, up), dim=-1)
            x = self.up_convs[i](x, edge_index)
            x = self.act(x)
        x = self.up_convs[-1](x, edge_index)
        return x
def train_node_classifier(model, graph, optimizer, criterion, train_mask, val_mask, n_epochs=200, patience=150, min_delta=0.0001):
    best_val_acc = 0
    patience_counter = 0
    model.to(device)
    graph = graph.to(device)  
    for epoch in range(1, n_epochs + 1):
        model.train()
        optimizer.zero_grad()
        out = model(graph.x, graph.edge_index)
        loss = criterion(out[train_mask], graph.y[train_mask])
        loss.backward()
        optimizer.step()
        val_acc = eval_node_classifier(model, graph, val_mask)
        if val_acc > best_val_acc + min_delta:
            best_val_acc = val_acc
            patience_counter = 0  
        else:
            patience_counter += 1  
        if patience_counter >= patience:
            break
    return model, best_val_acc
def eval_node_classifier(model, graph, mask):
    model.eval()
    pred = model(graph.x, graph.edge_index).argmax(dim=1)
    correct = (pred[mask] == graph.y[mask]).sum()
    acc = int(correct) / int(mask.sum())
    return acc
kf = KFold(n_splits=5, shuffle=True)
seeds = [42, 123, 456]
results = []
val_accuracies_list = []
times = []
memories = []
gpu_memories = []
for seed in seeds:
    graph = graph.to(device)
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    val_accuracies = []
    start_time = time.time()
    tracemalloc.start()
    for fold, (train_index, test_index) in enumerate(kf.split(graph.x)):
        model = HierarchicalGCN_CGI(in_channels, hidden_channels, out_channels, depth, pool_ratios).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()
        train_mask = torch.zeros(graph.num_nodes, dtype=torch.bool)
        test_mask = torch.zeros(graph.num_nodes, dtype=torch.bool)
        train_mask[train_index] = True
        test_mask[test_index] = True
        val_mask = test_mask  
        model, best_val_acc = train_node_classifier(model, graph, optimizer, criterion, train_mask, val_mask)
        val_accuracies.append(best_val_acc)
        print(f'Seed {seed}, Fold {fold + 1} Val Acc: {best_val_acc:.3f}')
    mean_val_acc = np.mean(val_accuracies)
    end_time = time.time()
    current, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()
    memory_usage = peak / 10**6  
    if torch.cuda.is_available():
        gpu_memory_usage = torch.cuda.memory_reserved(device) / 10**6  
    else:
        gpu_memory_usage = 0
    elapsed_time = end_time - start_time
    results.append({
        'seed': seed,
        'mean_val_acc': mean_val_acc,
        'time': elapsed_time,
        'memory': memory_usage,
        'gpu_memory': gpu_memory_usage
    })
    print(f'Seed {seed} Results: Mean Val Acc: {mean_val_acc:.3f}, Time: {elapsed_time:.3f} seconds, Memory: {memory_usage:.3f} MB, GPU Memory: {gpu_memory_usage:.3f} MB')
for result in results:
    print(result)
mean_val_acc_values = [result['mean_val_acc'] for result in results]
total_mean_val_acc = np.mean(mean_val_acc_values) * 100
standard_deviation = np.std(mean_val_acc_values) * 100
print(f"Total Mean Val Acc: {total_mean_val_acc:.2f}$\\pm${standard_deviation:.2f}")

Seed 42, Fold 1 Val Acc: 0.716
Seed 42, Fold 2 Val Acc: 0.688
Seed 42, Fold 3 Val Acc: 0.712
Seed 42, Fold 4 Val Acc: 0.699
Seed 42, Fold 5 Val Acc: 0.701
Seed 42 Results: Mean Val Acc: 0.703, Time: 51.433 seconds, Memory: 0.746 MB, GPU Memory: 55815.700 MB
Seed 123, Fold 1 Val Acc: 0.731
Seed 123, Fold 2 Val Acc: 0.696
Seed 123, Fold 3 Val Acc: 0.686
Seed 123, Fold 4 Val Acc: 0.684
Seed 123, Fold 5 Val Acc: 0.704
Seed 123 Results: Mean Val Acc: 0.700, Time: 50.921 seconds, Memory: 0.717 MB, GPU Memory: 55815.700 MB
Seed 456, Fold 1 Val Acc: 0.696
Seed 456, Fold 2 Val Acc: 0.734
Seed 456, Fold 3 Val Acc: 0.697
Seed 456, Fold 4 Val Acc: 0.695
Seed 456, Fold 5 Val Acc: 0.664
Seed 456 Results: Mean Val Acc: 0.697, Time: 49.904 seconds, Memory: 0.714 MB, GPU Memory: 55815.700 MB
{'seed': 42, 'mean_val_acc': 0.7030993581654854, 'time': 51.43271255493164, 'memory': 0.745922, 'gpu_memory': 55815.70048}
{'seed': 123, 'mean_val_acc': 0.7001432361828239, 'time': 50.92082858085632, 'memory': 0.71

### KMISPooling with HierarchicalGCN (2023)

In [19]:
from typing import Callable, Optional, Tuple, Union
from torch_geometric.typing import Adj, OptTensor, PairTensor, Tensor
Scorer = Callable[[Tensor, Adj, OptTensor, OptTensor], Tensor]
from torch_sparse import SparseTensor, remove_diag
from torch_geometric.nn.aggr import Aggregation
from torch_geometric.nn.dense import Linear
from torch.nn import Module
from torch_scatter import scatter_max, scatter_min
def maximal_independent_set(edge_index: Adj, k: int = 1,
                            perm: OptTensor = None) -> Tensor:
    r"""Returns a Maximal :math:`k`-Independent Set of a graph, i.e., a set of
    nodes (as a :class:`ByteTensor`) such that none of them are :math:`k`-hop
    neighbors, and any node in the graph has a :math:`k`-hop neighbor in the
    returned set.
    The algorithm greedily selects the nodes in their canonical order. If a
    permutation :obj:`perm` is provided, the nodes are extracted following
    that permutation instead.
    This method follows `Blelloch's Alogirithm
    <https://arxiv.org/abs/1202.3205>`_ for :math:`k = 1`, and its
    generalization by `Bacciu et al. <https://arxiv.org/abs/2208.03523>`_ for
    higher values of :math:`k`.
    Args:
        edge_index (Tensor or SparseTensor): The graph connectivity.
        k (int): The :math:`k` value (defaults to 1).
        perm (LongTensor, optional): Permutation vector. Must be of size
            :obj:`(n,)` (defaults to :obj:`None`).
    :rtype: :class:`ByteTensor`
    """
    if isinstance(edge_index, SparseTensor):
        row, col, _ = edge_index.coo()
        device = edge_index.device()
        n = edge_index.size(0)
    else:
        row, col = edge_index[0], edge_index[1]
        device = row.device
        n = edge_index.max().item() + 1
    if perm is None:
        rank = torch.arange(n, dtype=torch.long, device=device)
    else:
        rank = torch.zeros_like(perm)
        rank[perm] = torch.arange(n, dtype=torch.long, device=device)
    mis = torch.zeros(n, dtype=torch.bool, device=device)
    mask = mis.clone()
    min_rank = rank.clone()
    while not mask.all():
        for _ in range(k):
            min_neigh = torch.full_like(min_rank, fill_value=n)
            scatter_min(min_rank[row], col, out=min_neigh)
            torch.minimum(min_neigh, min_rank, out=min_rank)  
        mis = mis | torch.eq(rank, min_rank)
        mask = mis.clone().byte()
        for _ in range(k):
            max_neigh = torch.full_like(mask, fill_value=0)
            scatter_max(mask[row], col, out=max_neigh)
            torch.maximum(max_neigh, mask, out=mask)  
        mask = mask.to(dtype=torch.bool)
        min_rank = rank.clone()
        min_rank[mask] = n
    return mis
def maximal_independent_set_cluster(edge_index: Adj, k: int = 1,
                                    perm: OptTensor = None) -> PairTensor:
    r"""Computes the Maximal :math:`k`-Independent Set (:math:`k`-MIS)
    clustering of a graph, as defined in `"Generalizing Downsampling from
    Regular Data to Graphs" <https://arxiv.org/abs/2208.03523>`_.
    The algorithm greedily selects the nodes in their canonical order. If a
    permutation :obj:`perm` is provided, the nodes are extracted following
    that permutation instead.
    This method returns both the :math:`k`-MIS and the clustering, where the
    :math:`c`-th cluster refers to the :math:`c`-th element of the
    :math:`k`-MIS.
    Args:
        edge_index (Tensor or SparseTensor): The graph connectivity.
        k (int): The :math:`k` value (defaults to 1).
        perm (LongTensor, optional): Permutation vector. Must be of size
            :obj:`(n,)` (defaults to :obj:`None`).
    :rtype: (:class:`ByteTensor`, :class:`LongTensor`)
    """
    mis = maximal_independent_set(edge_index=edge_index, k=k, perm=perm)
    n, device = mis.size(0), mis.device
    if isinstance(edge_index, SparseTensor):
        row, col, _ = edge_index.coo()
    else:
        row, col = edge_index[0], edge_index[1]
    if perm is None:
        rank = torch.arange(n, dtype=torch.long, device=device)
    else:
        rank = torch.zeros_like(perm)
        rank[perm] = torch.arange(n, dtype=torch.long, device=device)
    min_rank = torch.full((n, ), fill_value=n, dtype=torch.long, device=device)
    rank_mis = rank[mis]
    min_rank[mis] = rank_mis
    for _ in range(k):
        min_neigh = torch.full_like(min_rank, fill_value=n)
        scatter_min(min_rank[row], col, out=min_neigh)
        torch.minimum(min_neigh, min_rank, out=min_rank)
    _, clusters = torch.unique(min_rank, return_inverse=True)
    perm = torch.argsort(rank_mis)
    return mis, perm[clusters]
class KMISPooling(Module):
    _heuristics = {None, 'greedy', 'w-greedy'}
    _passthroughs = {None, 'before', 'after'}
    _scorers = {
        'linear',
        'random',
        'constant',
        'canonical',
        'first',
        'last',
    }
    def __init__(self, in_channels: Optional[int] = None, k: int = 1,
                 scorer: Union[Scorer, str] = 'linear',
                 score_heuristic: Optional[str] = 'greedy',
                 score_passthrough: Optional[str] = 'before',
                 aggr_x: Optional[Union[str, Aggregation]] = None,
                 aggr_edge: str = 'sum',
                 aggr_score: Callable[[Tensor, Tensor], Tensor] = torch.mul,
                 remove_self_loops: bool = True) -> None:
        super(KMISPooling, self).__init__()
        assert score_heuristic in self._heuristics, \
            "Unrecognized `score_heuristic` value."
        assert score_passthrough in self._passthroughs, \
            "Unrecognized `score_passthrough` value."
        if not callable(scorer):
            assert scorer in self._scorers, \
                "Unrecognized `scorer` value."
        self.k = k
        self.scorer = scorer
        self.score_heuristic = score_heuristic
        self.score_passthrough = score_passthrough
        self.aggr_x = aggr_x
        self.aggr_edge = aggr_edge
        self.aggr_score = aggr_score
        self.remove_self_loops = remove_self_loops
        if scorer == 'linear':
            assert self.score_passthrough is not None, \
                "`'score_passthrough'` must not be `None`" \
                " when using `'linear'` scorer"
            self.lin = Linear(in_features=in_channels, out_features=1)
    def _apply_heuristic(self, x: Tensor, adj: SparseTensor) -> Tensor:
        if self.score_heuristic is None:
            return x
        row, col, _ = adj.coo()
        x = x.view(-1)
        if self.score_heuristic == 'greedy':
            k_sums = torch.ones_like(x)
        else:
            k_sums = x.clone()
        for _ in range(self.k):
            scatter_add(k_sums[row], col, out=k_sums)
        return x / k_sums
    def _scorer(self, x: Tensor, edge_index: Adj, edge_attr: OptTensor = None,
                batch: OptTensor = None) -> Tensor:
        if self.scorer == 'linear':
            return self.lin(x).sigmoid()
        if self.scorer == 'random':
            return torch.rand((x.size(0), 1), device=x.device)
        if self.scorer == 'constant':
            return torch.ones((x.size(0), 1), device=x.device)
        if self.scorer == 'canonical':
            return -torch.arange(x.size(0), device=x.device).view(-1, 1)
        if self.scorer == 'first':
            return x[..., [0]]
        if self.scorer == 'last':
            return x[..., [-1]]
        return self.scorer(x, edge_index, edge_attr, batch)
    def forward(self, x: Tensor, edge_index: Adj,
                edge_attr: OptTensor = None,
                batch: OptTensor = None) \
            -> Tuple[Tensor, Adj, OptTensor, OptTensor, Tensor, Tensor]:
        """"""
        edge_index = edge_index.long()
        adj, n = edge_index, x.size(0)
        if not isinstance(edge_index, SparseTensor):
            adj = SparseTensor.from_edge_index(edge_index, edge_attr, (n, n))
        score = self._scorer(x, edge_index, edge_attr, batch)
        updated_score = self._apply_heuristic(score, adj)
        perm = torch.argsort(updated_score.view(-1), 0, descending=True)
        mis, cluster = maximal_independent_set_cluster(adj, self.k, perm)
        row, col, val = adj.coo()
        c = mis.sum()
        if val is None:
            val = torch.ones_like(row, dtype=torch.float)
        adj = SparseTensor(row=cluster[row], col=cluster[col], value=val,
                           is_sorted=False,
                           sparse_sizes=(c, c)).coalesce(self.aggr_edge)
        if self.remove_self_loops:
            adj = remove_diag(adj)
        if self.score_passthrough == 'before':
            x = self.aggr_score(x, score)
        if self.aggr_x is None:
            x = x[mis]
        elif isinstance(self.aggr_x, str):
            x = scatter(x, cluster, dim=0, dim_size=mis.sum(),
                        reduce=self.aggr_x)
        else:
            x = self.aggr_x(x, cluster, dim_size=c)
        if self.score_passthrough == 'after':
            x = self.aggr_score(x, score[mis])
        if isinstance(edge_index, SparseTensor):
            edge_index, edge_attr = adj, None
        else:
            row, col, edge_attr = adj.coo()
            edge_index = torch.stack([row, col])
        if batch is not None:
            batch = batch[mis]
        perm = perm[mis]
        return x, edge_index, edge_attr, batch, mis, cluster, perm
    def __repr__(self):
        if self.scorer == 'linear':
            channels = f"in_channels={self.lin.in_channels}, "
        else:
            channels = ""
        return f'{self.__class__.__name__}({channels}k={self.k})'

In [20]:
import time
import tracemalloc
import torch
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import torch.nn as nn
from sklearn.model_selection import KFold
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, TopKPooling
from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import ToUndirected
from torch.nn import Linear
import torch.optim as optim
from torch_geometric.nn import global_mean_pool
from torch_geometric.utils import to_dense_batch
from sklearn.model_selection import KFold
import numpy as np
import random
from typing import Callable, Optional, Union
dataset = Planetoid(root="/data/ /Pooling", name='Cora')
graph = dataset_sparse
num_classes = dataset.num_classes
in_channels = dataset.num_features
hidden_channels = 64
out_channels = num_classes
depth = 2
pool_ratios = [0.7, 0.7]  
class HierarchicalGCN_KMIS(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, depth, pool_ratios, act=F.relu, sum_res=False):
        super(HierarchicalGCN_KMIS, self).__init__()
        assert depth >= 1
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.depth = depth
        self.pool_ratios = pool_ratios
        self.act = act
        self.sum_res = sum_res
        channels = self.hidden_channels
        self.down_convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        self.down_convs.append(GCNConv(self.in_channels, channels))
        for i in range(self.depth):
            self.pools.append(KMISPooling(64, k=1, aggr_x='sum'))
            self.down_convs.append(GCNConv(channels, channels))
        in_channels = channels if sum_res else 2 * channels
        self.up_convs = torch.nn.ModuleList()
        for i in range(self.depth):
            self.up_convs.append(GCNConv(in_channels, channels))
        self.up_convs.append(GCNConv(channels, self.out_channels))
    def forward(self, x, edge_index, batch=None):
        x, edge_index = x.to(device), edge_index.to(device)
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        if batch is not None:
            batch = batch.to(device)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.down_convs[0](x, edge_index)
        x = self.act(x)
        xs = [x]
        edge_indices = [edge_index]
        perms = []
        for i in range(1, self.depth + 1):
            x, edge_index, _, batch, _, cluster, perm = self.pools[i - 1](x, edge_index, batch=batch)
            x = self.down_convs[i](x, edge_index)
            x = self.act(x)
            if i < self.depth:
                xs.append(x)
                edge_indices.append(edge_index)
            perms.append(perm)
        for i in range(self.depth):
            j = self.depth - 1 - i
            res = xs[j]
            edge_index = edge_indices[j]
            perm = perms[j]
            up = torch.zeros_like(res)
            up[perm] = x
            x = res + up if self.sum_res else torch.cat((res, up), dim=-1)
            x = self.up_convs[i](x, edge_index)
            x = self.act(x)
        x = self.up_convs[-1](x, edge_index)
        return x
def train_node_classifier(model, graph, optimizer, criterion, train_mask, val_mask, n_epochs=200, patience=150, min_delta=0.0001):
    best_val_acc = 0
    patience_counter = 0
    model.to(device)
    graph = graph.to(device)  
    for epoch in range(1, n_epochs + 1):
        model.train()
        optimizer.zero_grad()
        out = model(graph.x, graph.edge_index)
        loss = criterion(out[train_mask], graph.y[train_mask])
        loss.backward()
        optimizer.step()
        val_acc = eval_node_classifier(model, graph, val_mask)
        if val_acc > best_val_acc + min_delta:
            best_val_acc = val_acc
            patience_counter = 0  
        else:
            patience_counter += 1  
        if patience_counter >= patience:
            break
    return model, best_val_acc
def eval_node_classifier(model, graph, mask):
    model.eval()
    pred = model(graph.x, graph.edge_index).argmax(dim=1)
    correct = (pred[mask] == graph.y[mask]).sum()
    acc = int(correct) / int(mask.sum())
    return acc
kf = KFold(n_splits=5, shuffle=True)
seeds = [42, 123, 456]
results = []
val_accuracies_list = []
times = []
memories = []
gpu_memories = []
for seed in seeds:
    graph = graph.to(device)
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    val_accuracies = []
    start_time = time.time()
    tracemalloc.start()
    for fold, (train_index, test_index) in enumerate(kf.split(graph.x)):
        model = HierarchicalGCN_KMIS(in_channels, hidden_channels, out_channels, depth, pool_ratios).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()
        train_mask = torch.zeros(graph.num_nodes, dtype=torch.bool)
        test_mask = torch.zeros(graph.num_nodes, dtype=torch.bool)
        train_mask[train_index] = True
        test_mask[test_index] = True
        val_mask = test_mask  
        model, best_val_acc = train_node_classifier(model, graph, optimizer, criterion, train_mask, val_mask)
        val_accuracies.append(best_val_acc)
        print(f'Seed {seed}, Fold {fold + 1} Val Acc: {best_val_acc:.3f}')
    mean_val_acc = np.mean(val_accuracies)
    end_time = time.time()
    current, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()
    memory_usage = peak / 10**6  
    if torch.cuda.is_available():
        gpu_memory_usage = torch.cuda.memory_reserved(device) / 10**6  
    else:
        gpu_memory_usage = 0
    elapsed_time = end_time - start_time
    results.append({
        'seed': seed,
        'mean_val_acc': mean_val_acc,
        'time': elapsed_time,
        'memory': memory_usage,
        'gpu_memory': gpu_memory_usage
    })
    print(f'Seed {seed} Results: Mean Val Acc: {mean_val_acc:.3f}, Time: {elapsed_time:.3f} seconds, Memory: {memory_usage:.3f} MB, GPU Memory: {gpu_memory_usage:.3f} MB')
for result in results:
    print(result)
mean_val_acc_values = [result['mean_val_acc'] for result in results]
total_mean_val_acc = np.mean(mean_val_acc_values) * 100
standard_deviation = np.std(mean_val_acc_values) * 100
print(f"Total Mean Val Acc: {total_mean_val_acc:.2f}$\\pm${standard_deviation:.2f}")

Seed 42, Fold 1 Val Acc: 0.768
Seed 42, Fold 2 Val Acc: 0.751
Seed 42, Fold 3 Val Acc: 0.734
Seed 42, Fold 4 Val Acc: 0.741
Seed 42, Fold 5 Val Acc: 0.786
Seed 42 Results: Mean Val Acc: 0.756, Time: 39.296 seconds, Memory: 0.349 MB, GPU Memory: 55815.700 MB
Seed 123, Fold 1 Val Acc: 0.777
Seed 123, Fold 2 Val Acc: 0.766
Seed 123, Fold 3 Val Acc: 0.738
Seed 123, Fold 4 Val Acc: 0.758
Seed 123, Fold 5 Val Acc: 0.745
Seed 123 Results: Mean Val Acc: 0.757, Time: 38.617 seconds, Memory: 0.314 MB, GPU Memory: 55815.700 MB
Seed 456, Fold 1 Val Acc: 0.753
Seed 456, Fold 2 Val Acc: 0.762
Seed 456, Fold 3 Val Acc: 0.762
Seed 456, Fold 4 Val Acc: 0.756
Seed 456, Fold 5 Val Acc: 0.750
Seed 456 Results: Mean Val Acc: 0.757, Time: 36.226 seconds, Memory: 0.343 MB, GPU Memory: 55815.700 MB
{'seed': 42, 'mean_val_acc': 0.7559139491579758, 'time': 39.296337366104126, 'memory': 0.349017, 'gpu_memory': 55815.70048}
{'seed': 123, 'mean_val_acc': 0.7566430895362558, 'time': 38.616806507110596, 'memory': 0.

### GSAPooling with HierarchicalGCN (2021)

In [21]:
import math
from typing import Union, Optional, Callable
from torch_scatter import scatter_add, scatter_max
from torch_geometric.utils import softmax
from torch_geometric.nn import GCNConv, GCNConv, GATConv, ChebConv, GraphConv
def uniform(size, tensor):
    if tensor is not None:
        bound = 1.0 / math.sqrt(size)
        tensor.data.uniform_(-bound, bound)
def maybe_num_nodes(edge_index, num_nodes=None):
    if num_nodes is not None:
        return num_nodes
    elif isinstance(edge_index, Tensor):
        return int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0
    else:
        return max(edge_index.size(0), edge_index.size(1))
def topk(x, ratio, batch, min_score=None, tol=1e-7):
    if min_score is not None:
        scores_max = scatter_max(x, batch)[0][batch] - tol
        scores_min = scores_max.clamp(max=min_score)
        perm = (x > scores_min).nonzero(as_tuple=False).view(-1)
    else:
        num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
        batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
        cum_num_nodes = torch.cat(
            [num_nodes.new_zeros(1),
             num_nodes.cumsum(dim=0)[:-1]], dim=0)
        index = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
        index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes)
        dense_x = x.new_full((batch_size * max_num_nodes, ),
                             torch.finfo(x.dtype).min)
        dense_x[index] = x
        dense_x = dense_x.view(batch_size, max_num_nodes)
        _, perm = dense_x.sort(dim=-1, descending=True)
        perm = perm + cum_num_nodes.view(-1, 1)
        perm = perm.view(-1)
        if isinstance(ratio, int):
            k = num_nodes.new_full((num_nodes.size(0), ), ratio)
            k = torch.min(k, num_nodes)
        else:
            k = (ratio * num_nodes.to(torch.float)).ceil().to(torch.long)
        mask = [
            torch.arange(k[i], dtype=torch.long, device=x.device) +
            i * max_num_nodes for i in range(batch_size)
        ]
        mask = torch.cat(mask, dim=0)
        perm = perm[mask]
    return perm
def filter_adj(edge_index, edge_attr, perm, num_nodes=None):
    num_nodes = maybe_num_nodes(edge_index, num_nodes)
    mask = perm.new_full((num_nodes, ), -1)
    i = torch.arange(perm.size(0), dtype=torch.long, device=perm.device)
    mask[perm] = i
    row, col = edge_index
    row, col = mask[row], mask[col]
    mask = (row >= 0) & (col >= 0)
    row, col = row[mask], col[mask]
    if edge_attr is not None:
        edge_attr = edge_attr[mask]
    return torch.stack([row, col], dim=0), edge_attr
class GSAPool(torch.nn.Module):
    def __init__(self, in_channels, pooling_ratio=0.5, alpha=0.6,
                        min_score=None, multiplier=1,
                        non_linearity=torch.tanh,
                        cus_drop_ratio =0):
        super(GSAPool,self).__init__()
        self.in_channels = in_channels
        self.ratio = pooling_ratio
        self.alpha = alpha
        self.sbtl_layer = GCNConv(in_channels,1)
        self.fbtl_layer = nn.Linear(in_channels, 1)
        self.fusion = GCNConv(in_channels,in_channels)
        self.min_score = min_score
        self.multiplier = multiplier
        self.fusion_flag = 0
        self.non_linearity = non_linearity
        self.dropout = torch.nn.Dropout(cus_drop_ratio)
    def conv_selection(self, conv, in_channels, conv_type=0):
        if(conv_type == 0):
            out_channels = 1
        elif(conv_type == 1):
            out_channels = in_channels
        if(conv == "GCNConv"):
            return GCNConv(in_channels,out_channels)
        elif(conv == "ChebConv"):
            return ChebConv(in_channels,out_channels,1)
        elif(conv == "GCNConv"):
            return GCNConv(in_channels,out_channels)
        elif(conv == "GATConv"):
            return GATConv(in_channels,out_channels, heads=1, concat=True)
        elif(conv == "GraphConv"):
            return GraphConv(in_channels,out_channels)
        else:
            raise ValueError
    def forward(self, x, edge_index, edge_attr=None, batch=None):
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        x = x.unsqueeze(-1) if x.dim() == 1 else x
        score_s = self.sbtl_layer(x,edge_index).squeeze()
        score_f = self.fbtl_layer(x).squeeze()
        score = score_s*self.alpha + score_f*(1-self.alpha)
        score = score.unsqueeze(-1) if score.dim()==0 else score
        if self.min_score is None:
            score = self.non_linearity(score)
        else:
            score = softmax(score, batch)
        sc = self.dropout(score)
        perm = topk(sc, self.ratio, batch)
        if(self.fusion_flag == 1):
            x = self.fusion(x, edge_index)
        x_ae = x[perm]
        x = x[perm] * score[perm].view(-1, 1)
        x = self.multiplier * x if self.multiplier != 1 else x
        batch = batch[perm]
        edge_index, edge_attr = filter_adj(
            edge_index, edge_attr, perm, num_nodes=score.size(0))
        return x, edge_index, edge_attr, batch, perm, x_ae

In [22]:
import time
import tracemalloc
import torch
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import torch.nn as nn
from sklearn.model_selection import KFold
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, TopKPooling
from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import ToUndirected
from torch.nn import Linear
import torch.optim as optim
from torch_geometric.nn import global_mean_pool
from torch_geometric.utils import to_dense_batch
from sklearn.model_selection import KFold
import numpy as np
import random
from typing import Callable, Optional, Union
dataset = Planetoid(root="/data/ /Pooling", name='Cora')
graph = dataset_sparse
num_classes = dataset.num_classes
in_channels = dataset.num_features
hidden_channels = 64
out_channels = num_classes
depth = 2
pool_ratios = [0.7, 0.7]  
class HierarchicalGCN_GSA(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, depth, pool_ratios, act=F.relu, sum_res=False):
        super(HierarchicalGCN_GSA, self).__init__()
        assert depth >= 1
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.depth = depth
        self.pool_ratios = pool_ratios
        self.act = act
        self.sum_res = sum_res
        channels = self.hidden_channels
        self.down_convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        self.down_convs.append(GCNConv(self.in_channels, channels))
        for i in range(self.depth):
            self.pools.append(GSAPool(64, pooling_ratio=pool_ratios[i], alpha = 0.6, cus_drop_ratio = 0))
            self.down_convs.append(GCNConv(channels, channels))
        in_channels = channels if sum_res else 2 * channels
        self.up_convs = torch.nn.ModuleList()
        for i in range(self.depth):
            self.up_convs.append(GCNConv(in_channels, channels))
        self.up_convs.append(GCNConv(channels, self.out_channels))
    def forward(self, x, edge_index, batch=None):
        x, edge_index = x.to(device), edge_index.to(device)
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        if batch is not None:
            batch = batch.to(device)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.down_convs[0](x, edge_index)
        x = self.act(x)
        xs = [x]
        edge_indices = [edge_index]
        perms = []
        for i in range(1, self.depth + 1):
            x, edge_index, edge_attr, batch, perm, _ = self.pools[i - 1](x, edge_index, None, batch)
            x = self.down_convs[i](x, edge_index)
            x = self.act(x)
            if i < self.depth:
                xs.append(x)
                edge_indices.append(edge_index)
            perms.append(perm)
        for i in range(self.depth):
            j = self.depth - 1 - i
            res = xs[j]
            edge_index = edge_indices[j]
            perm = perms[j]
            up = torch.zeros_like(res)
            up[perm] = x
            x = res + up if self.sum_res else torch.cat((res, up), dim=-1)
            x = self.up_convs[i](x, edge_index)
            x = self.act(x)
        x = self.up_convs[-1](x, edge_index)
        return x
def train_node_classifier(model, graph, optimizer, criterion, train_mask, val_mask, n_epochs=200, patience=150, min_delta=0.0001):
    best_val_acc = 0
    patience_counter = 0
    model.to(device)
    graph = graph.to(device)  
    for epoch in range(1, n_epochs + 1):
        model.train()
        optimizer.zero_grad()
        out = model(graph.x, graph.edge_index)
        loss = criterion(out[train_mask], graph.y[train_mask])
        loss.backward()
        optimizer.step()
        val_acc = eval_node_classifier(model, graph, val_mask)
        if val_acc > best_val_acc + min_delta:
            best_val_acc = val_acc
            patience_counter = 0  
        else:
            patience_counter += 1  
        if patience_counter >= patience:
            break
    return model, best_val_acc
def eval_node_classifier(model, graph, mask):
    model.eval()
    pred = model(graph.x, graph.edge_index).argmax(dim=1)
    correct = (pred[mask] == graph.y[mask]).sum()
    acc = int(correct) / int(mask.sum())
    return acc
kf = KFold(n_splits=5, shuffle=True)
seeds = [42, 123, 456]
results = []
val_accuracies_list = []
times = []
memories = []
gpu_memories = []
for seed in seeds:
    graph = graph.to(device)
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    val_accuracies = []
    start_time = time.time()
    tracemalloc.start()
    for fold, (train_index, test_index) in enumerate(kf.split(graph.x)):
        model = HierarchicalGCN_GSA(in_channels, hidden_channels, out_channels, depth, pool_ratios).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()
        train_mask = torch.zeros(graph.num_nodes, dtype=torch.bool)
        test_mask = torch.zeros(graph.num_nodes, dtype=torch.bool)
        train_mask[train_index] = True
        test_mask[test_index] = True
        val_mask = test_mask  
        model, best_val_acc = train_node_classifier(model, graph, optimizer, criterion, train_mask, val_mask)
        val_accuracies.append(best_val_acc)
        print(f'Seed {seed}, Fold {fold + 1} Val Acc: {best_val_acc:.3f}')
    mean_val_acc = np.mean(val_accuracies)
    end_time = time.time()
    current, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()
    memory_usage = peak / 10**6  
    if torch.cuda.is_available():
        gpu_memory_usage = torch.cuda.memory_reserved(device) / 10**6  
    else:
        gpu_memory_usage = 0
    elapsed_time = end_time - start_time
    results.append({
        'seed': seed,
        'mean_val_acc': mean_val_acc,
        'time': elapsed_time,
        'memory': memory_usage,
        'gpu_memory': gpu_memory_usage
    })
    print(f'Seed {seed} Results: Mean Val Acc: {mean_val_acc:.3f}, Time: {elapsed_time:.3f} seconds, Memory: {memory_usage:.3f} MB, GPU Memory: {gpu_memory_usage:.3f} MB')
for result in results:
    print(result)
mean_val_acc_values = [result['mean_val_acc'] for result in results]
total_mean_val_acc = np.mean(mean_val_acc_values) * 100
standard_deviation = np.std(mean_val_acc_values) * 100
print(f"Total Mean Val Acc: {total_mean_val_acc:.2f}$\\pm${standard_deviation:.2f}")

Seed 42, Fold 1 Val Acc: 0.749
Seed 42, Fold 2 Val Acc: 0.734
Seed 42, Fold 3 Val Acc: 0.749
Seed 42, Fold 4 Val Acc: 0.726
Seed 42, Fold 5 Val Acc: 0.762
Seed 42 Results: Mean Val Acc: 0.744, Time: 28.365 seconds, Memory: 0.482 MB, GPU Memory: 55815.700 MB
Seed 123, Fold 1 Val Acc: 0.769
Seed 123, Fold 2 Val Acc: 0.753
Seed 123, Fold 3 Val Acc: 0.729
Seed 123, Fold 4 Val Acc: 0.734
Seed 123, Fold 5 Val Acc: 0.734
Seed 123 Results: Mean Val Acc: 0.744, Time: 27.543 seconds, Memory: 0.452 MB, GPU Memory: 55815.700 MB
Seed 456, Fold 1 Val Acc: 0.745
Seed 456, Fold 2 Val Acc: 0.742
Seed 456, Fold 3 Val Acc: 0.758
Seed 456, Fold 4 Val Acc: 0.732
Seed 456, Fold 5 Val Acc: 0.726
Seed 456 Results: Mean Val Acc: 0.741, Time: 28.906 seconds, Memory: 0.455 MB, GPU Memory: 55815.700 MB
{'seed': 42, 'mean_val_acc': 0.7440915074585126, 'time': 28.364673852920532, 'memory': 0.481577, 'gpu_memory': 55815.70048}
{'seed': 123, 'mean_val_acc': 0.7437150009208041, 'time': 27.542988061904907, 'memory': 0.

### HGPSLPooling with HierarchicalGCN (2019)

In [23]:
import torch
import torch.nn as nn
from torch.autograd import Function
from torch_scatter import scatter_add, scatter_max
from torch_geometric.utils import softmax, dense_to_sparse, add_remaining_self_loops
def topk(x, ratio, batch, min_score=None, tol=1e-7):
    if min_score is not None:
        scores_max = scatter_max(x, batch)[0][batch] - tol
        scores_min = scores_max.clamp(max=min_score)
        perm = torch.nonzero(x > scores_min).view(-1)
    else:
        num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
        batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
        cum_num_nodes = torch.cat(
            [num_nodes.new_zeros(1),
            num_nodes.cumsum(dim=0)[:-1]], dim=0)
        index = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
        index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes)
        dense_x = x.new_full((batch_size * max_num_nodes, ), -2)
        dense_x[index] = x
        dense_x = dense_x.view(batch_size, max_num_nodes)
        _, perm = dense_x.sort(dim=-1, descending=True)
        perm = perm + cum_num_nodes.view(-1, 1)
        perm = perm.view(-1)
        k = (ratio * num_nodes.to(torch.float)).ceil().to(torch.long)
        mask = [
            torch.arange(k[i], dtype=torch.long, device=x.device) +
            i * max_num_nodes for i in range(batch_size)
        ]
        mask = torch.cat(mask, dim=0)
        perm = perm[mask]
    return perm
def filter_adj(edge_index, edge_weight, perm, num_nodes=None):
        num_nodes = maybe_num_nodes(edge_index, num_nodes)
        mask = perm.new_full((num_nodes, ), -1)
        i = torch.arange(perm.size(0), dtype=torch.long, device=perm.device)
        mask[perm] = i
        row, col = edge_index
        row, col = mask[row], mask[col]
        mask = (row >= 0) & (col >= 0)
        row, col = row[mask], col[mask]
        if edge_weight is not None:
            edge_weight = edge_weight[mask]
        return torch.stack([row, col], dim=0), edge_weight
def scatter_sort(x, batch, fill_value=-1e16):
    num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
    batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
    cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0)
    index = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
    index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes)
    dense_x = x.new_full((batch_size * max_num_nodes,), fill_value)
    dense_x[index] = x
    dense_x = dense_x.view(batch_size, max_num_nodes)
    sorted_x, _ = dense_x.sort(dim=-1, descending=True)
    cumsum_sorted_x = sorted_x.cumsum(dim=-1)
    cumsum_sorted_x = cumsum_sorted_x.view(-1)
    sorted_x = sorted_x.view(-1)
    filled_index = sorted_x != fill_value
    sorted_x = sorted_x[filled_index]
    cumsum_sorted_x = cumsum_sorted_x[filled_index]
    return sorted_x, cumsum_sorted_x
def _make_ix_like(batch):
    num_nodes = scatter_add(batch.new_ones(batch.size(0)), batch, dim=0)
    idx = [torch.arange(1, i + 1, dtype=torch.long, device=batch.device) for i in num_nodes]
    idx = torch.cat(idx, dim=0)
    return idx
def _threshold_and_support(x, batch):
    """Sparsemax building block: compute the threshold
    Args:
        x: input tensor to apply the sparsemax
        batch: group indicators
    Returns:
        the threshold value
    """
    num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
    cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0)
    sorted_input, input_cumsum = scatter_sort(x, batch)
    input_cumsum = input_cumsum - 1.0
    rhos = _make_ix_like(batch).to(x.dtype)
    support = rhos * sorted_input > input_cumsum
    support_size = scatter_add(support.to(batch.dtype), batch)
    idx = support_size + cum_num_nodes - 1
    mask = idx < 0
    idx[mask] = 0
    tau = input_cumsum.gather(0, idx)
    tau /= support_size.to(x.dtype)
    return tau, support_size
class SparsemaxFunction(Function):
    @staticmethod
    def forward(ctx, x, batch):
        """sparsemax: normalizing sparse transform
        Parameters:
            ctx: context object
            x (Tensor): shape (N, )
            batch: group indicator
        Returns:
            output (Tensor): same shape as input
        """
        max_val, _ = scatter_max(x, batch)
        x -= max_val[batch]
        tau, supp_size = _threshold_and_support(x, batch)
        output = torch.clamp(x - tau[batch], min=0)
        ctx.save_for_backward(supp_size, output, batch)
        return output
    @staticmethod
    def backward(ctx, grad_output):
        supp_size, output, batch = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[output == 0] = 0
        v_hat = scatter_add(grad_input, batch) / supp_size.to(output.dtype)
        grad_input = torch.where(output != 0, grad_input - v_hat[batch], grad_input)
        return grad_input, None
sparsemax = SparsemaxFunction.apply
class Sparsemax(nn.Module):
    def __init__(self):
        super(Sparsemax, self).__init__()
    def forward(self, x, batch):
        return sparsemax(x, batch)
if __name__ == '__main__':
    sparse_attention = Sparsemax()
    input_x = torch.tensor([1.7301, 0.6792, -1.0565, 1.6614, -0.3196, -0.7790, -0.3877, -0.4943, 0.1831, -0.0061])
    input_batch = torch.cat([torch.zeros(4, dtype=torch.long), torch.ones(6, dtype=torch.long)], dim=0)
    res = sparse_attention(input_x, input_batch)
    print(res)
class TwoHopNeighborhood(object):
    def __call__(self, data):
        edge_index, edge_attr = data.edge_index, data.edge_attr
        n = data.num_nodes
        fill = 1e16
        value = edge_index.new_full((edge_index.size(1),), fill, dtype=torch.float)
        index, value = spspmm(edge_index, value, edge_index, value, n, n, n, True)
        edge_index = torch.cat([edge_index, index], dim=1)
        if edge_attr is None:
            data.edge_index, _ = coalesce(edge_index, None, n, n)
        else:
            value = value.view(-1, *[1 for _ in range(edge_attr.dim() - 1)])
            value = value.expand(-1, *list(edge_attr.size())[1:])
            edge_attr = torch.cat([edge_attr, value], dim=0)
            data.edge_index, edge_attr = coalesce(edge_index, edge_attr, n, n, op='min', fill_value=fill)
            edge_attr[edge_attr >= fill] = 0
            data.edge_attr = edge_attr
        return data
    def __repr__(self):
        return '{}()'.format(self.__class__.__name__)
class GCN(MessagePassing):
    def __init__(self, in_channels, out_channels, cached=False, bias=True, **kwargs):
        super(GCN, self).__init__(aggr='add', **kwargs)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.cached = cached
        self.cached_result = None
        self.cached_num_edges = None
        self.weight = Parameter(torch.Tensor(in_channels, out_channels))
        nn.init.xavier_uniform_(self.weight.data)
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
            nn.init.zeros_(self.bias.data)
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
    def reset_parameters(self):
        self.cached_result = None
        self.cached_num_edges = None
    @staticmethod
    def norm(edge_index, num_nodes, edge_weight, dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, device=edge_index.device)
        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
    def forward(self, x, edge_index, edge_weight=None):
        x = torch.matmul(x, self.weight)
        if self.cached and self.cached_result is not None:
            if edge_index.size(1) != self.cached_num_edges:
                raise RuntimeError(
                    'Cached {} number of edges, but found {}'.format(self.cached_num_edges, edge_index.size(1)))
        if not self.cached or self.cached_result is None:
            self.cached_num_edges = edge_index.size(1)
            edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, x.dtype)
            self.cached_result = edge_index, norm
        edge_index, norm = self.cached_result
        return self.propagate(edge_index, x=x, norm=norm)
    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j
    def update(self, aggr_out):
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out
    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels)
class NodeInformationScore(MessagePassing):
    def __init__(self, improved=False, cached=False, **kwargs):
        super(NodeInformationScore, self).__init__(aggr='add', **kwargs)
        self.improved = improved
        self.cached = cached
        self.cached_result = None
        self.cached_num_edges = None
    @staticmethod
    def norm(edge_index, num_nodes, edge_weight, dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, device=edge_index.device)
        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        edge_index, edge_weight = add_remaining_self_loops(edge_index, edge_weight, 0, num_nodes)
        row, col = edge_index
        expand_deg = torch.zeros((edge_weight.size(0),), dtype=dtype, device=edge_index.device)
        expand_deg[-num_nodes:] = torch.ones((num_nodes,), dtype=dtype, device=edge_index.device)
        return edge_index, expand_deg - deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
    def forward(self, x, edge_index, edge_weight):
        if self.cached and self.cached_result is not None:
            if edge_index.size(1) != self.cached_num_edges:
                raise RuntimeError(
                    'Cached {} number of edges, but found {}'.format(self.cached_num_edges, edge_index.size(1)))
        if not self.cached or self.cached_result is None:
            self.cached_num_edges = edge_index.size(1)
            edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, x.dtype)
            self.cached_result = edge_index, norm
        edge_index, norm = self.cached_result
        return self.propagate(edge_index, x=x, norm=norm)
    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j
    def update(self, aggr_out):
        return aggr_out
class HGPSLPool(torch.nn.Module):
    def __init__(self, in_channels, ratio=0.5, sample=False, sparse=False, sl=True, lamb=1.0, negative_slop=0.2):
        super(HGPSLPool, self).__init__()
        self.in_channels = in_channels
        self.ratio = ratio
        self.sample = sample
        self.sparse = sparse
        self.sl = sl
        self.negative_slop = negative_slop
        self.lamb = lamb
        self.att = Parameter(torch.Tensor(1, self.in_channels * 2))
        nn.init.xavier_uniform_(self.att.data)
        self.sparse_attention = Sparsemax()
        self.neighbor_augment = TwoHopNeighborhood()
        self.calc_information_score = NodeInformationScore()
    def forward(self, x, edge_index, edge_attr, batch):
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        x_information_score = self.calc_information_score(x, edge_index, edge_attr)
        score = torch.sum(torch.abs(x_information_score), dim=1)
        original_x = x
        perm = topk(score, self.ratio, batch)
        x = x[perm]
        batch = batch[perm]
        induced_edge_index, induced_edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=score.size(0))
        if self.sl is False:
            return x, induced_edge_index, induced_edge_attr, batch
        if self.sample:
            k_hop = 3
            if edge_attr is None:
                edge_attr = torch.ones((edge_index.size(1),), dtype=torch.float, device=edge_index.device)
            hop_data = Data(x=original_x, edge_index=edge_index, edge_attr=edge_attr)
            for _ in range(k_hop - 1):
                hop_data = self.neighbor_augment(hop_data)
            hop_edge_index = hop_data.edge_index
            hop_edge_attr = hop_data.edge_attr
            new_edge_index, new_edge_attr = filter_adj(hop_edge_index, hop_edge_attr, perm, num_nodes=score.size(0))
            new_edge_index, new_edge_attr = add_remaining_self_loops(new_edge_index, new_edge_attr, 0, x.size(0))
            row, col = new_edge_index
            weights = (torch.cat([x[row], x[col]], dim=1) * self.att).sum(dim=-1)
            weights = F.leaky_relu(weights, self.negative_slop) + new_edge_attr * self.lamb
            adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device)
            adj[row, col] = weights
            new_edge_index, weights = dense_to_sparse(adj)
            row, col = new_edge_index
            if self.sparse:
                new_edge_attr = self.sparse_attention(weights, row)
            else:
                new_edge_attr = softmax(weights, row, x.size(0))
            adj[row, col] = new_edge_attr
            new_edge_index, new_edge_attr = dense_to_sparse(adj)
            del adj
            torch.cuda.empty_cache()
        else:
            if edge_attr is None:
                induced_edge_attr = torch.ones((induced_edge_index.size(1),), dtype=x.dtype,
                                               device=induced_edge_index.device)
            num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
            shift_cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0)
            cum_num_nodes = num_nodes.cumsum(dim=0)
            adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device)
            for idx_i, idx_j in zip(shift_cum_num_nodes, cum_num_nodes):
                adj[idx_i:idx_j, idx_i:idx_j] = 1.0
            new_edge_index, _ = dense_to_sparse(adj)
            row, col = new_edge_index
            weights = (torch.cat([x[row], x[col]], dim=1) * self.att).sum(dim=-1)
            weights = F.leaky_relu(weights, self.negative_slop)
            adj[row, col] = weights
            induced_row, induced_col = induced_edge_index
            adj[induced_row, induced_col] += induced_edge_attr * self.lamb
            weights = adj[row, col]
            if self.sparse:
                new_edge_attr = self.sparse_attention(weights, row)
            else:
                new_edge_attr = softmax(weights, row, num_nodes=x.size(0))
            adj[row, col] = new_edge_attr
            new_edge_index, new_edge_attr = dense_to_sparse(adj)
            del adj
            torch.cuda.empty_cache()
        return x, new_edge_index, new_edge_attr, batch, perm

tensor([0.5344, 0.0000, 0.0000, 0.4656, 0.0613, 0.0000, 0.0000, 0.0000, 0.5640,
        0.3748])


In [24]:
import time
import tracemalloc
import torch
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import torch.nn as nn
from sklearn.model_selection import KFold
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, TopKPooling
from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import ToUndirected
from torch.nn import Linear
import torch.optim as optim
from torch_geometric.nn import global_mean_pool
from torch_geometric.utils import to_dense_batch
from sklearn.model_selection import KFold
import numpy as np
import random
from typing import Callable, Optional, Union
dataset = Planetoid(root="/data/ /Pooling", name='Cora')
graph = dataset_sparse
num_classes = dataset.num_classes
in_channels = dataset.num_features
hidden_channels = 64
out_channels = num_classes
depth = 2
pool_ratios = [0.7, 0.7]  
class HierarchicalGCN_HGPSL(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, depth, pool_ratios, act=F.relu, sum_res=False):
        super(HierarchicalGCN_HGPSL, self).__init__()
        assert depth >= 1
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.depth = depth
        self.pool_ratios = pool_ratios
        self.act = act
        self.sum_res = sum_res
        channels = self.hidden_channels
        self.down_convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        self.down_convs.append(GCNConv(self.in_channels, channels))
        for i in range(self.depth):
            self.pools.append(HGPSLPool(hidden_channels, ratio=pool_ratios[i], sample=False, sparse=False, sl=True, lamb=1.0, negative_slop=0.2))
            self.down_convs.append(GCNConv(channels, channels))
        in_channels = channels if sum_res else 2 * channels
        self.up_convs = torch.nn.ModuleList()
        for i in range(self.depth):
            self.up_convs.append(GCNConv(in_channels, channels))
        self.up_convs.append(GCNConv(channels, self.out_channels))
    def forward(self, x, edge_index, batch=None):
        x, edge_index = x.to(device), edge_index.to(device)
        edge_attr = None
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        if batch is not None:
            batch = batch.to(device)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.down_convs[0](x, edge_index)
        x = self.act(x)
        xs = [x]
        edge_indices = [edge_index]
        perms = []
        for i in range(1, self.depth + 1):
            x, edge_index, _, batch, perm = self.pools[i - 1](x, edge_index, edge_attr, batch)
            x = self.down_convs[i](x, edge_index)
            x = self.act(x)
            if i < self.depth:
                xs.append(x)
                edge_indices.append(edge_index)
            perms.append(perm)
        for i in range(self.depth):
            j = self.depth - 1 - i
            res = xs[j]
            edge_index = edge_indices[j]
            perm = perms[j]
            up = torch.zeros_like(res)
            up[perm] = x
            x = res + up if self.sum_res else torch.cat((res, up), dim=-1)
            x = self.up_convs[i](x, edge_index)
            x = self.act(x)
        x = self.up_convs[-1](x, edge_index)
        return x
def train_node_classifier(model, graph, optimizer, criterion, train_mask, val_mask, n_epochs=200, patience=150, min_delta=0.0001):
    best_val_acc = 0
    patience_counter = 0
    model.to(device)
    graph = graph.to(device)  
    for epoch in range(1, n_epochs + 1):
        model.train()
        optimizer.zero_grad()
        out = model(graph.x, graph.edge_index)
        loss = criterion(out[train_mask], graph.y[train_mask])
        loss.backward()
        optimizer.step()
        val_acc = eval_node_classifier(model, graph, val_mask)
        if val_acc > best_val_acc + min_delta:
            best_val_acc = val_acc
            patience_counter = 0  
        else:
            patience_counter += 1  
        if patience_counter >= patience:
            break
    return model, best_val_acc
def eval_node_classifier(model, graph, mask):
    model.eval()
    pred = model(graph.x, graph.edge_index).argmax(dim=1)
    correct = (pred[mask] == graph.y[mask]).sum()
    acc = int(correct) / int(mask.sum())
    return acc
kf = KFold(n_splits=5, shuffle=True)
seeds = [42, 123, 456]
results = []
val_accuracies_list = []
times = []
memories = []
gpu_memories = []
for seed in seeds:
    graph = graph.to(device)
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    val_accuracies = []
    start_time = time.time()
    tracemalloc.start()
    for fold, (train_index, test_index) in enumerate(kf.split(graph.x)):
        model = HierarchicalGCN_HGPSL(in_channels, hidden_channels, out_channels, depth, pool_ratios).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()
        train_mask = torch.zeros(graph.num_nodes, dtype=torch.bool)
        test_mask = torch.zeros(graph.num_nodes, dtype=torch.bool)
        train_mask[train_index] = True
        test_mask[test_index] = True
        val_mask = test_mask  
        model, best_val_acc = train_node_classifier(model, graph, optimizer, criterion, train_mask, val_mask)
        val_accuracies.append(best_val_acc)
        print(f'Seed {seed}, Fold {fold + 1} Val Acc: {best_val_acc:.3f}')
    mean_val_acc = np.mean(val_accuracies)
    end_time = time.time()
    current, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()
    memory_usage = peak / 10**6  
    if torch.cuda.is_available():
        gpu_memory_usage = torch.cuda.memory_reserved(device) / 10**6  
    else:
        gpu_memory_usage = 0
    elapsed_time = end_time - start_time
    results.append({
        'seed': seed,
        'mean_val_acc': mean_val_acc,
        'time': elapsed_time,
        'memory': memory_usage,
        'gpu_memory': gpu_memory_usage
    })
    print(f'Seed {seed} Results: Mean Val Acc: {mean_val_acc:.3f}, Time: {elapsed_time:.3f} seconds, Memory: {memory_usage:.3f} MB, GPU Memory: {gpu_memory_usage:.3f} MB')
for result in results:
    print(result)
mean_val_acc_values = [result['mean_val_acc'] for result in results]
total_mean_val_acc = np.mean(mean_val_acc_values) * 100
standard_deviation = np.std(mean_val_acc_values) * 100
print(f"Total Mean Val Acc: {total_mean_val_acc:.2f}$\\pm${standard_deviation:.2f}")

Seed 42, Fold 1 Val Acc: 0.762
Seed 42, Fold 2 Val Acc: 0.744
Seed 42, Fold 3 Val Acc: 0.755
Seed 42, Fold 4 Val Acc: 0.747
Seed 42, Fold 5 Val Acc: 0.780
Seed 42 Results: Mean Val Acc: 0.757, Time: 113.429 seconds, Memory: 0.397 MB, GPU Memory: 29787.947 MB
Seed 123, Fold 1 Val Acc: 0.780
Seed 123, Fold 2 Val Acc: 0.760
Seed 123, Fold 3 Val Acc: 0.738
Seed 123, Fold 4 Val Acc: 0.762
Seed 123, Fold 5 Val Acc: 0.739
Seed 123 Results: Mean Val Acc: 0.756, Time: 109.022 seconds, Memory: 0.360 MB, GPU Memory: 29787.947 MB
Seed 456, Fold 1 Val Acc: 0.747
Seed 456, Fold 2 Val Acc: 0.758
Seed 456, Fold 3 Val Acc: 0.768
Seed 456, Fold 4 Val Acc: 0.750
Seed 456, Fold 5 Val Acc: 0.747
Seed 456 Results: Mean Val Acc: 0.754, Time: 109.003 seconds, Memory: 0.359 MB, GPU Memory: 29787.947 MB
{'seed': 42, 'mean_val_acc': 0.7573899639181235, 'time': 113.42894172668457, 'memory': 0.397067, 'gpu_memory': 29787.947008}
{'seed': 123, 'mean_val_acc': 0.755904400079121, 'time': 109.02213883399963, 'memory':