## Download PyG

In [1]:
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cu113.html

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.12.0+cu113.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_scatter-2.1.0%2Bpt112cu113-cp38-cp38-linux_x86_64.whl (8.9 MB)
[K     |████████████████████████████████| 8.9 MB 43.3 MB/s 
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_sparse-0.6.15%2Bpt112cu113-cp38-cp38-linux_x86_64.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 14.1 MB/s 
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_cluster-1.6.0%2Bpt112cu113-cp38-cp38-linux_x86_64.whl (2.5 MB)
[K     |████████████████████████████████| 2.5 MB 19.1 MB/s 
[?25hCollecting torch-spline-conv
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_spline_conv-1.2.1%2Bpt112cu113-cp38-cp38-linux_x86_64.whl (722 kB)
[K     |█

In [2]:
import torch
from torch_geometric.utils import add_self_loops, degree
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.datasets import TUDataset

# Models definition

## Laplacian utils

### Builders

In [3]:
# Copyright 2022 Twitter, Inc.
# SPDX-License-Identifier: Apache-2.0

import itertools
import math

import torch
import torch_sparse

from torch_geometric.utils import degree


def remove_duplicate_edges(edge_index):
    processed_edges = set()
    new_edge_index = []

    for e in range(edge_index.size(1)):
        source, target = sorted((edge_index[0, e].item(), edge_index[1, e].item()))
        if (source, target) in processed_edges:
            continue
        processed_edges.add((source, target))
        new_edge_index.append([source, target])
    print(f"Removed {edge_index.size(1) - len(new_edge_index)} edges")
    return torch.tensor(new_edge_index, dtype=torch.long).t()


def build_sheaf_laplacian(N, K, edge_index, maps):
    """
    Builds a sheaf laplacian given the edge_index and the restriction maps

    Args:
        N: The number of nodes in the graph
        K: The dimensionality of the Stalks
        edge_index: Edge index of the graph without duplicate edges. We assume that edge i has orientation
            edge_index[0, i] --> edge_index[1, i].
        maps: Tensor of shape [edge_index.size(1), 2 (source/target), K, K] containing the restriction maps of the sheaf
    Returns:
        (index, value): The sheaf Laplacian as a sparse matrix of size (N*K, N*K)
    """
    E = edge_index.size(1)
    index = []
    values = []

    for e in range(E):
        source = edge_index[0, e]
        target = edge_index[1, e]

        top_x = e * K
        # Generate the positions in the block matrix
        top_y = source * K
        for i, j in itertools.product(range(K), range(K)):
            index.append([top_x + i, top_y + j])
            values.append(-maps[e, 0, i, j])

        top_y = target * K
        for i, j in itertools.product(range(K), range(K)):
            index.append([top_x + i, top_y + j])
            values.append(maps[e, 1, i, j])

    index = torch.tensor(index, dtype=torch.long).T
    values = torch.tensor(values)

    index_t, values_t = torch_sparse.transpose(index, values, E * K, N * K)
    index, value = torch_sparse.spspmm(index_t, values_t, index, values, N * K, E * K, N * K, coalesced=True)
    return torch_sparse.coalesce(index, value, N * K, N * K)


def sym_matrix_pow(matrix: torch.Tensor, p: float) -> torch.Tensor:
    r"""
    Power of a matrix using Eigen Decomposition.
    Args:
        matrix: a batch of matrices
        p: power
    Returns:
        Power of a matrix
    """
    vals, vecs = torch.linalg.eigh(matrix)
    vals[vals > 0] = vals[vals > 0].pow(p)
    matrix_pow = vecs @ torch.diag(vals) @ vecs.T
    return matrix_pow


def build_norm_sheaf_laplacian(N, K, edge_index, maps, augmented=True):
    """
    Builds a normalised sheaf laplacian given the edge_index and the restriction maps.

    Args:
        N: The number of nodes in the graph
        K: The dimensionality of the Stalks
        edge_index: Edge index of the graph without duplicate edges. We assume that edge i has orientation
            edge_index[0, i] --> edge_index[1, i].
        maps: Tensor of shape [edge_index.size(1), 2 (source/target), K, K] containing the restriction maps of the sheaf
        augmented: Use D* = D + I instead of D.
    Returns:
        (index, value): The normalised sheaf Laplacian as a sparse matrix of size (N*K, N*K)
    """
    index, values = build_sheaf_laplacian(N, K, edge_index, maps)
    block_diag_indices = []
    block_diag_values = []

    for i in range(N):
        low = i * K
        high = low + K

        mask1 = torch.logical_and(low <= index[0, :], index[0, :] < high)
        mask2 = torch.logical_and(low <= index[1, :], index[1, :] < high)
        mask = torch.logical_and(mask1, mask2)

        d_index = index[:, mask]
        d_values = values[mask]
        d_index = d_index - low

        Dv = torch.sparse_coo_tensor(d_index, d_values).to_dense()
        assert list(Dv.size()) == [K, K]
        if augmented:
            Dv = Dv + torch.eye(K, K)
        Dv_sqrt_inv = sym_matrix_pow(Dv, -0.5).to_sparse()

        block_diag_indices.append(Dv_sqrt_inv.indices() + low)
        block_diag_values.append(Dv_sqrt_inv.values())

    D_sqrt_inv_idx = torch.cat(block_diag_indices, dim=1)
    D_sqrt_val = torch.cat(block_diag_values, dim=0)

    tmp_idx, tmp_val = torch_sparse.spspmm(D_sqrt_inv_idx, D_sqrt_val, index, values, N * K, N * K, N * K,
                                           coalesced=True)
    index, value = torch_sparse.spspmm(tmp_idx, tmp_val, D_sqrt_inv_idx, D_sqrt_val, N * K, N * K, N * K,
                                       coalesced=True)
    return torch_sparse.coalesce(index, value, N * K, N * K)


def build_sheaf_difussion_matrix(N, K, edge_index, maps, augmented=True, return_laplacian=False):
    """
    Builds the difussion matrix P := I - D*^{-1/2}LD*^{-1/2}, where D* = D + I

    Args:
        N: The number of nodes in the graph
        K: The dimensionality of the Stalks
        edge_index: Edge index of the graph without duplicate edges. We assume that edge i has orientation
            edge_index[0, i] --> edge_index[1, i].
        maps: Tensor of shape [edge_index.size(1), 2 (source/target), K, K] containing the restriction maps of the sheaf
        augmented: Use the augmented sheaf Laplacian.
        return_laplacian: Also returns the Laplacian as a second argument.
    Returns:
        (index, value): The difussion matrix associated with the normalised sheaf Laplacian.
    """
    L_index, L_val = build_norm_sheaf_laplacian(N, K, edge_index, maps, augmented=augmented)

    I_index = torch.arange(0, N * K).view(1, -1).tile(2, 1)
    I_val = torch.ones((N * K,))

    index = torch.cat((L_index, I_index), dim=1)
    value = torch.cat((-L_val, I_val), dim=0)

    P_index, P_val = torch_sparse.coalesce(index, value, N * K, N * K, op='add')
    if return_laplacian:
        L_index, L_val = torch_sparse.coalesce(L_index, L_val, N * K, N * K, op='add')
        return (P_index, P_val), (L_index, L_val)
    return P_index, P_val


def dirichlet_energy(L, f, size):
    """Returns the Dirichlet energy of the signal f under the sheaf Laplacian L."""
    right = torch_sparse.spmm(L[0], L[1], size, size, f)
    energy = f.t() @ right
    return energy.item()


def get_edge_index_dict(edge_index, undirected=True):
    """Computes a dictionary mapping the undirected edges in edge_index to an ID."""
    assert edge_index.size(1) % 2 == 0

    E = edge_index.size(1)
    edge_idx_dict = dict()
    next_id = 0

    for e in range(E):
        source = edge_index[0, e].item()
        target = edge_index[1, e].item()
        if undirected:
            edge = tuple(sorted([source, target]))
        else:
            edge = tuple([source, target])

        # Generate or retrieve the edge index
        if edge not in edge_idx_dict:
            edge_idx_dict[edge] = next_id
            next_id += 1

    return edge_idx_dict


def compute_incidence_index(edge_index, d):
    """Computes the indices of a sheaf coboundary matrix from the edge_index of the graph."""
    assert edge_index.size(1) % 2 == 0

    edge_idx_dict = get_edge_index_dict(edge_index)
    index = []

    for edge in range(edge_index.size(1)):
        source = edge_index[0, edge].item()
        target = edge_index[1, edge].item()
        edge_key = tuple(sorted([source, target]))

        top_x = edge_idx_dict[edge_key] * d
        top_y = source * d
        for i, j in itertools.product(range(d), range(d)):
            index.append([top_x + i, top_y + j])

    incidence_index = torch.tensor(index, dtype=torch.long).T
    assert list(incidence_index.size()) == [2, edge_index.size(1) * (d ** 2)]
    return incidence_index


def build_dense_laplacian(size, edge_index, maps, d, normalised=False, diagonal_maps=False, values=None,
                          edge_weights=None):
    """Builds a sheaf laplacian from a given graph using naive dense computations (used for testing)."""
    assert edge_index.size(1) % 2 == 0
    if diagonal_maps:
        assert len(maps.size()) == 2
        assert maps.size(1) == d

    E = edge_index.size(1) // 2
    N = size
    Delta = torch.zeros(size=(E*d, N*d), dtype=torch.float64)
    undirected_edge_idx_dict = get_edge_index_dict(edge_index)
    directed_edge_idx_dict = get_edge_index_dict(edge_index, undirected=False)

    for e in range(edge_index.size(1)):
        source = edge_index[0, e].item()
        target = edge_index[1, e].item()
        edge_key = tuple(sorted([source, target]))

        # Generate the positions in the block matrix
        top_x = undirected_edge_idx_dict[edge_key] * d
        top_y = source * d

        orient = -1 if edge_key[0] == source else 1
        if edge_weights is not None:
            factor1_idx, factor2_idx = (
                directed_edge_idx_dict[(source, target)], directed_edge_idx_dict[(target, source)])
            assert edge_weights[factor1_idx] == edge_weights[factor2_idx]
            maps[e] = maps[e] * edge_weights[factor1_idx]
        if diagonal_maps:
            diag_idx = torch.arange(0, d)
            Delta[top_x + diag_idx, top_y + diag_idx] = orient * maps[e]
        else:
            Delta[top_x: top_x+d, top_y: top_y+d] = orient * maps[e]

    # Compute non-normalised Laplacian.
    L_dense = Delta.T @ Delta

    if values is not None:
        # Append extra entries to the diagonal of the parallel transport maps and update the stalk dimension.
        L_dense, d = append_diag_maps_to_existent_laplacian(size, d, L_dense, edge_index, values)

    if not normalised:
        return L_dense

    # Build normalised Laplacian.
    D_sqrt_inv = torch.zeros((N*d, N*d), dtype=torch.float64)
    for i in range(N):
        low = i * d
        high = low + d

        D_i = L_dense[low:high, low:high]
        D_i = D_i + torch.eye(d)
        D_i_sqrt_inv = sym_matrix_pow(D_i, -0.5)
        D_sqrt_inv[low:high, low:high] = D_i_sqrt_inv

    return D_sqrt_inv @ L_dense @ D_sqrt_inv


def append_diag_maps_to_existent_laplacian(size, learnable_d, L, edge_index, values):
    extra_d = len(values)
    total_d = learnable_d + extra_d

    deg = degree(edge_index[0], num_nodes=size, dtype=L.dtype)
    values = torch.tensor(values, dtype=L.dtype)
    new_L = torch.zeros((size*(extra_d + learnable_d), size*(extra_d + learnable_d)), dtype=L.dtype)

    for idx in range(edge_index.size(1)):
        i, j = edge_index[0][idx], edge_index[1][idx]
        assert i != j

        # Add to the new Laplacian the entries of the existent Laplacian
        new_low_i, new_high_i = i * total_d, i * total_d + learnable_d
        new_low_j, new_high_j = j * total_d, j * total_d + learnable_d

        low_i, high_i = i * learnable_d, i * learnable_d + learnable_d
        low_j, high_j = j * learnable_d, j * learnable_d + learnable_d

        new_L[new_low_i:new_high_i, new_low_j:new_high_j] = L[low_i:high_i, low_j:high_j]

        # Append extra entries to each parallel transport map
        extra_diag_idx = torch.arange(learnable_d, total_d)
        new_L[new_low_i + extra_diag_idx, new_low_j + extra_diag_idx] = values

    for i in range(size):
        # Add to the new Laplacian diagonal, the diagonal entries of the existent Laplacian
        new_low_i, new_high_i = i * total_d, i * total_d + learnable_d
        low_i, high_i = i * learnable_d, i * learnable_d + learnable_d
        new_L[new_low_i:new_high_i, new_low_i:new_high_i] = L[low_i:high_i, low_i:high_i]

        # Append the degree on the diagonal for the extra entries
        extra_diag_idx = torch.arange(learnable_d, total_d)
        new_L[i * total_d + extra_diag_idx, i * total_d + extra_diag_idx] = deg[i]

    return new_L, total_d


def compute_left_right_map_index(edge_index, full_matrix=False):
    """Computes indices for lower triangular matrix or full matrix"""
    edge_to_idx = dict()
    for e in range(edge_index.size(1)):
        source = edge_index[0, e].item()
        target = edge_index[1, e].item()
        edge_to_idx[(source, target)] = e

    left_index, right_index = [], []
    row, col = [], []
    for e in range(edge_index.size(1)):
        source = edge_index[0, e].item()
        target = edge_index[1, e].item()
        if source < target or full_matrix:
            left_index.append(e)
            right_index.append(edge_to_idx[(target, source)])

            row.append(source)
            col.append(target)

    left_index = torch.tensor(left_index, dtype=torch.long, device=edge_index.device)
    right_index = torch.tensor(right_index, dtype=torch.long, device=edge_index.device)
    left_right_index = torch.vstack([left_index, right_index])

    row = torch.tensor(row, dtype=torch.long, device=edge_index.device)
    col = torch.tensor(col, dtype=torch.long, device=edge_index.device)
    new_edge_index = torch.vstack([row, col])

    if full_matrix:
        assert len(left_index) == edge_index.size(1)
    else:
        assert len(left_index) == edge_index.size(1) // 2

    return left_right_index, new_edge_index


def compute_learnable_laplacian_indices(size, edge_index, learned_d, total_d):
    assert torch.all(edge_index[0] < edge_index[1])

    row, col = edge_index
    device = edge_index.device
    row_template = torch.arange(0, learned_d, device=device).view(1, -1, 1).tile(1, 1, learned_d)
    col_template = torch.transpose(row_template, dim0=1, dim1=2)

    non_diag_row_indices = (row_template + total_d*row.reshape(-1, 1, 1)).reshape(1, -1)
    non_diag_col_indices = (col_template + total_d*col.reshape(-1, 1, 1)).reshape(1, -1)
    non_diag_indices = torch.cat((non_diag_row_indices, non_diag_col_indices), dim=0)

    diag = torch.arange(0, size, device=device)
    diag_row_indices = (row_template + total_d*diag.reshape(-1, 1, 1)).reshape(1, -1)
    diag_col_indices = (col_template + total_d*diag.reshape(-1, 1, 1)).reshape(1, -1)
    diag_indices = torch.cat((diag_row_indices, diag_col_indices), dim=0)

    return diag_indices, non_diag_indices


def compute_learnable_diag_laplacian_indices(size, edge_index, learned_d, total_d):
    assert torch.all(edge_index[0] < edge_index[1])
    row, col = edge_index
    device = edge_index.device
    row_template = torch.arange(0, learned_d, device=device).view(1, -1)
    col_template = row_template.clone()

    non_diag_row_indices = (row_template + total_d*row.unsqueeze(1)).reshape(1, -1)
    non_diag_col_indices = (col_template + total_d*col.unsqueeze(1)).reshape(1, -1)
    non_diag_indices = torch.cat((non_diag_row_indices, non_diag_col_indices), dim=0)

    diag = torch.arange(0, size, device=device)
    diag_row_indices = (row_template + total_d*diag.unsqueeze(1)).reshape(1, -1)
    diag_col_indices = (col_template + total_d*diag.unsqueeze(1)).reshape(1, -1)
    diag_indices = torch.cat((diag_row_indices, diag_col_indices), dim=0)

    return diag_indices, non_diag_indices


def compute_fixed_diag_laplacian_indices(size, edge_index, learned_d, total_d):
    assert torch.all(edge_index[0] < edge_index[1])
    row, col = edge_index
    device = edge_index.device
    row_template = torch.arange(learned_d, total_d, device=device).view(1, -1)
    col_template = row_template.clone()

    non_diag_row_indices = (row_template + total_d*row.unsqueeze(1)).reshape(1, -1)
    non_diag_col_indices = (col_template + total_d*col.unsqueeze(1)).reshape(1, -1)
    non_diag_indices = torch.cat((non_diag_row_indices, non_diag_col_indices), dim=0)

    diag = torch.arange(0, size, device=device)
    diag_row_indices = (row_template + total_d*diag.unsqueeze(1)).reshape(1, -1)
    diag_col_indices = (col_template + total_d*diag.unsqueeze(1)).reshape(1, -1)
    diag_indices = torch.cat((diag_row_indices, diag_col_indices), dim=0)

    return diag_indices, non_diag_indices


def batched_sym_matrix_pow(matrices: torch.Tensor, p: float) -> torch.Tensor:
    r"""
    Power of a matrix using Eigen Decomposition.
    Args:
        matrices: A batch of matrices.
        p: Power.
        positive_definite: If positive definite
    Returns:
        Power of each matrix in the batch.
    """
    # vals, vecs = torch.linalg.eigh(matrices)
    # SVD is much faster than  vals, vecs = torch.linalg.eigh(matrices) for large batches.
    vecs, vals, _ = torch.linalg.svd(matrices)
    good = vals > vals.max(-1, True).values * vals.size(-1) * torch.finfo(vals.dtype).eps
    vals = vals.pow(p).where(good, torch.zeros((), device=matrices.device, dtype=matrices.dtype))
    matrix_power = (vecs * vals.unsqueeze(-2)) @ torch.transpose(vecs, -2, -1)
    return matrix_power


def mergesp(index1, value1, index2, value2):
    """Merges two sparse matrices with disjoint indices into one."""
    assert index1.dim() == 2 and index2.dim() == 2
    assert value1.dim() == 1 and value2.dim() == 1
    assert index1.size(1) == value1.numel()
    assert index2.size(1) == value2.numel()
    assert index1.size(0) == 2 and index2.size(0) == 2

    index = torch.cat([index1, index2], dim=1)
    val = torch.cat([value1, value2])
    return index, val


def get_random_edge_weights(edge_index):
    edge_dict = get_edge_index_dict(edge_index, undirected=False)
    edge_weights = torch.FloatTensor(size=(edge_index.size(1), 1)).uniform_(0.0, 1.0)

    # Make the edge weights symmetric
    for i in range(edge_index.size(1)):
        v = edge_index[0, i].item()
        u = edge_index[1, i].item()
        edge_weights[edge_dict[(v, u)]] = edge_weights[edge_dict[(u, v)]]
    return edge_weights


def get_2d_oracle_rotation_angles(edge_index, y, theta=None):
    """Returns the class rotation angles for an oracle 2D orthogonal sheaf."""
    assert y.min() == 0
    if theta is None:
        # This is to be multiplied by 2pi during the construction of the orthogonal matrix
        # in the Connection Laplacian builder.
        theta = 2.0 * math.pi / (y.max() + 1)

    angles = torch.empty(edge_index.size(1), dtype=torch.float32)
    for i in range(edge_index.size(1)):
        v = edge_index[0, i].item()
        u = edge_index[1, i].item()
        cdiff = abs(float(y[u].item() - y[v].item()))
        if v < u:
            angles[i] = theta * cdiff / 2.0
        else:
            angles[i] = -theta * cdiff / 2.0
    assert angles.max() < 2 * math.pi
    return angles.view(-1, 1)


def get_1d_oracle_maps(edge_index, y):
    """Returns the maps for an oracle 2D orthogonal sheaf."""
    maps = torch.empty(edge_index.size(1), dtype=edge_index.dtype)
    for i in range(edge_index.size(1)):
        v = edge_index[0, i].item()
        u = edge_index[1, i].item()
        if v < u or y[v].item() == y[u].item():
            maps[i] = 1.0
        else:
            maps[i] = -1.0
    return maps.view(-1, 1)

### Permutation utils

In [4]:
# Copyright 2022 Twitter, Inc.
# SPDX-License-Identifier: Apache-2.0

import torch
import numpy as np

from scipy import sparse as sp
from torch_geometric.data import Data


def permute_graph(graph: Data, P: np.ndarray) -> Data:
    assert graph.edge_attr is None

    # Check validity of permutation matrix
    n = graph.x.size(0)
    if not is_valid_permutation_matrix(P, n):
        raise AssertionError

    # Apply permutation to features
    x = graph.x.numpy()
    x_perm = torch.FloatTensor(P @ x)

    # Apply perm to labels, if per-node
    if graph.y is None:
        y_perm = None
    elif graph.y.size(0) == n:
        y = graph.y.numpy()
        y_perm = torch.tensor(P @ y)
    else:
        y_perm = graph.y.clone().detach()

    # Apply permutation to adjacencies, if any
    if graph.edge_index.size(1) > 0:
        inps = (np.ones(graph.edge_index.size(1)), (graph.edge_index[0].numpy(), graph.edge_index[1].numpy()))
        A = sp.csr_matrix(inps, shape=(n, n))
        P = sp.csr_matrix(P)
        A_perm = P.dot(A).dot(P.transpose()).tocoo()
        edge_index_perm = torch.LongTensor(np.vstack((A_perm.row, A_perm.col)))
    else:
        edge_index_perm = graph.edge_index.clone().detach()

    # Instantiate new graph
    graph_perm = Data(x=x_perm, edge_index=edge_index_perm, y=y_perm)

    return graph_perm


def is_valid_permutation_matrix(P: np.ndarray, n: int):
    valid = True
    valid &= P.ndim == 2
    valid &= P.shape[0] == n
    valid &= np.all(P.sum(0) == np.ones(n))
    valid &= np.all(P.sum(1) == np.ones(n))
    valid &= np.all(P.max(0) == np.ones(n))
    valid &= np.all(P.max(1) == np.ones(n))
    if n > 1:
        valid &= np.all(P.min(0) == np.zeros(n))
        valid &= np.all(P.min(1) == np.zeros(n))
        valid &= not np.array_equal(P, np.eye(n))
    return valid


def generate_permutation_matrices(size, amount=10):
    Ps = list()
    random_state = np.random.RandomState()
    count = 0
    while count < amount:
        I = np.eye(size)
        perm = random_state.permutation(size)
        P = I[perm]
        if is_valid_permutation_matrix(P, size):
            Ps.append(P)
            count += 1

    return Ps

## Models


### Sheaf diffusion base model

In [5]:
# Copyright 2022 Twitter, Inc.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch import nn


class SheafDiffusion(nn.Module):
    """Base class for sheaf diffusion models."""

    def __init__(self, args):
        super(SheafDiffusion, self).__init__()

        assert args['d'] > 0
        self.d = args['d']
        # add low pass filters/high pass filters 
        self.add_lp = args['add_lp']
        self.add_hp = args['add_hp']

        self.final_d = self.d
        if self.add_hp:
            self.final_d += 1
        if self.add_lp:
            self.final_d += 1

        self.hidden_dim = args['hidden_channels'] * self.final_d
        self.device = args['device']
        self.layers = args['layers']
        self.normalised = args['normalised']
        self.deg_normalised = args['deg_normalised']
        self.nonlinear = not args['linear']
        self.input_dropout = args['input_dropout']
        self.dropout = args['dropout']
        self.left_weights = args['left_weights']
        self.right_weights = args['right_weights']
        self.use_act = args['use_act']
        self.input_dim = args['input_dim']
        self.hidden_channels = args['hidden_channels']
        self.output_dim = args['output_dim']
        self.layers = args['layers']
        self.sheaf_act = args['sheaf_act']
        self.second_linear = args['second_linear']
        self.orth_trans = args['orth']
        self.use_edge_weights = args['edge_weights']
        self.laplacian_builder = None
        self.edges_feat = args['edges_feat']
        self.readout = args['readout'] 
        self.input_dim_edge = args['input_dim_edge']
        self.dense_intermediate_dim = args['dense_intermediate_dim']
        self.dense_output_graph_dim = args['dense_output_graph_dim']
        self.max_num_nodes_in_graph = args['max_num_nodes_in_graph']
        self.output_nn_intermediate_dim = args['output_nn_intermediate_dim']
        self.set_transformer_k = args['set_transformer_k']
    """
    def update_edge_index(self, edge_index):
        assert edge_index.max() <= self.graph_size
        self.edge_index = edge_index
        self.laplacian_builder = self.laplacian_builder.create_with_new_edge_index(edge_index)
    """

    def grouped_parameters(self):
        sheaf_learners, others = [], []
        for name, param in self.named_parameters():
            if "sheaf_learner" in name:
                sheaf_learners.append(param)
            else:
                others.append(param)
        assert len(sheaf_learners) > 0
        assert len(sheaf_learners) + len(others) == len(list(self.parameters()))
        return sheaf_learners, others


### Sheaf learners 

In [6]:
# Copyright 2022 Twitter, Inc.
# SPDX-License-Identifier: Apache-2.0

import torch
import torch.nn.functional as F
import numpy as np

from typing import Tuple
from abc import abstractmethod
from torch import nn


class SheafLearner(nn.Module):
    """Base model that learns a sheaf from the features and the graph structure."""
    def __init__(self):
        super(SheafLearner, self).__init__()
        self.L = None

    @abstractmethod
    def forward(self, x, edge_index):
        raise NotImplementedError()

    def set_L(self, weights):
        self.L = weights.clone().detach()


class LocalConcatSheafLearner(SheafLearner):
    """Learns a sheaf by concatenating the local node features and passing them through a linear layer + activation."""

    def __init__(self, d: int, hidden_channels: int, out_shape: Tuple[int, ...], sheaf_act="tanh"):
        super(LocalConcatSheafLearner, self).__init__()
        assert len(out_shape) in [1, 2]
        self.out_shape = out_shape
        self.d = d
        self.hidden_channels = hidden_channels
        self.linear1 = torch.nn.Linear(hidden_channels * 2, int(np.prod(out_shape)), bias=False)
        # self.linear2 = torch.nn.Linear(self.d, 1, bias=False)

        # std1 = 1.414 * math.sqrt(2. / (hidden_channels * 2 + 1))
        # std2 = 1.414 * math.sqrt(2. / (d + 1))
        #
        # nn.init.normal_(self.linear1.weight, 0.0, std1)
        # nn.init.normal_(self.linear2.weight, 0.0, std2)

        if sheaf_act == 'id':
            self.act = lambda x: x
        elif sheaf_act == 'tanh':
            self.act = torch.tanh
        elif sheaf_act == 'elu':
            self.act = F.elu
        else:
            raise ValueError(f"Unsupported act {sheaf_act}")

    def forward(self, x, edge_attr, edge_index):
        row, col = edge_index

        x_row = torch.index_select(x, dim=0, index=row)
        x_col = torch.index_select(x, dim=0, index=col)
        x_cat = torch.cat([x_row, x_col], dim=-1)
        x_cat = x_cat.reshape(-1, self.d, self.hidden_channels * 2).sum(dim=1)

        x_cat = self.linear1(x_cat)

        # x_cat = x_cat.t().reshape(-1, self.d)
        # x_cat = self.linear2(x_cat)
        # x_cat = x_cat.reshape(-1, edge_index.size(1)).t()

        maps = self.act(x_cat)

        if len(self.out_shape) == 2:
            return maps.view(-1, self.out_shape[0], self.out_shape[1])
        else:
            return maps.view(-1, self.out_shape[0])

class LocalConcatSheafLearnerVariant1(SheafLearner):
    """Learns a sheaf by concatenating the local node features and passing them through a linear layer + activation."""

    def __init__(self, d: int, hidden_channels: int, out_shape: Tuple[int, ...], sheaf_act="tanh"):
        super(LocalConcatSheafLearnerVariant1, self).__init__()
        assert len(out_shape) in [1, 2]
        self.out_shape = out_shape
        self.d = d
        self.hidden_channels = hidden_channels
        self.linear1 = torch.nn.Linear(hidden_channels * 2 + 1, int(np.prod(out_shape)), bias=False)

        if sheaf_act == 'id':
            self.act = lambda x: x
        elif sheaf_act == 'tanh':
            self.act = torch.tanh
        elif sheaf_act == 'elu':
            self.act = F.elu
        else:
            raise ValueError(f"Unsupported act {sheaf_act}")

    def forward(self, x, edge_attr, edge_index):
        row, col = edge_index

        x_row = torch.index_select(x, dim=0, index=row)
        x_col = torch.index_select(x, dim=0, index=col)
        x_cat = torch.cat([x_row, x_col, edge_attr], dim=-1)
        x_cat = x_cat.reshape(-1, self.d, self.hidden_channels * 2 + 1).sum(dim=1)

        x_cat = self.linear1(x_cat)
        maps = self.act(x_cat)

        if len(self.out_shape) == 2:
            return maps.view(-1, self.out_shape[0], self.out_shape[1])
        else:
            return maps.view(-1, self.out_shape[0])


class LocalConcatSheafLearnerVariant2(SheafLearner):
    """Learns a sheaf by concatenating the local node features and passing them through a linear layer + activation."""

    def __init__(self, d: int, hidden_channels: int, out_shape: Tuple[int, ...], sheaf_act="tanh"):
        super(LocalConcatSheafLearnerVariant2, self).__init__()
        assert len(out_shape) in [1, 2]
        self.out_shape = out_shape
        self.d = d
        self.hidden_channels = hidden_channels
        self.linear1 = torch.nn.Linear(hidden_channels * 2, int(np.prod(out_shape)), bias=False)
        self.linear2 = torch.nn.Linear(self.d * 2, self.d, bias=False)


        if sheaf_act == 'id':
            self.act = lambda x: x
        elif sheaf_act == 'tanh':
            self.act = torch.tanh
        elif sheaf_act == 'elu':
            self.act = F.elu
        else:
            raise ValueError(f"Unsupported act {sheaf_act}")

    def forward(self, x, edge_attr, edge_index):
        row, col = edge_index

        x_row = torch.index_select(x, dim=0, index=row)
        x_col = torch.index_select(x, dim=0, index=col)
        x_cat = torch.cat([x_row, x_col], dim=-1)
        x_cat = x_cat.reshape(-1, self.d, self.hidden_channels * 2).sum(dim=1)

        x_cat = self.linear1(x_cat)
        x_cat_act = self.act(x_cat)

        x_cat_2 = self.linear2(torch.cat([x_cat_act, edge_attr], dim = -1))
        maps = self.act(x_cat_2)
        maps = x_cat_act + maps

        if len(self.out_shape) == 2:
            return maps.view(-1, self.out_shape[0], self.out_shape[1])
        else:
            return maps.view(-1, self.out_shape[0])

class LocalConcatSheafLearnerVariant2(SheafLearner):
    """Learns a sheaf by concatenating the local node features and passing them through a linear layer + activation."""

    def __init__(self, d: int, hidden_channels: int, out_shape: Tuple[int, ...], sheaf_act="tanh"):
        super(LocalConcatSheafLearnerVariant2, self).__init__()
        assert len(out_shape) in [1, 2]
        self.out_shape = out_shape
        self.d = d
        self.hidden_channels = hidden_channels
        self.linear1 = torch.nn.Linear(hidden_channels * 2, int(np.prod(out_shape)), bias=False)
        self.linear2 = torch.nn.Linear(self.d * 2, self.d, bias=False)


        if sheaf_act == 'id':
            self.act = lambda x: x
        elif sheaf_act == 'tanh':
            self.act = torch.tanh
        elif sheaf_act == 'elu':
            self.act = F.elu
        else:
            raise ValueError(f"Unsupported act {sheaf_act}")

    def forward(self, x, edge_attr, edge_index):
        row, col = edge_index

        x_row = torch.index_select(x, dim=0, index=row)
        x_col = torch.index_select(x, dim=0, index=col)
        x_cat = torch.cat([x_row, x_col], dim=-1)
        x_cat = x_cat.reshape(-1, self.d, self.hidden_channels * 2).sum(dim=1)

        x_cat = self.linear1(x_cat)
        x_cat_act = self.act(x_cat)

        x_cat_2 = self.linear2(torch.cat([x_cat_act, edge_attr], dim = -1))
        maps = self.act(x_cat_2)
        maps = x_cat_act + maps

        if len(self.out_shape) == 2:
            return maps.view(-1, self.out_shape[0], self.out_shape[1])
        else:
            return maps.view(-1, self.out_shape[0])

class LocalConcatSheafLearnerVariant3(SheafLearner):
    """Learns a sheaf by concatenating the local node features and passing them through a linear layer + activation."""

    def __init__(self, d: int, hidden_channels: int, out_shape: Tuple[int, ...], sheaf_act="tanh"):
        super(LocalConcatSheafLearnerVariant3, self).__init__()
        assert len(out_shape) in [1, 2]
        self.out_shape = out_shape
        self.d = d
        self.hidden_channels = hidden_channels
        self.linear1 = torch.nn.Linear(hidden_channels * 2, int(np.prod(out_shape)), bias=False)
        self.bilinear = torch.nn.Bilinear(self.d, self.d, self.d, bias=False)


        if sheaf_act == 'id':
            self.act = lambda x: x
        elif sheaf_act == 'tanh':
            self.act = torch.tanh
        elif sheaf_act == 'elu':
            self.act = F.elu
        else:
            raise ValueError(f"Unsupported act {sheaf_act}")

    def forward(self, x, edge_attr, edge_index):
        row, col = edge_index

        x_row = torch.index_select(x, dim=0, index=row)
        x_col = torch.index_select(x, dim=0, index=col)
        x_cat = torch.cat([x_row, x_col], dim=-1)
        x_cat = x_cat.reshape(-1, self.d, self.hidden_channels * 2).sum(dim=1)

        x_cat = self.linear1(x_cat)
        x_cat_act = self.act(x_cat)

        maps = x_cat_act + self.bilinear(x_cat_act, edge_attr)

        if len(self.out_shape) == 2:
            return maps.view(-1, self.out_shape[0], self.out_shape[1])
        else:
            return maps.view(-1, self.out_shape[0])


class EdgeWeightLearner(SheafLearner):
    """Learns a sheaf by concatenating the local node features and passing them through a linear layer + activation."""

    def __init__(self, in_channels: int):
        super(EdgeWeightLearner, self).__init__()
        self.in_channels = in_channels
        self.linear1 = torch.nn.Linear(in_channels*2, 1, bias=False)
        self.full_left_right_idx = None

    def forward(self, x, edge_index):

        self.full_left_right_idx, _ = compute_left_right_map_index(edge_index, full_matrix=True)
        _, full_right_idx = self.full_left_right_idx

        row, col = edge_index
        x_row = torch.index_select(x, dim=0, index=row)
        x_col = torch.index_select(x, dim=0, index=col)
        weights = self.linear1(torch.cat([x_row, x_col], dim=1))
        weights = torch.sigmoid(weights)

        edge_weights = weights * torch.index_select(weights, index=full_right_idx, dim=0)
        return edge_weights

    def update_edge_index(self, edge_index):
        self.full_left_right_idx, _ = compute_left_right_map_index(edge_index, full_matrix=True)


class QuadraticFormSheafLearner(SheafLearner):
    """Learns a sheaf by concatenating the local node features and passing them through a linear layer + activation."""

    def __init__(self, in_channels: int, out_shape: Tuple[int]):
        super(QuadraticFormSheafLearner, self).__init__()
        assert len(out_shape) in [1, 2]
        self.out_shape = out_shape

        tensor = torch.eye(in_channels).unsqueeze(0).tile(int(np.prod(out_shape)), 1, 1)
        self.tensor = nn.Parameter(tensor)

    def forward(self, x, edge_index):
        row, col = edge_index
        x_row = torch.index_select(x, dim=0, index=row)
        x_col = torch.index_select(x, dim=0, index=col)
        maps = self.map_builder(torch.cat([x_row, x_col], dim=1))

        if len(self.out_shape) == 2:
            return torch.tanh(maps).view(-1, self.out_shape[0], self.out_shape[1])
        else:
            return torch.tanh(maps).view(-1, self.out_shape[0])

### Orthogonal module 

In [None]:
!pip install torch-householder==1.0.1


In [8]:
# Copyright 2022 Twitter, Inc.
# SPDX-License-Identifier: Apache-2.0

import math
import torch

from torch import nn
from torch_householder import torch_householder_orgqr


class Orthogonal(nn.Module):
    """Based on https://pytorch.org/docs/stable/_modules/torch/nn/utils/parametrizations.html#orthogonal"""
    def __init__(self, d, orthogonal_map):
        super().__init__()
        assert orthogonal_map in ["matrix_exp", "cayley", "householder", "euler"]
        self.d = d
        self.orthogonal_map = orthogonal_map

    def get_2d_rotation(self, params):
        # assert params.min() >= -1.0 and params.max() <= 1.0
        assert params.size(-1) == 1
        sin = torch.sin(params * 2 * math.pi)
        cos = torch.cos(params * 2 * math.pi)
        return torch.cat([cos, -sin,
                          sin, cos], dim=1).view(-1, 2, 2)

    def get_3d_rotation(self, params):
        assert params.min() >= -1.0 and params.max() <= 1.0
        assert params.size(-1) == 3

        alpha = params[:, 0].view(-1, 1) * 2 * math.pi
        beta = params[:, 1].view(-1, 1) * 2 * math.pi
        gamma = params[:, 2].view(-1, 1) * 2 * math.pi

        sin_a, cos_a = torch.sin(alpha), torch.cos(alpha)
        sin_b, cos_b = torch.sin(beta),  torch.cos(beta)
        sin_g, cos_g = torch.sin(gamma), torch.cos(gamma)

        return torch.cat(
            [cos_a*cos_b, cos_a*sin_b*sin_g - sin_a*cos_g, cos_a*sin_b*cos_g + sin_a*sin_g,
             sin_a*cos_b, sin_a*sin_b*sin_g + cos_a*cos_g, sin_a*sin_b*cos_g - cos_a*sin_g,
             -sin_b, cos_b*sin_g, cos_b*cos_g], dim=1).view(-1, 3, 3)

    def forward(self, params: torch.Tensor) -> torch.Tensor:
        if self.orthogonal_map != "euler":
            # Construct a lower diagonal matrix where to place the parameters.
            offset = -1 if self.orthogonal_map == 'householder' else 0
            tril_indices = torch.tril_indices(row=self.d, col=self.d, offset=offset, device=params.device)
            new_params = torch.zeros(
                (params.size(0), self.d, self.d), dtype=params.dtype, device=params.device)
            new_params[:, tril_indices[0], tril_indices[1]] = params
            params = new_params

        if self.orthogonal_map == "matrix_exp" or self.orthogonal_map == "cayley":
            # We just need n x k - k(k-1)/2 parameters
            params = params.tril()
            A = params - params.transpose(-2, -1)
            # A is skew-symmetric (or skew-hermitian)
            if self.orthogonal_map == "matrix_exp":
                Q = torch.matrix_exp(A)
            elif self.orthogonal_map == "cayley":
                # Computes the Cayley retraction (I+A/2)(I-A/2)^{-1}
                Id = torch.eye(self.d, dtype=A.dtype, device=A.device)
                Q = torch.linalg.solve(torch.add(Id, A, alpha=-0.5), torch.add(Id, A, alpha=0.5))
        elif self.orthogonal_map == 'householder':
            eye = torch.eye(self.d, device=params.device).unsqueeze(0).repeat(params.size(0), 1, 1)
            A = params.tril(diagonal=-1) + eye
            Q = torch_householder_orgqr(A)
        elif self.orthogonal_map == 'euler':
            assert 2 <= self.d <= 3
            if self.d == 2:
                Q = self.get_2d_rotation(params)
            else:
                Q = self.get_3d_rotation(params)
        else:
            raise ValueError(f"Unsupported transformations {self.orthogonal_map}")
        return Q

### Laplacian builders 

In [9]:
# Copyright 2022 Twitter, Inc.
# SPDX-License-Identifier: Apache-2.0

import torch

from torch import nn
from torch_scatter import scatter_add
from torch_geometric.utils import degree


class LaplacianBuilder(nn.Module):

    def __init__(self, size, edge_index, d, normalised=False, deg_normalised=False, add_hp=False, add_lp=False,
                 augmented=True):
        super(LaplacianBuilder, self).__init__()
        assert not (normalised and deg_normalised)

        self.d = d
        self.final_d = d
        if add_hp:
            self.final_d += 1
        if add_lp:
            self.final_d += 1
        self.size = size
        self.edges = edge_index.size(1) // 2
        self.edge_index = edge_index
        self.normalised = normalised
        self.deg_normalised = deg_normalised
        self.device = edge_index.device
        self.add_hp = add_hp
        self.add_lp = add_lp
        self.augmented = augmented

        # Preprocess the sparse indices required to compute the Sheaf Laplacian.
        self.full_left_right_idx, _ = compute_left_right_map_index(edge_index, full_matrix=True)
        self.left_right_idx, self.vertex_tril_idx = compute_left_right_map_index(edge_index)
        if self.add_lp or self.add_hp:
            self.fixed_diag_indices, self.fixed_tril_indices = compute_fixed_diag_laplacian_indices(
                size, self.vertex_tril_idx, self.d, self.final_d)
        self.deg = degree(self.edge_index[0], num_nodes=self.size)

    def get_fixed_maps(self, size, dtype):
        assert self.add_lp or self.add_hp

        fixed_diag, fixed_non_diag = [], []
        if self.add_lp:
            fixed_diag.append(self.deg.view(-1, 1))
            fixed_non_diag.append(torch.ones(size=(size, 1), device=self.device, dtype=dtype))
        if self.add_hp:
            fixed_diag.append(self.deg.view(-1, 1))
            fixed_non_diag.append(-torch.ones(size=(size, 1), device=self.device, dtype=dtype))

        fixed_diag = torch.cat(fixed_diag, dim=1)
        fixed_non_diag = torch.cat(fixed_non_diag, dim=1)

        assert self.fixed_tril_indices.size(1) == fixed_non_diag.numel()
        assert self.fixed_diag_indices.size(1) == fixed_diag.numel()

        return fixed_diag, fixed_non_diag

    def scalar_normalise(self, diag, tril, row, col):
        if tril.dim() > 2:
            assert tril.size(-1) == tril.size(-2)
            assert diag.dim() == 2
        d = diag.size(-1)

        if self.augmented:
            diag_sqrt_inv = (diag + 1).pow(-0.5)
        else:
            diag_sqrt_inv = diag.pow(-0.5)
            diag_sqrt_inv.masked_fill_(diag_sqrt_inv == float('inf'), 0)
        diag_sqrt_inv = diag_sqrt_inv.view(-1, 1, 1) if tril.dim() > 2 else diag_sqrt_inv.view(-1, d)
        left_norm = diag_sqrt_inv[row]
        right_norm = diag_sqrt_inv[col]
        non_diag_maps = left_norm * tril * right_norm

        diag_sqrt_inv = diag_sqrt_inv.view(-1, 1, 1) if diag.dim() > 2 else diag_sqrt_inv.view(-1, d)
        diag_maps = diag_sqrt_inv**2 * diag

        return diag_maps, non_diag_maps

    def append_fixed_maps(self, size, diag_indices, diag_maps, tril_indices, tril_maps):
        if not self.add_lp and not self.add_hp:
            return (diag_indices, diag_maps), (tril_indices, tril_maps)

        fixed_diag, fixed_non_diag = self.get_fixed_maps(size, tril_maps.dtype)
        tril_row, tril_col = self.vertex_tril_idx

        # Normalise the fixed parts.
        if self.normalised:
            fixed_diag, fixed_non_diag = self.scalar_normalise(fixed_diag, fixed_non_diag, tril_row, tril_col)
        fixed_diag, fixed_non_diag = fixed_diag.view(-1), fixed_non_diag.view(-1)
        # Combine the learnable and fixed parts.
        tril_indices, tril_maps = mergesp(self.fixed_tril_indices, fixed_non_diag, tril_indices, tril_maps)
        diag_indices, diag_maps = mergesp(self.fixed_diag_indices, fixed_diag, diag_indices, diag_maps)

        return (diag_indices, diag_maps), (tril_indices, tril_maps)

    def create_with_new_edge_index(self, edge_index):
        assert edge_index.max() <= self.size
        new_builder = self.__class__(
            self.size, edge_index, self.d,
            normalised=self.normalised, deg_normalised=self.deg_normalised, add_hp=self.add_hp, add_lp=self.add_lp,
            augmented=self.augmented)
        new_builder.train(self.training)
        return new_builder


class DiagLaplacianBuilder(LaplacianBuilder):
    """Learns a a Sheaf Laplacian with diagonal restriction maps"""

    def __init__(self, size, edge_index, d, normalised=False, deg_normalised=False, add_hp=False, add_lp=False,
                 augmented=True):
        super(DiagLaplacianBuilder, self).__init__(
            size, edge_index, d, normalised, deg_normalised, add_hp, add_lp, augmented)

        self.diag_indices, self.tril_indices = compute_learnable_diag_laplacian_indices(
            size, self.vertex_tril_idx, self.d, self.final_d)

    def normalise(self, diag, tril, row, col):
        if self.normalised:
            d_sqrt_inv = (diag + 1).pow(-0.5) if self.augmented else diag.pow(-0.5)
            left_norm, right_norm = d_sqrt_inv[row], d_sqrt_inv[col]
            tril = left_norm * tril * right_norm
            diag = d_sqrt_inv * diag * d_sqrt_inv
        elif self.deg_normalised:
            deg_sqrt_inv = (self.deg + 1).pow(-0.5) if self.augmented else self.deg.pow(-0.5)
            deg_sqrt_inv = deg_sqrt_inv.unsqueeze(-1)
            deg_sqrt_inv.masked_fill_(deg_sqrt_inv == float('inf'), 0)
            left_norm, right_norm = deg_sqrt_inv[row], deg_sqrt_inv[col]
            tril = left_norm * tril * right_norm
            diag = deg_sqrt_inv * diag * deg_sqrt_inv
        return diag, tril

    def forward(self, maps):
        assert len(maps.size()) == 2
        assert maps.size(1) == self.d
        left_idx, right_idx = self.left_right_idx
        tril_row, tril_col = self.vertex_tril_idx
        row, _ = self.edge_index

        # Compute the un-normalised Laplacian entries.
        left_maps = torch.index_select(maps, index=left_idx, dim=0)
        right_maps = torch.index_select(maps, index=right_idx, dim=0)
        tril_maps = -left_maps * right_maps
        saved_tril_maps = tril_maps.detach().clone()
        diag_maps = scatter_add(maps**2, row, dim=0, dim_size=self.size)

        # Normalise the entries if the normalised Laplacian is used.
        diag_maps, tril_maps = self.normalise(diag_maps, tril_maps, tril_row, tril_col)
        tril_indices, diag_indices = self.tril_indices, self.diag_indices
        tril_maps, diag_maps = tril_maps.view(-1), diag_maps.view(-1)

        # Append fixed diagonal values in the non-learnable dimensions.
        (diag_indices, diag_maps), (tril_indices, tril_maps) = self.append_fixed_maps(
            len(left_maps), diag_indices, diag_maps, tril_indices, tril_maps)

        # Add the upper triangular part
        triu_indices = torch.empty_like(tril_indices)
        triu_indices[0], triu_indices[1] = tril_indices[1], tril_indices[0]
        non_diag_indices, non_diag_values = mergesp(tril_indices, tril_maps, triu_indices, tril_maps)

        # Merge diagonal and non-diagonal
        edge_index, weights = mergesp(non_diag_indices, non_diag_values, diag_indices, diag_maps)

        return (edge_index, weights), saved_tril_maps


class NormConnectionLaplacianBuilder(LaplacianBuilder):
    """Learns a a Sheaf Laplacian with diagonal restriction maps"""

    def __init__(self, size, edge_index, d, add_hp=False, add_lp=False, orth_map=None, augmented=True):
        super(NormConnectionLaplacianBuilder, self).__init__(
            size, edge_index, d, add_hp=add_hp, add_lp=add_lp, normalised=True, augmented=augmented)
        self.orth_transform = Orthogonal(d=self.d, orthogonal_map=orth_map)
        self.orth_map = orth_map

        _, self.tril_indices = compute_learnable_laplacian_indices(
            size, self.vertex_tril_idx, self.d, self.final_d)
        self.diag_indices, _ = compute_learnable_diag_laplacian_indices(
            size, self.vertex_tril_idx, self.d, self.final_d)

    def create_with_new_edge_index(self, edge_index):
        assert edge_index.max() <= self.size
        new_builder = self.__class__(
            self.size, edge_index, self.d, add_hp=self.add_hp, add_lp=self.add_lp, augmented=self.augmented,
            orth_map=self.orth_map)
        new_builder.train(self.training)
        return new_builder

    def normalise(self, diag, tril, row, col):
        if tril.dim() > 2:
            assert tril.size(-1) == tril.size(-2)
            assert diag.dim() == 2
        d = diag.size(-1)

        if self.augmented:
            diag_sqrt_inv = (diag + 1).pow(-0.5)
        else:
            diag_sqrt_inv = diag.pow(-0.5)
            diag_sqrt_inv.masked_fill_(diag_sqrt_inv == float('inf'), 0)
        diag_sqrt_inv = diag_sqrt_inv.view(-1, 1, 1) if tril.dim() > 2 else diag_sqrt_inv.view(-1, d)
        left_norm = diag_sqrt_inv[row]
        right_norm = diag_sqrt_inv[col]
        non_diag_maps = left_norm * tril * right_norm

        diag_sqrt_inv = diag_sqrt_inv.view(-1, 1, 1) if diag.dim() > 2 else diag_sqrt_inv.view(-1, d)
        diag_maps = diag_sqrt_inv**2 * diag

        return diag_maps, non_diag_maps

    def forward(self, map_params, edge_weights=None):
        if edge_weights is not None:
            assert edge_weights.size(1) == 1
        assert len(map_params.size()) == 2
        if self.orth_map in ["matrix_exp", "cayley"]:
            assert map_params.size(1) == self.d * (self.d + 1) // 2
        else:
            assert map_params.size(1) == self.d * (self.d - 1) // 2

        _, full_right_idx = self.full_left_right_idx
        left_idx, right_idx = self.left_right_idx
        tril_row, tril_col = self.vertex_tril_idx
        tril_indices, diag_indices = self.tril_indices, self.diag_indices
        row, _ = self.edge_index

        # Convert the parameters to orthogonal matrices.
        maps = self.orth_transform(map_params)
        if edge_weights is None:
            diag_maps = self.deg.unsqueeze(-1)
        else:
            diag_maps = scatter_add(edge_weights ** 2, row, dim=0, dim_size=self.size)
            maps = maps * edge_weights.unsqueeze(-1)

        # Compute the transport maps.
        left_maps = torch.index_select(maps, index=left_idx, dim=0)
        right_maps = torch.index_select(maps, index=right_idx, dim=0)
        tril_maps = -torch.bmm(torch.transpose(left_maps, -1, -2), right_maps)
        saved_tril_maps = tril_maps.detach().clone()

        # Normalise the entries if the normalised Laplacian is used.
        diag_maps, tril_maps = self.scalar_normalise(diag_maps, tril_maps, tril_row, tril_col)
        tril_maps, diag_maps = tril_maps.view(-1), diag_maps.expand(-1, self.d).reshape(-1)

        # Append fixed diagonal values in the non-learnable dimensions.
        (diag_indices, diag_maps), (tril_indices, tril_maps) = self.append_fixed_maps(
            len(left_maps), diag_indices, diag_maps, tril_indices, tril_maps)

        # Add the upper triangular part
        triu_indices = torch.empty_like(tril_indices)
        triu_indices[0], triu_indices[1] = tril_indices[1], tril_indices[0]
        non_diag_indices, non_diag_values = mergesp(tril_indices, tril_maps, triu_indices, tril_maps)

        # Merge diagonal and non-diagonal
        edge_index, weights = mergesp(non_diag_indices, non_diag_values, diag_indices, diag_maps)

        return (edge_index, weights), saved_tril_maps

### Discrete models 

#### Imports 

In [10]:
from IPython.core.display import display_html
# Copyright 2022 Twitter, Inc.
# SPDX-License-Identifier: Apache-2.0

import torch
import torch.nn.functional as F
import torch_sparse
from torch_geometric.nn import SAGPooling
from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool
from torch import nn
from torch_geometric.utils import to_dense_batch

#### Diagonal sheaf diffusion

In [14]:
class DiscreteDiagSheafDiffusion(SheafDiffusion):

    def __init__(self, args):
        super(DiscreteDiagSheafDiffusion, self).__init__(args)
        assert args['d'] > 0

        self.lin_right_weights = nn.ModuleList()
        self.lin_left_weights = nn.ModuleList()

        self.batch_norms = nn.ModuleList()
        if self.right_weights:
            for i in range(self.layers):
                self.lin_right_weights.append(nn.Linear(self.hidden_channels, self.hidden_channels, bias=False))
                nn.init.orthogonal_(self.lin_right_weights[-1].weight.data)
        if self.left_weights:
            for i in range(self.layers):
                self.lin_left_weights.append(nn.Linear(self.final_d, self.final_d, bias=False))
                nn.init.eye_(self.lin_left_weights[-1].weight.data)

        self.sheaf_learners = nn.ModuleList()

        # Define the restriction map with a specific edge handling technique 
        num_sheaf_learners = min(self.layers, self.layers if self.nonlinear else 1)
        for i in range(num_sheaf_learners):
            if self.edges_feat == "none":    
                self.sheaf_learners.append(LocalConcatSheafLearner(self.final_d,
                    self.hidden_channels, out_shape=(self.d,), sheaf_act=self.sheaf_act))
            elif self.edges_feat == "concat":
                self.sheaf_learners.append(LocalConcatSheafLearnerVariant1(self.final_d,
                    self.hidden_channels, out_shape=(self.d,), sheaf_act=self.sheaf_act))
            elif self.edges_feat == "linear":
                self.sheaf_learners.append(LocalConcatSheafLearnerVariant2(self.final_d,
                    self.hidden_channels, out_shape=(self.d,), sheaf_act=self.sheaf_act))
            elif self.edges_feat == "bilinear":
                self.sheaf_learners.append(LocalConcatSheafLearnerVariant3(self.final_d,
                    self.hidden_channels, out_shape=(self.d,), sheaf_act=self.sheaf_act))
                
        # Define linear map for edge features only if needed
        if self.input_dim_edge > 0 and self.edges_feat != "none":
            self.lin3 = nn.Linear(self.input_dim_edge, self.final_d)

        self.epsilons = nn.ParameterList()
        for i in range(self.layers):
            self.epsilons.append(nn.Parameter(torch.zeros((self.final_d, 1))))

        # input dim = f1 (n. of input channels) 
        # hidden dim = n. of hidden channels * d
        # final_d = d
        # lin1 --> linear layer to produce hidden_dim features from input_dim features 
        self.lin1 = nn.Linear(self.input_dim, self.hidden_dim)
        if self.second_linear:
            # Additional linear layer 
            self.lin12 = nn.Linear(self.hidden_dim, self.hidden_dim)

        # Mean, max, sum readout
        if self.readout in ["mean", "max", "sum"]:

            # Output linear layer 
            self.lin2 = nn.Linear(self.hidden_dim, self.output_dim)


        # Concat readout
        if self.readout == "concat":

            # Output linear layer 
            self.lin2 = nn.Linear(self.hidden_dim, self.output_dim)
            self.lin4 = nn.Linear(self.hidden_dim*3, self.hidden_dim)

        # SAG readout layer
        if self.readout == "sag":

            self.pooling_layers = SAGPooling(self.hidden_channels * self.final_d, 0.3)

            # Output linear layer 
            self.lin4 = nn.Linear(self.hidden_dim*3, self.hidden_dim)
            self.lin2 = nn.Linear(self.hidden_dim, self.output_dim)

        # MLP readout layer
        if self.readout == "mlp":

          self.gnn_output_node_dim = self.hidden_dim

          # MLP readout
          self.dense_agg = torch.nn.Sequential(
              nn.Linear(in_features=self.max_num_nodes_in_graph * self.gnn_output_node_dim, out_features=self.dense_intermediate_dim),
              nn.BatchNorm1d(self.dense_intermediate_dim),
              nn.ReLU(),
              nn.Dropout(p=0.7), # Originally there was not
              nn.Linear(in_features=self.dense_intermediate_dim, out_features=self.dense_output_graph_dim),
              nn.BatchNorm1d(self.dense_output_graph_dim),
              nn.ReLU(),
              nn.Dropout(p=0.7)  # Original p = 0.4
          )
          # Output NN for MLP
          self.output_nn = torch.nn.Sequential(
          nn.Linear(in_features=self.dense_output_graph_dim, out_features=self.output_nn_intermediate_dim),
          nn.ReLU(),
          nn.Dropout(p=0.4),
          nn.Linear(in_features=self.output_nn_intermediate_dim, out_features=self.output_dim)
          #nn.Linear(in_features=self.dense_output_graph_dim, out_features=self.output_dim)
          )


    def forward(self, x, edge_index, edge_attr, batch):
        x = F.dropout(x, p=self.input_dropout, training=self.training)
        #(n * b, input features) -> (n* b, hidden feat * d)
        x = self.lin1(x)
        if self.use_act:
            x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        if self.second_linear:
            x = self.lin12(x)
        dim = x.size(dim=0)
        x = x.view(dim * self.final_d, -1)

        if self.input_dim_edge > 0 and self.edges_feat != "none":
            #(e * b, input_edge features) -> (e* b,hidden feat * d)
            edge_attr = self.lin3(edge_attr)
            if self.use_act:
              edge_attr = torch.tanh(edge_attr)

        laplacian_builder = DiagLaplacianBuilder(dim, edge_index, d=self.d,
                                                         normalised=self.normalised,
                                                         deg_normalised=self.deg_normalised,
                                                         add_hp=self.add_hp, add_lp=self.add_lp)
        x0 = x
        for layer in range(self.layers):

            if layer == 0 or self.nonlinear:
                x_maps = F.dropout(x, p=self.dropout if layer > 0 else 0., training=self.training)
                maps = self.sheaf_learners[layer](x_maps.reshape(dim, -1), edge_attr, edge_index)
                L, trans_maps = laplacian_builder(maps)
                self.sheaf_learners[layer].set_L(trans_maps)

            x = F.dropout(x, p=self.dropout, training=self.training)

            if self.left_weights:
                x = x.t().reshape(-1, self.final_d)
                x = self.lin_left_weights[layer](x)
                x = x.reshape(-1, dim * self.final_d).t()

            if self.right_weights:
                x = self.lin_right_weights[layer](x)

            x = torch_sparse.spmm(L[0], L[1], x.size(0), x.size(0), x)

            if self.use_act:
                x = F.elu(x)

            coeff = (1 + torch.tanh(self.epsilons[layer]).tile(dim, 1))
            x0 = coeff * x0 - x
            x = x0

        # Readout layers
        x = x.reshape(dim, -1)

        if self.readout == "mean":
            # Simple mean readout
            x = global_mean_pool(x, batch)  # [batch_size, hidden_channels = hidden feat * d]

            # Linear classifier
            x = self.lin2(x)

        if self.readout == "sum":
            # Simple sum readout
            x = global_add_pool(x, batch) 

            # Linear classifier
            x = self.lin2(x)

        if self.readout == "max":
            # Simple max readout
            x = global_max_pool(x, batch) 

            # Linear classifier
            x = self.lin2(x)

        if self.readout == "concat":
            # Concatenation of [mean, sum, max] readout
            x = torch.cat([global_mean_pool(x, batch), global_add_pool(x, batch), global_max_pool(x, batch)], dim=1)
            x = self.lin4(x)
            x = F.elu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

            # Linear classifier
            x = self.lin2(x)

        if self.readout == "sag":
            # SAG pooling layer
            x, edge_index, _, batch, _, _ = self.pooling_layers(x, edge_index, batch = batch) 
            x = torch.cat([global_mean_pool(x, batch), global_add_pool(x, batch), global_max_pool(x, batch)], dim=1)
            x = self.lin4(x)
            x = F.elu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

            # Linear classifier
            x = self.lin2(x)

        if self.readout == "mlp":
            graph_x, _ = to_dense_batch(x, batch, fill_value=0, max_num_nodes=self.max_num_nodes_in_graph)
            graph_x = self.dense_agg(graph_x.view(-1, graph_x.shape[1] * graph_x.shape[2]))

            # MLP classifier
            x = self.output_nn(graph_x)

        return x

#### Diagonal sheaf diffusion with alternate sheaf diffusion/pooling layers

In [35]:
class DiscreteDiagPoolSheafDiffusion(SheafDiffusion):

    def __init__(self, args):
        super(DiscreteDiagPoolSheafDiffusion, self).__init__(args)
        assert args['d'] > 0

        self.lin_right_weights = nn.ModuleList()
        self.lin_left_weights = nn.ModuleList()

        self.batch_norms = nn.ModuleList()
        if self.right_weights:
            for i in range(self.layers):
                self.lin_right_weights.append(nn.Linear(self.hidden_channels, self.hidden_channels, bias=False))
                nn.init.orthogonal_(self.lin_right_weights[-1].weight.data)
        if self.left_weights:
            for i in range(self.layers):
                self.lin_left_weights.append(nn.Linear(self.final_d, self.final_d, bias=False))
                nn.init.eye_(self.lin_left_weights[-1].weight.data)

        # Definition of SAG Pooling layers, one for each layer 
        self.pooling_layers = nn.ModuleList()
        for i in range(self.layers):
            self.pooling_layers.append(SAGPooling(self.hidden_channels * self.final_d, 0.4))

        # Readout results
        self.readouts = []

        # Define the restriction map with a specific edge handling technique 
        self.sheaf_learners = nn.ModuleList()
        num_sheaf_learners = min(self.layers, self.layers if self.nonlinear else 1)
        for i in range(num_sheaf_learners):
            if self.edges_feat == "none":    
                self.sheaf_learners.append(LocalConcatSheafLearner(self.final_d,
                    self.hidden_channels, out_shape=(self.d,), sheaf_act=self.sheaf_act))
            elif self.edges_feat == "concat":
                self.sheaf_learners.append(LocalConcatSheafLearnerVariant1(self.final_d,
                    self.hidden_channels, out_shape=(self.d,), sheaf_act=self.sheaf_act))
            elif self.edges_feat == "linear":
                self.sheaf_learners.append(LocalConcatSheafLearnerVariant2(self.final_d,
                    self.hidden_channels, out_shape=(self.d,), sheaf_act=self.sheaf_act))
            elif self.edges_feat == "bilinear":
                self.sheaf_learners.append(LocalConcatSheafLearnerVariant3(self.final_d,
                    self.hidden_channels, out_shape=(self.d,), sheaf_act=self.sheaf_act))

        # Define linear map for edge features only if needed
        if self.input_dim_edge > 0 and self.edges_feat != "none":
            self.lin3 = nn.Linear(self.input_dim_edge, self.final_d)

        self.epsilons = nn.ParameterList()
        for i in range(self.layers):
            self.epsilons.append(nn.Parameter(torch.zeros((self.final_d, 1))))

        self.lin1 = nn.Linear(self.input_dim, self.hidden_dim)
        if self.second_linear:
            self.lin12 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.lin2 = nn.Linear(self.hidden_dim * 3, self.output_dim)

    def forward(self, x, edge_index, edge_attr, batch):
        x = F.dropout(x, p=self.input_dropout, training=self.training)
        #(n * b, input features) -> (n* b, hidden feat * d)
        x = self.lin1(x)
        if self.use_act:
            x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        if self.second_linear:
            x = self.lin12(x)
        dim = x.size(dim=0)
        x = x.view(dim * self.final_d, -1)

        if self.input_dim_edge > 0 and self.edges_feat != "none":
            #(e * b, input_edge features) -> (e* b,hidden feat * d)
            edge_attr = self.lin3(edge_attr)
            if self.use_act:
              edge_attr = torch.tanh(edge_attr)

        x0 = x
        for layer in range(self.layers):

            laplacian_builder = DiagLaplacianBuilder(dim, edge_index, d=self.d,
                                                            normalised=self.normalised,
                                                            deg_normalised=self.deg_normalised,
                                                            add_hp=self.add_hp, add_lp=self.add_lp)
            #for weight in self.lin_left_weights.parameters():
              #print("Left weights: ", weight)
            #print("Sheaf learner:", self.sheaf_learners.weight)
            #for weight in self.sheaf_learners.parameters():
              #print("Sheaf weights: ", weight)
            if layer == 0 or self.nonlinear:
                x_maps = F.dropout(x, p=self.dropout if layer > 0 else 0., training=self.training)
                maps = self.sheaf_learners[layer](x_maps.reshape(dim, -1), edge_attr, edge_index)
                L, trans_maps = laplacian_builder(maps)
                self.sheaf_learners[layer].set_L(trans_maps)

            x = F.dropout(x, p=self.dropout, training=self.training)

            if self.left_weights:
                x = x.t().reshape(-1, self.final_d)
                x = self.lin_left_weights[layer](x)
                x = x.reshape(-1, dim * self.final_d).t()

            if self.right_weights:
                x = self.lin_right_weights[layer](x)

            x = torch_sparse.spmm(L[0], L[1], x.size(0), x.size(0), x)

            if self.use_act:
                x = F.elu(x)

            coeff = (1 + torch.tanh(self.epsilons[layer]).tile(dim, 1))
            x = coeff * x0 - x

            # Pooling layer 
            x = x.reshape(dim, -1)
            x, edge_index, edge_attr, batch, _, _ = self.pooling_layers[layer](x, edge_index, edge_attr, batch = batch) 

            dim = x.size(dim=0)
            x = x.view(dim * self.final_d, -1)

            x0 = x


        # Readout layer
        x = x.reshape(dim, -1)
        x = torch.cat([global_mean_pool(x, batch), global_add_pool(x, batch), global_max_pool(x, batch)], dim=1)

        # Linear classifier
        x = self.lin2(x)

        return x

#### Bundle sheaf diffusion

In [16]:
class DiscreteBundleSheafDiffusion(SheafDiffusion):

    def __init__(self, args):
        super(DiscreteBundleSheafDiffusion, self).__init__(args)
        assert args['d'] > 1
        assert not self.deg_normalised

        self.lin_right_weights = nn.ModuleList()
        self.lin_left_weights = nn.ModuleList()

        self.batch_norms = nn.ModuleList()
        if self.right_weights:
            for i in range(self.layers):
                self.lin_right_weights.append(nn.Linear(self.hidden_channels, self.hidden_channels, bias=False))
                nn.init.orthogonal_(self.lin_right_weights[-1].weight.data)
        if self.left_weights:
            for i in range(self.layers):
                self.lin_left_weights.append(nn.Linear(self.final_d, self.final_d, bias=False))
                nn.init.eye_(self.lin_left_weights[-1].weight.data)

        self.sheaf_learners = nn.ModuleList()
        self.weight_learners = nn.ModuleList()

        # Define the restriction map with a specific edge handling technique 
        num_sheaf_learners = min(self.layers, self.layers if self.nonlinear else 1)
        for i in range(num_sheaf_learners):
            if self.use_edge_weights:
                self.weight_learners.append(EdgeWeightLearner(self.hidden_dim))
            if self.edges_feat == "none":    
                self.sheaf_learners.append(LocalConcatSheafLearner(self.final_d,
                    self.hidden_channels, out_shape=(self.d,), sheaf_act=self.sheaf_act))
            elif self.edges_feat == "concat":
                self.sheaf_learners.append(LocalConcatSheafLearnerVariant1(self.final_d,
                    self.hidden_channels, out_shape=(self.d,), sheaf_act=self.sheaf_act))
            elif self.edges_feat == "linear":
                self.sheaf_learners.append(LocalConcatSheafLearnerVariant2(self.final_d,
                    self.hidden_channels, out_shape=(self.d,), sheaf_act=self.sheaf_act))
            elif self.edges_feat == "bilinear":
                self.sheaf_learners.append(LocalConcatSheafLearnerVariant3(self.final_d,
                    self.hidden_channels, out_shape=(self.d,), sheaf_act=self.sheaf_act))

        # Define linear map for edge features only if needed
        if self.input_dim_edge > 0 and self.edges_feat != "none":
            self.lin3 = nn.Linear(self.input_dim_edge, self.final_d)

        self.epsilons = nn.ParameterList()
        for i in range(self.layers):
            self.epsilons.append(nn.Parameter(torch.zeros((self.final_d, 1))))

        self.lin1 = nn.Linear(self.input_dim, self.hidden_dim)
        if self.second_linear:
            self.lin12 = nn.Linear(self.hidden_dim, self.hidden_dim)

        # Mean, max, sum readout
        if self.readout in ["mean", "max", "sum"]:

            # Output linear layer 
            self.lin2 = nn.Linear(self.hidden_dim, self.output_dim)


        # Concat readout
        if self.readout == "concat":

            # Output linear layer 
            self.lin2 = nn.Linear(self.hidden_dim, self.output_dim)
            self.lin4 = nn.Linear(self.hidden_dim*3, self.hidden_dim)

        # SAG readout layer
        if self.readout == "sag":

            self.pooling_layers = SAGPooling(self.hidden_channels * self.final_d, 0.3)

            # Output linear layer 
            self.lin4 = nn.Linear(self.hidden_dim*3, self.hidden_dim)
            self.lin2 = nn.Linear(self.hidden_dim, self.output_dim)

        # MLP readout layer
        if self.readout == "mlp":

          self.gnn_output_node_dim = self.hidden_dim

          # MLP readout
          self.dense_agg = torch.nn.Sequential(
              nn.Linear(in_features=self.max_num_nodes_in_graph * self.gnn_output_node_dim, out_features=self.dense_intermediate_dim),
              nn.BatchNorm1d(self.dense_intermediate_dim),
              nn.ReLU(),
              nn.Dropout(p=0.7), # Originally there was not
              nn.Linear(in_features=self.dense_intermediate_dim, out_features=self.dense_output_graph_dim),
              nn.BatchNorm1d(self.dense_output_graph_dim),
              nn.ReLU(),
              nn.Dropout(p=0.7)  # Original p = 0.4
          )
          # Output NN for MLP
          self.output_nn = torch.nn.Sequential(
          nn.Linear(in_features=self.dense_output_graph_dim, out_features=self.output_nn_intermediate_dim),
          nn.ReLU(),
          nn.Dropout(p=0.4),
          nn.Linear(in_features=self.output_nn_intermediate_dim, out_features=self.output_dim)
          #nn.Linear(in_features=self.dense_output_graph_dim, out_features=self.output_dim)
          )


    def get_param_size(self):
        if self.orth_trans in ['matrix_exp', 'cayley']:
            return self.d * (self.d + 1) // 2
        else:
            return self.d * (self.d - 1) // 2


    def update_edge_index(self, edge_index):
        super().update_edge_index(edge_index)
        for weight_learner in self.weight_learners:
            weight_learner.update_edge_index(edge_index)

    def forward(self, x, edge_index, edge_attr, batch):
        x = F.dropout(x, p=self.input_dropout, training=self.training)
        #(n * b, input features) -> (n* b, hidden feat * d)
        x = self.lin1(x)
        if self.use_act:
            x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        if self.second_linear:
            x = self.lin12(x)
        dim = x.size(dim=0)
        x = x.view(dim * self.final_d, -1)

        if self.input_dim_edge > 0 and self.edges_feat != "none":
            #(e * b, input_edge features) -> (e* b,hidden feat * d)
            edge_attr = self.lin3(edge_attr)
            if self.use_act:
              edge_attr = torch.tanh(edge_attr)

        laplacian_builder = NormConnectionLaplacianBuilder(
            dim, edge_index, d=self.d, add_hp=self.add_hp,
            add_lp=self.add_lp, orth_map=self.orth_trans)

        x0, L = x, None
        for layer in range(self.layers):
            if layer == 0 or self.nonlinear:
                x_maps = F.dropout(x, p=self.dropout if layer > 0 else 0., training=self.training)
                x_maps = x_maps.reshape(dim, -1)
                maps = self.sheaf_learners[layer](x_maps, edge_attr, edge_index)
                edge_weights = self.weight_learners[layer](x_maps, edge_index) if self.use_edge_weights else None
                L, trans_maps = laplacian_builder(maps, edge_weights)
                self.sheaf_learners[layer].set_L(trans_maps)

            x = F.dropout(x, p=self.dropout, training=self.training)

            if self.left_weights:
                x = x.t().reshape(-1, self.final_d)
                x = self.lin_left_weights[layer](x)
                x = x.reshape(-1, dim * self.final_d).t()

            if self.right_weights:
                x = self.lin_right_weights[layer](x)

            # Use the adjacency matrix rather than the diagonal
            x = torch_sparse.spmm(L[0], L[1], x.size(0), x.size(0), x)

            if self.use_act:
                x = F.elu(x)

            coeff = (1 + torch.tanh(self.epsilons[layer]).tile(dim, 1))
            x0 = coeff * x0 - x
            x = x0

        # Readout layers
        x = x.reshape(dim, -1)

        if self.readout == "mean":
            # Simple mean readout
            x = global_mean_pool(x, batch)  # [batch_size, hidden_channels = hidden feat * d]

            # Linear classifier
            x = self.lin2(x)

        if self.readout == "sum":
            # Simple sum readout
            x = global_add_pool(x, batch) 

            # Linear classifier
            x = self.lin2(x)

        if self.readout == "max":
            # Simple max readout
            x = global_max_pool(x, batch) 

            # Linear classifier
            x = self.lin2(x)

        if self.readout == "concat":
            # Concatenation of [mean, sum, max] readout
            x = torch.cat([global_mean_pool(x, batch), global_add_pool(x, batch), global_max_pool(x, batch)], dim=1)
            x = self.lin4(x)
            x = F.elu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

            # Linear classifier
            x = self.lin2(x)

        if self.readout == "sag":
            # SAG pooling layer
            x, edge_index, _, batch, _, _ = self.pooling_layers(x, edge_index, batch = batch) 
            x = torch.cat([global_mean_pool(x, batch), global_add_pool(x, batch), global_max_pool(x, batch)], dim=1)
            x = self.lin4(x)
            x = F.elu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

            # Linear classifier
            x = self.lin2(x)

        if self.readout == "mlp":
            graph_x, _ = to_dense_batch(x, batch, fill_value=0, max_num_nodes=self.max_num_nodes_in_graph)
            graph_x = self.dense_agg(graph_x.view(-1, graph_x.shape[1] * graph_x.shape[2]))

            # MLP classifier
            x = self.output_nn(graph_x)

        return x

# Experiments

#### Imports and common parameters

In [17]:
import sys
import os
import random
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

# Set the seed for everything
torch.manual_seed(43)
torch.cuda.manual_seed(43)
torch.cuda.manual_seed_all(43)
np.random.seed(43)
random.seed(43)

Parameters for model construction, training and test

In [18]:
class Parameters():
    def __init__(self):
        super(Parameters, self).__init__()

        # Optimisation params
        self.epochs=200
        self.lr=0.01
        self.weight_decay=0.0005
        self.sheaf_decay=None
        self.patience = 15        # Early stopping on validation set 

        # Model configuration
        self.second_linear = False
        self.d=3
        self.layers=3
        self.normalised=True
        self.deg_normalised=False
        self.linear=False
        self.hidden_channels= 15
        self.input_dropout=0.0
        self.dropout=0.1
        self.left_weights=True
        self.right_weights=True
        self.add_lp=False
        self.add_hp=False
        self.use_act=True
        self.sheaf_act="tanh"
        self.edge_weights=True
        self.second_linear =False
        self.orth = "householder"   #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization used for the orthogonal group
        self.edges_feat = "none"    #choices=['none', 'concat', 'linear', 'bilinear']
        self.readout = "mean"       #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']

        # Dense readout layer 
        self.dense_intermediate_dim = 256 
        self.dense_output_graph_dim = 128 
        self.output_nn_intermediate_dim = 64
        
        # Transformer layer 
        self.set_transformer_k = 8

args = Parameters()

#### Training phase

In [19]:
def train(model, optimizer, criterion, train_loader, device):
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
         data = data.to(device)
         out = model(data.x, data.edge_index, data.edge_attr, data.batch)  # Perform a single forward pass.
         loss = criterion(out, data.y)  # Compute the loss.
         loss.backward()  # Derive gradients.
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.

#### Test phase 

In [20]:
def test(model, criterion, test_loader, device):
     model.eval()

     correct = 0
     tot_loss = 0

     with torch.no_grad():
        for data in test_loader:  # Iterate in batches over the training/test dataset.
            data = data.to(device)
            out = model(data.x, data.edge_index, data.edge_attr, data.batch) 
            tot_loss = tot_loss + criterion(out, data.y).item()  # Compute the loss.
            pred = out.argmax(dim=1)  # Use the class with highest probability.
            correct += int((pred == data.y).sum())  # Check against ground-truth labels.
     acc = correct / len(test_loader.dataset) * 100  # Derive ratio of correct predictions.
     final_loss = tot_loss / len(test_loader.dataset)
     return acc, final_loss

## ENZYMES

### Download of the dataset

In [38]:
import torch_geometric.transforms as T
from torch_geometric.utils.undirected import to_undirected
from torch_geometric.utils import remove_self_loops

dataset = TUDataset(root='data/TUDataset', name='ENZYMES', transform=T.NormalizeFeatures())

print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of node features: {dataset.num_node_features}')
print(f'Number of edge features: {dataset.num_edge_features}')
print(f'Number of classes: {dataset.num_classes}')

# In order not to mess with laplacian: must be undirected and not contain self loops
# Check that does not contain self loops

max_num_nodes_in_graph = 0

for data in dataset:

    # Remove self-loops
    data.edge_index, _ = remove_self_loops(data.edge_index)

    # Make the graph undirected
    data.edge_index = to_undirected(data.edge_index)
    num_nodes = data.x.size(dim=0)
    
    if num_nodes > max_num_nodes_in_graph:
      max_num_nodes_in_graph = num_nodes

print('====================')
graph1 = dataset[0]  # Get the first graph object.
print(graph1)
print()

print(f'Contains isolated nodes: {graph1.has_isolated_nodes()}')
print(f'Contains self-loops: {graph1.has_self_loops()}')
print(f'Is undirected: {graph1.is_undirected()}')

Dataset: ENZYMES(600):
Number of graphs: 600
Number of node features: 3
Number of edge features: 0
Number of classes: 6
Data(edge_index=[2, 168], x=[37, 3], y=[1])

Contains isolated nodes: False
Contains self-loops: False
Is undirected: True


### Split into training, validation and test

In [39]:
torch.manual_seed(12345)
dataset = dataset.shuffle()

train_dataset = dataset[:440]
valid_dataset = dataset[440:520]
test_dataset = dataset[520:]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of validation graphs: {len(valid_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

Number of training graphs: 440
Number of validation graphs: 80
Number of test graphs: 80


### Creation of Dataloader 


In [40]:
from torch_geometric.loader import DataLoader

batch_size = 128

train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size = batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle=False)

for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()

Step 1:
Number of graphs in the current batch: 128
DataBatch(edge_index=[2, 15760], x=[4119, 3], y=[128], batch=[4119], ptr=[129])

Step 2:
Number of graphs in the current batch: 128
DataBatch(edge_index=[2, 16030], x=[4167, 3], y=[128], batch=[4167], ptr=[129])

Step 3:
Number of graphs in the current batch: 128
DataBatch(edge_index=[2, 15794], x=[4298, 3], y=[128], batch=[4298], ptr=[129])

Step 4:
Number of graphs in the current batch: 56
DataBatch(edge_index=[2, 7002], x=[1857, 3], y=[56], batch=[1857], ptr=[57])



### Running of experiments 

For the ENZYMES dataset, since the graphs don't have edge features, the model extensions that handle edge features are not tested, hence we always set

```
args.edges_feat = "none"
```

In [41]:
# Add extra arguments
args.input_dim = dataset.num_features
args.output_dim = dataset.num_classes
args.input_dim_edge = dataset.num_edge_features
args.max_num_nodes_in_graph = max_num_nodes_in_graph
args.device = torch.device(f'cuda:0' if torch.cuda.is_available() else 'cpu')
assert args.normalised or args.deg_normalised
if args.sheaf_decay is None:
    args.sheaf_decay = args.weight_decay

#### Baseline scalar model (**d=1**), [mean, sum, max] readout 

In [None]:
#Model-specific arguments

# Optimisation params
args.lr=0.005
args.weight_decay=0.0005

# Model configuration
args.d=1
args.layers=4
args.hidden_channels= 15
args.input_dropout=0.0
args.dropout=0.1
args.orth = "householder"   #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization for the orthogonal group
args.edges_feat = "none"    #choices=['none', 'concat', 'linear', 'bilinear']
args.readout = "concat"     #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']


config=args.__dict__
print(config)

model = DiscreteDiagSheafDiffusion(config)
#model = DiscreteDiagPoolSheafDiffusion(config)
#model = DiscreteBundleSheafDiffusion(config)
model = model.to(args.device)

sheaf_learner_params, other_params = model.grouped_parameters()

# Optimizers 
optimizer = torch.optim.Adam([
    {'params': sheaf_learner_params, 'weight_decay': config['sheaf_decay']},
    {'params': other_params, 'weight_decay': config['weight_decay']}
], lr=config['lr'])


# Training and test criterions
criterion_train = torch.nn.CrossEntropyLoss()
# In order not to make test loss change with batch_size 
criterion_test = torch.nn.CrossEntropyLoss(reduction = 'sum')

best_test_acc = 0
best_epoch = 0
# Initial best loss value, used for the early stopping technique
best_loss = 100.00
patience = args.patience

for epoch in range(1, args.epochs):
    # Training epoch
    train(model, optimizer, criterion_train, train_loader, args.device)


    # Compute training, validation and test accuracy
    train_acc, train_loss = test(model, criterion_test, train_loader, args.device)
    valid_acc, valid_loss = test(model, criterion_test, valid_loader, args.device)
    test_acc, test_loss = test(model, criterion_test, test_loader, args.device)

    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_epoch = epoch

    if valid_loss < best_loss:
        best_loss = valid_loss
        patience = args.patience
    else:
        patience -= 1
        
    if patience <= 0: 
        # Early stopping with patience
        break 

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Test Acc: {test_acc:.4f}')

print("\n")
print("Best test accuracy: ", best_test_acc)
print("Best epoch: ", best_epoch)

{'epochs': 200, 'lr': 0.005, 'weight_decay': 0.0005, 'sheaf_decay': 0.0005, 'patience': 15, 'second_linear': False, 'd': 1, 'layers': 4, 'normalised': True, 'deg_normalised': False, 'linear': False, 'hidden_channels': 15, 'input_dropout': 0.0, 'dropout': 0.1, 'left_weights': True, 'right_weights': True, 'add_lp': False, 'add_hp': False, 'use_act': True, 'sheaf_act': 'tanh', 'edge_weights': True, 'orth': 'householder', 'edges_feat': 'none', 'readout': 'concat', 'dense_intermediate_dim': 256, 'dense_output_graph_dim': 128, 'output_nn_intermediate_dim': 64, 'set_transformer_k': 8, 'input_dim': 3, 'output_dim': 6, 'input_dim_edge': 0, 'max_num_nodes_in_graph': 126, 'device': device(type='cuda', index=0)}
Epoch: 001, Train Acc: 20.9091, Train Loss: 1.8667, Valid Loss: 1.9287, Test Acc: 20.0000
Epoch: 002, Train Acc: 18.6364, Train Loss: 1.8867, Valid Loss: 1.9315, Test Acc: 17.5000
Epoch: 003, Train Acc: 19.3182, Train Loss: 1.8330, Valid Loss: 1.8289, Test Acc: 17.5000
Epoch: 004, Train Ac

#### **Diagonal** sheaf diffusion,  **[mean, sum, max]** readout

In [None]:
#Model-specific arguments

# Optimisation params
args.lr=0.005
args.weight_decay=0.0005

# Model configuration
args.d=3
args.layers=4
args.hidden_channels= 15
args.input_dropout=0.0
args.dropout=0.1
args.orth = "householder"   #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization for the orthogonal group 
args.edges_feat = "none"    #choices=['none', 'concat', 'linear', 'bilinear']
args.readout = "concat"     #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']


config=args.__dict__
print(config)

model = DiscreteDiagSheafDiffusion(config)
#model = DiscreteDiagPoolSheafDiffusion(config)
#model = DiscreteBundleSheafDiffusion(config)
model = model.to(args.device)

sheaf_learner_params, other_params = model.grouped_parameters()

# Optimizers 
optimizer = torch.optim.Adam([
    {'params': sheaf_learner_params, 'weight_decay': config['sheaf_decay']},
    {'params': other_params, 'weight_decay': config['weight_decay']}
], lr=config['lr'])


# Training and test criterions
criterion_train = torch.nn.CrossEntropyLoss()
# In order not to make test loss change with batch_size 
criterion_test = torch.nn.CrossEntropyLoss(reduction = 'sum')

best_test_acc = 0
best_epoch = 0
# Initial best loss value, used for the early stopping technique
best_loss = 100.00
patience = args.patience

for epoch in range(1, args.epochs):
    # Training epoch
    train(model, optimizer, criterion_train, train_loader, args.device)


    # Compute training, validation and test accuracy
    train_acc, train_loss = test(model, criterion_test, train_loader, args.device)
    valid_acc, valid_loss = test(model, criterion_test, valid_loader, args.device)
    test_acc, test_loss = test(model, criterion_test, test_loader, args.device)

    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_epoch = epoch

    if valid_loss < best_loss:
        best_loss = valid_loss
        patience = args.patience
    else:
        patience -= 1
        
    if patience <= 0: 
        # Early stopping with patience
        break 

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Test Acc: {test_acc:.4f}')

print("\n")
print("Best test accuracy: ", best_test_acc)
print("Best epoch: ", best_epoch)

{'epochs': 200, 'lr': 0.005, 'weight_decay': 0.0005, 'sheaf_decay': 0.0005, 'patience': 15, 'second_linear': False, 'd': 3, 'layers': 4, 'normalised': True, 'deg_normalised': False, 'linear': False, 'hidden_channels': 15, 'input_dropout': 0.0, 'dropout': 0.1, 'left_weights': True, 'right_weights': True, 'add_lp': False, 'add_hp': False, 'use_act': True, 'sheaf_act': 'tanh', 'edge_weights': True, 'orth': 'householder', 'edges_feat': 'none', 'readout': 'concat', 'dense_intermediate_dim': 256, 'dense_output_graph_dim': 128, 'output_nn_intermediate_dim': 64, 'set_transformer_k': 8, 'input_dim': 3, 'output_dim': 6, 'input_dim_edge': 0, 'max_num_nodes_in_graph': 126, 'device': device(type='cuda', index=0)}
Epoch: 001, Train Acc: 14.7727, Train Loss: 1.9131, Valid Loss: 1.8577, Test Acc: 12.5000
Epoch: 002, Train Acc: 18.1818, Train Loss: 1.8258, Valid Loss: 1.9215, Test Acc: 21.2500
Epoch: 003, Train Acc: 20.2273, Train Loss: 1.7886, Valid Loss: 1.8388, Test Acc: 21.2500
Epoch: 004, Train Ac

#### **Bundle** sheaf diffusion,  **[mean, sum, max]** readout

In [None]:
#Model-specific arguments

# Optimisation params
args.lr=0.01
args.weight_decay=0.0005

# Model configuration
args.d=3
args.layers=4
#args.hidden_channels= 15
#args.input_dropout=0.0
#args.dropout=0.1
#args.orth = "householder"   #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization for the orthogonal group
args.edges_feat = "none"    #choices=['none', 'concat', 'linear', 'bilinear']
args.readout = "concat"      #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']


config=args.__dict__
print(config)

#model = DiscreteDiagSheafDiffusion(config)
#model = DiscreteDiagPoolSheafDiffusion(config)
model = DiscreteBundleSheafDiffusion(config)
model = model.to(args.device)

sheaf_learner_params, other_params = model.grouped_parameters()

# Optimizers 
optimizer = torch.optim.Adam([
    {'params': sheaf_learner_params, 'weight_decay': config['sheaf_decay']},
    {'params': other_params, 'weight_decay': config['weight_decay']}
], lr=config['lr'])


# Training and test criterions
criterion_train = torch.nn.CrossEntropyLoss()
# In order not to make test loss change with batch_size 
criterion_test = torch.nn.CrossEntropyLoss(reduction = 'sum')

best_test_acc = 0
best_epoch = 0
# Initial best loss value, used for the early stopping technique
best_loss = 100.00
patience = args.patience

for epoch in range(1, args.epochs):
    # Training epoch
    train(model, optimizer, criterion_train, train_loader, args.device)


    # Compute training, validation and test accuracy
    train_acc, train_loss = test(model, criterion_test, train_loader, args.device)
    valid_acc, valid_loss = test(model, criterion_test, valid_loader, args.device)
    test_acc, test_loss = test(model, criterion_test, test_loader, args.device)

    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_epoch = epoch

    if valid_loss < best_loss:
        best_loss = valid_loss
        patience = args.patience
    else:
        patience -= 1
        
    if patience <= 0: 
        # Early stopping with patience
        break 

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Test Acc: {test_acc:.4f}')

print("\n")
print("Best test accuracy: ", best_test_acc)
print("Best epoch: ", best_epoch)

{'epochs': 200, 'lr': 0.01, 'weight_decay': 0.0005, 'sheaf_decay': 0.0005, 'patience': 15, 'second_linear': False, 'd': 2, 'layers': 4, 'normalised': True, 'deg_normalised': False, 'linear': False, 'hidden_channels': 15, 'input_dropout': 0.0, 'dropout': 0.1, 'left_weights': True, 'right_weights': True, 'add_lp': False, 'add_hp': False, 'use_act': True, 'sheaf_act': 'tanh', 'edge_weights': True, 'orth': 'householder', 'edges_feat': 'none', 'readout': 'concat', 'dense_intermediate_dim': 256, 'dense_output_graph_dim': 128, 'output_nn_intermediate_dim': 64, 'set_transformer_k': 8, 'input_dim': 3, 'output_dim': 6, 'input_dim_edge': 0, 'max_num_nodes_in_graph': 126, 'device': device(type='cuda', index=0)}
Epoch: 001, Train Acc: 15.4545, Train Loss: 2.0803, Valid Loss: 1.9413, Test Acc: 20.0000
Epoch: 002, Train Acc: 20.0000, Train Loss: 1.8537, Valid Loss: 1.9046, Test Acc: 23.7500
Epoch: 003, Train Acc: 20.4545, Train Loss: 1.7904, Valid Loss: 1.8588, Test Acc: 23.7500
Epoch: 004, Train Acc

#### Diagonal sheaf diffusion, **mean** readout 

In [None]:
#Model-specific arguments

# Optimisation params
args.lr=0.01
args.weight_decay=0.0005

# Model configuration
args.d=3
args.layers=4
args.hidden_channels= 15
args.input_dropout=0.0
args.dropout=0.1
args.orth = "householder"   #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization for the orthogonal group
args.edges_feat = "none"    #choices=['none', 'concat', 'linear', 'bilinear']
args.readout = "mean"       #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']


config=args.__dict__
print(config)

model = DiscreteDiagSheafDiffusion(config)
#model = DiscreteDiagPoolSheafDiffusion(config)
#model = DiscreteBundleSheafDiffusion(config)
model = model.to(args.device)

sheaf_learner_params, other_params = model.grouped_parameters()

# Optimizers 
optimizer = torch.optim.Adam([
    {'params': sheaf_learner_params, 'weight_decay': config['sheaf_decay']},
    {'params': other_params, 'weight_decay': config['weight_decay']}
], lr=config['lr'])


# Training and test criterions
criterion_train = torch.nn.CrossEntropyLoss()
# In order not to make test loss change with batch_size 
criterion_test = torch.nn.CrossEntropyLoss(reduction = 'sum')

best_test_acc = 0
best_epoch = 0
# Initial best loss value, used for the early stopping technique
best_loss = 100.00
patience = args.patience

for epoch in range(1, args.epochs):
    # Training epoch
    train(model, optimizer, criterion_train, train_loader, args.device)


    # Compute training, validation and test accuracy
    train_acc, train_loss = test(model, criterion_test, train_loader, args.device)
    valid_acc, valid_loss = test(model, criterion_test, valid_loader, args.device)
    test_acc, test_loss = test(model, criterion_test, test_loader, args.device)

    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_epoch = epoch

    if valid_loss < best_loss:
        best_loss = valid_loss
        patience = args.patience
    else:
        patience -= 1
        
    if patience <= 0: 
        # Early stopping with patience
        break 

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Test Acc: {test_acc:.4f}')

print("\n")
print("Best test accuracy: ", best_test_acc)
print("Best epoch: ", best_epoch)

{'epochs': 200, 'lr': 0.01, 'weight_decay': 0.0005, 'sheaf_decay': 0.0005, 'patience': 15, 'second_linear': False, 'd': 3, 'layers': 4, 'normalised': True, 'deg_normalised': False, 'linear': False, 'hidden_channels': 15, 'input_dropout': 0.0, 'dropout': 0.1, 'left_weights': True, 'right_weights': True, 'add_lp': False, 'add_hp': False, 'use_act': True, 'sheaf_act': 'tanh', 'edge_weights': True, 'orth': 'householder', 'edges_feat': 'none', 'readout': 'mean', 'dense_intermediate_dim': 256, 'dense_output_graph_dim': 128, 'output_nn_intermediate_dim': 64, 'set_transformer_k': 8, 'input_dim': 3, 'output_dim': 6, 'input_dim_edge': 0, 'max_num_nodes_in_graph': 126, 'device': device(type='cuda', index=0)}
Epoch: 001, Train Acc: 17.7273, Train Loss: 1.7887, Valid Loss: 1.8142, Test Acc: 10.0000
Epoch: 002, Train Acc: 19.3182, Train Loss: 1.7792, Valid Loss: 1.8028, Test Acc: 15.0000
Epoch: 003, Train Acc: 24.0909, Train Loss: 1.7713, Valid Loss: 1.7811, Test Acc: 22.5000
Epoch: 004, Train Acc: 

#### Diagonal sheaf diffusion, **sum** readout 

In [None]:
#Model-specific arguments

# Optimisation params
args.lr=0.01
args.weight_decay=0.0005

# Model configuration
args.d=3
args.layers=4
args.hidden_channels= 15
args.input_dropout=0.0
args.dropout=0.1
args.orth = "householder"   #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization for the orthogonal group
args.edges_feat = "none"    #choices=['none', 'concat', 'linear', 'bilinear']
args.readout = "sum"        #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']


config=args.__dict__
print(config)

model = DiscreteDiagSheafDiffusion(config)
#model = DiscreteDiagPoolSheafDiffusion(config)
#model = DiscreteBundleSheafDiffusion(config)
model = model.to(args.device)

sheaf_learner_params, other_params = model.grouped_parameters()

# Optimizers 
optimizer = torch.optim.Adam([
    {'params': sheaf_learner_params, 'weight_decay': config['sheaf_decay']},
    {'params': other_params, 'weight_decay': config['weight_decay']}
], lr=config['lr'])


# Training and test criterions
criterion_train = torch.nn.CrossEntropyLoss()
# In order not to make test loss change with batch_size 
criterion_test = torch.nn.CrossEntropyLoss(reduction = 'sum')

best_test_acc = 0
best_epoch = 0
# Initial best loss value, used for the early stopping technique
best_loss = 100.00
patience = args.patience

for epoch in range(1, args.epochs):
    # Training epoch
    train(model, optimizer, criterion_train, train_loader, args.device)


    # Compute training, validation and test accuracy
    train_acc, train_loss = test(model, criterion_test, train_loader, args.device)
    valid_acc, valid_loss = test(model, criterion_test, valid_loader, args.device)
    test_acc, test_loss = test(model, criterion_test, test_loader, args.device)

    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_epoch = epoch

    if valid_loss < best_loss:
        best_loss = valid_loss
        patience = args.patience
    else:
        patience -= 1
        
    if patience <= 0: 
        # Early stopping with patience
        break 

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Test Acc: {test_acc:.4f}')

print("\n")
print("Best test accuracy: ", best_test_acc)
print("Best epoch: ", best_epoch)

{'epochs': 200, 'lr': 0.01, 'weight_decay': 0.0005, 'sheaf_decay': 0.0005, 'patience': 15, 'second_linear': False, 'd': 3, 'layers': 4, 'normalised': True, 'deg_normalised': False, 'linear': False, 'hidden_channels': 15, 'input_dropout': 0.0, 'dropout': 0.1, 'left_weights': True, 'right_weights': True, 'add_lp': False, 'add_hp': False, 'use_act': True, 'sheaf_act': 'tanh', 'edge_weights': True, 'orth': 'householder', 'edges_feat': 'none', 'readout': 'sum', 'dense_intermediate_dim': 256, 'dense_output_graph_dim': 128, 'output_nn_intermediate_dim': 64, 'set_transformer_k': 8, 'input_dim': 3, 'output_dim': 6, 'input_dim_edge': 0, 'max_num_nodes_in_graph': 126, 'device': device(type='cuda', index=0)}
Epoch: 001, Train Acc: 8.8636, Train Loss: 3.6313, Valid Loss: 3.2132, Test Acc: 15.0000
Epoch: 002, Train Acc: 17.9545, Train Loss: 3.1746, Valid Loss: 3.0441, Test Acc: 15.0000
Epoch: 003, Train Acc: 17.2727, Train Loss: 1.9633, Valid Loss: 2.0150, Test Acc: 18.7500
Epoch: 004, Train Acc: 20

#### Diagonal sheaf diffusion, **max** readout 

In [None]:
#Model-specific arguments

# Optimisation params
args.lr=0.01
args.weight_decay=0.0005

# Model configuration
args.d=3
args.layers=4
args.hidden_channels= 15
args.input_dropout=0.0
args.dropout=0.1
args.orth = "householder"   #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization for the orthogonal group
args.edges_feat = "none"    #choices=['none', 'concat', 'linear', 'bilinear']
args.readout = "max"        #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']


config=args.__dict__
print(config)

model = DiscreteDiagSheafDiffusion(config)
#model = DiscreteDiagPoolSheafDiffusion(config)
#model = DiscreteBundleSheafDiffusion(config)
model = model.to(args.device)

sheaf_learner_params, other_params = model.grouped_parameters()

# Optimizers 
optimizer = torch.optim.Adam([
    {'params': sheaf_learner_params, 'weight_decay': config['sheaf_decay']},
    {'params': other_params, 'weight_decay': config['weight_decay']}
], lr=config['lr'])


# Training and test criterions
criterion_train = torch.nn.CrossEntropyLoss()
# In order not to make test loss change with batch_size 
criterion_test = torch.nn.CrossEntropyLoss(reduction = 'sum')

best_test_acc = 0
best_epoch = 0
# Initial best loss value, used for the early stopping technique
best_loss = 100.00
patience = args.patience

for epoch in range(1, args.epochs):
    # Training epoch
    train(model, optimizer, criterion_train, train_loader, args.device)


    # Compute training, validation and test accuracy
    train_acc, train_loss = test(model, criterion_test, train_loader, args.device)
    valid_acc, valid_loss = test(model, criterion_test, valid_loader, args.device)
    test_acc, test_loss = test(model, criterion_test, test_loader, args.device)

    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_epoch = epoch

    if valid_loss < best_loss:
        best_loss = valid_loss
        patience = args.patience
    else:
        patience -= 1
        
    if patience <= 0: 
        # Early stopping with patience
        break 

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Test Acc: {test_acc:.4f}')

print("\n")
print("Best test accuracy: ", best_test_acc)
print("Best epoch: ", best_epoch)

{'epochs': 200, 'lr': 0.01, 'weight_decay': 0.0005, 'sheaf_decay': 0.0005, 'patience': 15, 'second_linear': False, 'd': 3, 'layers': 4, 'normalised': True, 'deg_normalised': False, 'linear': False, 'hidden_channels': 15, 'input_dropout': 0.0, 'dropout': 0.1, 'left_weights': True, 'right_weights': True, 'add_lp': False, 'add_hp': False, 'use_act': True, 'sheaf_act': 'tanh', 'edge_weights': True, 'orth': 'householder', 'edges_feat': 'none', 'readout': 'max', 'dense_intermediate_dim': 256, 'dense_output_graph_dim': 128, 'output_nn_intermediate_dim': 64, 'set_transformer_k': 8, 'input_dim': 3, 'output_dim': 6, 'input_dim_edge': 0, 'max_num_nodes_in_graph': 126, 'device': device(type='cuda', index=0)}
Epoch: 001, Train Acc: 16.5909, Train Loss: 1.8175, Valid Loss: 1.8384, Test Acc: 18.7500
Epoch: 002, Train Acc: 17.0455, Train Loss: 1.7918, Valid Loss: 1.7957, Test Acc: 13.7500
Epoch: 003, Train Acc: 18.1818, Train Loss: 1.7924, Valid Loss: 1.8002, Test Acc: 13.7500
Epoch: 004, Train Acc: 1

#### Diagonal sheaf diffusion, **MLP** readout 

In [None]:
#Model-specific arguments

# Optimisation params
args.lr=0.001
args.weight_decay=0.0005

# Model configuration
args.d=3
args.layers=4
args.hidden_channels= 15
args.input_dropout=0.0
args.dropout=0.1
#args.orth = "householder"  #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization for the orthogonal group
args.edges_feat = "none"    #choices=['none', 'concat', 'linear', 'bilinear']
args.readout = "mlp"        #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']


config=args.__dict__
print(config)

model = DiscreteDiagSheafDiffusion(config)
#model = DiscreteDiagPoolSheafDiffusion(config)
#model = DiscreteBundleSheafDiffusion(config)
model = model.to(args.device)

sheaf_learner_params, other_params = model.grouped_parameters()

# Optimizers 
optimizer = torch.optim.Adam([
    {'params': sheaf_learner_params, 'weight_decay': config['sheaf_decay']},
    {'params': other_params, 'weight_decay': config['weight_decay']}
], lr=config['lr'])


# Training and test criterions
criterion_train = torch.nn.CrossEntropyLoss()
# In order not to make test loss change with batch_size 
criterion_test = torch.nn.CrossEntropyLoss(reduction = 'sum')

best_test_acc = 0
best_epoch = 0
# Initial best loss value, used for the early stopping technique
best_loss = 100.00
patience = args.patience

for epoch in range(1, args.epochs):
    # Training epoch
    train(model, optimizer, criterion_train, train_loader, args.device)


    # Compute training, validation and test accuracy
    train_acc, train_loss = test(model, criterion_test, train_loader, args.device)
    valid_acc, valid_loss = test(model, criterion_test, valid_loader, args.device)
    test_acc, test_loss = test(model, criterion_test, test_loader, args.device)

    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_epoch = epoch

    if valid_loss < best_loss:
        best_loss = valid_loss
        patience = args.patience
    else:
        patience -= 1
        
    if patience <= 0: 
        # Early stopping with patience
        break 

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Test Acc: {test_acc:.4f}')

print("\n")
print("Best test accuracy: ", best_test_acc)
print("Best epoch: ", best_epoch)

{'epochs': 200, 'lr': 0.001, 'weight_decay': 0.0005, 'sheaf_decay': 0.0005, 'patience': 15, 'second_linear': False, 'd': 3, 'layers': 4, 'normalised': True, 'deg_normalised': False, 'linear': False, 'hidden_channels': 15, 'input_dropout': 0.0, 'dropout': 0.1, 'left_weights': True, 'right_weights': True, 'add_lp': False, 'add_hp': False, 'use_act': True, 'sheaf_act': 'tanh', 'edge_weights': True, 'orth': 'householder', 'edges_feat': 'none', 'readout': 'mlp', 'dense_intermediate_dim': 256, 'dense_output_graph_dim': 128, 'output_nn_intermediate_dim': 64, 'set_transformer_k': 8, 'input_dim': 3, 'output_dim': 6, 'input_dim_edge': 0, 'max_num_nodes_in_graph': 126, 'device': device(type='cuda', index=0)}
Epoch: 001, Train Acc: 18.4091, Train Loss: 1.7894, Valid Loss: 1.7835, Test Acc: 13.7500
Epoch: 002, Train Acc: 21.1364, Train Loss: 1.7808, Valid Loss: 1.7771, Test Acc: 15.0000
Epoch: 003, Train Acc: 22.2727, Train Loss: 1.7704, Valid Loss: 1.7714, Test Acc: 17.5000
Epoch: 004, Train Acc: 

#### Diagonal sheaf diffusion, **global SAG** readout 

In [None]:
#Model-specific arguments

# Optimisation params
args.lr=0.01
args.weight_decay=0.0005

# Model configuration
args.d=3
args.layers=4
args.hidden_channels= 15
args.input_dropout=0.0
args.dropout=0.1
#args.orth = "householder"  #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization for the orthogonal group
args.edges_feat = "none"    #choices=['none', 'concat', 'linear', 'bilinear']
args.readout = "sag"        #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']


config=args.__dict__
print(config)

model = DiscreteDiagSheafDiffusion(config)
#model = DiscreteDiagPoolSheafDiffusion(config)
#model = DiscreteBundleSheafDiffusion(config)
model = model.to(args.device)

sheaf_learner_params, other_params = model.grouped_parameters()

# Optimizers 
optimizer = torch.optim.Adam([
    {'params': sheaf_learner_params, 'weight_decay': config['sheaf_decay']},
    {'params': other_params, 'weight_decay': config['weight_decay']}
], lr=config['lr'])


# Training and test criterions
criterion_train = torch.nn.CrossEntropyLoss()
# In order not to make test loss change with batch_size 
criterion_test = torch.nn.CrossEntropyLoss(reduction = 'sum')

best_test_acc = 0
best_epoch = 0
# Initial best loss value, used for the early stopping technique
best_loss = 100.00
patience = args.patience

for epoch in range(1, args.epochs):
    # Training epoch
    train(model, optimizer, criterion_train, train_loader, args.device)


    # Compute training, validation and test accuracy
    train_acc, train_loss = test(model, criterion_test, train_loader, args.device)
    valid_acc, valid_loss = test(model, criterion_test, valid_loader, args.device)
    test_acc, test_loss = test(model, criterion_test, test_loader, args.device)

    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_epoch = epoch

    if valid_loss < best_loss:
        best_loss = valid_loss
        patience = args.patience
    else:
        patience -= 1
        
    if patience <= 0: 
        # Early stopping with patience
        break 

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Test Acc: {test_acc:.4f}')

print("\n")
print("Best test accuracy: ", best_test_acc)
print("Best epoch: ", best_epoch)

{'epochs': 200, 'lr': 0.01, 'weight_decay': 0.0005, 'sheaf_decay': 0.0005, 'patience': 15, 'second_linear': False, 'd': 3, 'layers': 4, 'normalised': True, 'deg_normalised': False, 'linear': False, 'hidden_channels': 15, 'input_dropout': 0.0, 'dropout': 0.1, 'left_weights': True, 'right_weights': True, 'add_lp': False, 'add_hp': False, 'use_act': True, 'sheaf_act': 'tanh', 'edge_weights': True, 'orth': 'householder', 'edges_feat': 'none', 'readout': 'sag', 'dense_intermediate_dim': 256, 'dense_output_graph_dim': 128, 'output_nn_intermediate_dim': 64, 'set_transformer_k': 8, 'input_dim': 3, 'output_dim': 6, 'input_dim_edge': 0, 'max_num_nodes_in_graph': 126, 'device': device(type='cuda', index=0)}
Epoch: 001, Train Acc: 14.5455, Train Loss: 1.8949, Valid Loss: 1.9199, Test Acc: 16.2500
Epoch: 002, Train Acc: 16.5909, Train Loss: 1.8280, Valid Loss: 1.8388, Test Acc: 18.7500
Epoch: 003, Train Acc: 20.4545, Train Loss: 1.7940, Valid Loss: 1.8154, Test Acc: 17.5000
Epoch: 004, Train Acc: 2

#### Diagonal sheaf diffusion, **hierarchical SAG** readout 

In [None]:
#Model-specific arguments

# Optimisation params
args.lr=0.001
args.weight_decay=0.0005

# Model configuration
args.d=3
args.layers=2
args.hidden_channels= 15
args.input_dropout=0.0
args.dropout=0.0
#args.orth = "householder"  #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization for the orthogonal group
args.edges_feat = "none"    #choices=['none', 'concat', 'linear', 'bilinear']
#args.readout = "sag"       #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']


config=args.__dict__
print(config)

#model = DiscreteDiagSheafDiffusion(config)
model = DiscreteDiagPoolSheafDiffusion(config)
#model = DiscreteBundleSheafDiffusion(config)
model = model.to(args.device)

sheaf_learner_params, other_params = model.grouped_parameters()

# Optimizers 
optimizer = torch.optim.Adam([
    {'params': sheaf_learner_params, 'weight_decay': config['sheaf_decay']},
    {'params': other_params, 'weight_decay': config['weight_decay']}
], lr=config['lr'])


# Training and test criterions
criterion_train = torch.nn.CrossEntropyLoss()
# In order not to make test loss change with batch_size 
criterion_test = torch.nn.CrossEntropyLoss(reduction = 'sum')

best_test_acc = 0
best_epoch = 0
# Initial best loss value, used for the early stopping technique
best_loss = 100.00
patience = args.patience

for epoch in range(1, args.epochs):
    # Training epoch
    train(model, optimizer, criterion_train, train_loader, args.device)


    # Compute training, validation and test accuracy
    train_acc, train_loss = test(model, criterion_test, train_loader, args.device)
    valid_acc, valid_loss = test(model, criterion_test, valid_loader, args.device)
    test_acc, test_loss = test(model, criterion_test, test_loader, args.device)

    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_epoch = epoch

    if valid_loss < best_loss:
        best_loss = valid_loss
        patience = args.patience
    else:
        patience -= 1
        
    if patience <= 0: 
        # Early stopping with patience
        break 

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Test Acc: {test_acc:.4f}')

print("\n")
print("Best test accuracy: ", best_test_acc)
print("Best epoch: ", best_epoch)

{'epochs': 200, 'lr': 0.001, 'weight_decay': 0.0005, 'sheaf_decay': 0.0005, 'patience': 15, 'second_linear': False, 'd': 3, 'layers': 2, 'normalised': True, 'deg_normalised': False, 'linear': False, 'hidden_channels': 15, 'input_dropout': 0.0, 'dropout': 0.0, 'left_weights': True, 'right_weights': True, 'add_lp': False, 'add_hp': False, 'use_act': True, 'sheaf_act': 'tanh', 'edge_weights': True, 'orth': 'householder', 'edges_feat': 'none', 'readout': 'sag', 'dense_intermediate_dim': 256, 'dense_output_graph_dim': 128, 'output_nn_intermediate_dim': 64, 'set_transformer_k': 8, 'input_dim': 3, 'output_dim': 6, 'input_dim_edge': 0, 'max_num_nodes_in_graph': 126, 'device': device(type='cuda', index=0)}
Epoch: 001, Train Acc: 17.0455, Train Loss: 1.7916, Valid Loss: 1.7909, Test Acc: 11.2500
Epoch: 002, Train Acc: 17.0455, Train Loss: 1.7889, Valid Loss: 1.7869, Test Acc: 11.2500
Epoch: 003, Train Acc: 17.0455, Train Loss: 1.7856, Valid Loss: 1.7809, Test Acc: 11.2500
Epoch: 004, Train Acc: 

## Mutagenicity

### Download of the dataset

In [29]:
import torch_geometric.transforms as T
from torch_geometric.utils.undirected import to_undirected
from torch_geometric.utils import remove_self_loops

dataset = TUDataset(root='data/TUDataset', name='Mutagenicity', transform=T.NormalizeFeatures())

print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of node features: {dataset.num_node_features}')
print(f'Number of edge features: {dataset.num_edge_features}')
print(f'Number of classes: {dataset.num_classes}')

# In order not to mess with laplacian: must be undirected and not contain self loops
# Check that does not contain self loops

max_num_nodes_in_graph = 0
for data in dataset:

    # Remove self-loops
    data.edge_index, _ = remove_self_loops(data.edge_index)

    # Make the graph undirected
    data.edge_index = to_undirected(data.edge_index)
    num_nodes = data.x.size(dim=0)
    
    if num_nodes > max_num_nodes_in_graph:
      max_num_nodes_in_graph = num_nodes

print('====================')
graph1 = dataset[0]  # Get the first graph object.
print(graph1)
print()

print(f'Contains isolated nodes: {graph1.has_isolated_nodes()}')
print(f'Contains self-loops: {graph1.has_self_loops()}')
print(f'Is undirected: {graph1.is_undirected()}')


Downloading https://www.chrsmrrs.com/graphkerneldatasets/Mutagenicity.zip
Extracting data/TUDataset/Mutagenicity/Mutagenicity.zip
Processing...
Done!


Dataset: Mutagenicity(4337):
Number of graphs: 4337
Number of node features: 14
Number of edge features: 3
Number of classes: 2
Data(edge_index=[2, 32], x=[16, 14], edge_attr=[32, 3], y=[1])

Contains isolated nodes: False
Contains self-loops: False
Is undirected: True


### Split into training, validation and test set 

In [30]:
torch.manual_seed(12345)
dataset = dataset.shuffle()

train_dataset = dataset[:3700]
valid_dataset = dataset[3700:4000]
test_dataset = dataset[4000:]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of validation graphs: {len(valid_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

Number of training graphs: 3700
Number of validation graphs: 300
Number of test graphs: 337


### Creation of Dataloader 


In [31]:
from torch_geometric.data import DataLoader

batch_size = 256

train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size = batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle=False)

for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()

Step 1:
Number of graphs in the current batch: 256
DataBatch(edge_index=[2, 15054], x=[7536, 14], edge_attr=[15054, 3], y=[256], batch=[7536], ptr=[257])

Step 2:
Number of graphs in the current batch: 256
DataBatch(edge_index=[2, 15748], x=[7689, 14], edge_attr=[15748, 3], y=[256], batch=[7689], ptr=[257])

Step 3:
Number of graphs in the current batch: 256
DataBatch(edge_index=[2, 15150], x=[7470, 14], edge_attr=[15150, 3], y=[256], batch=[7470], ptr=[257])

Step 4:
Number of graphs in the current batch: 256
DataBatch(edge_index=[2, 16882], x=[8364, 14], edge_attr=[16882, 3], y=[256], batch=[8364], ptr=[257])

Step 5:
Number of graphs in the current batch: 256
DataBatch(edge_index=[2, 15272], x=[7455, 14], edge_attr=[15272, 3], y=[256], batch=[7455], ptr=[257])

Step 6:
Number of graphs in the current batch: 256
DataBatch(edge_index=[2, 15138], x=[7358, 14], edge_attr=[15138, 3], y=[256], batch=[7358], ptr=[257])

Step 7:
Number of graphs in the current batch: 256
DataBatch(edge_inde



Step 10:
Number of graphs in the current batch: 256
DataBatch(edge_index=[2, 16368], x=[7943, 14], edge_attr=[16368, 3], y=[256], batch=[7943], ptr=[257])

Step 11:
Number of graphs in the current batch: 256
DataBatch(edge_index=[2, 15678], x=[7678, 14], edge_attr=[15678, 3], y=[256], batch=[7678], ptr=[257])

Step 12:
Number of graphs in the current batch: 256
DataBatch(edge_index=[2, 15672], x=[7643, 14], edge_attr=[15672, 3], y=[256], batch=[7643], ptr=[257])

Step 13:
Number of graphs in the current batch: 256
DataBatch(edge_index=[2, 15094], x=[7363, 14], edge_attr=[15094, 3], y=[256], batch=[7363], ptr=[257])

Step 14:
Number of graphs in the current batch: 256
DataBatch(edge_index=[2, 16054], x=[8023, 14], edge_attr=[16054, 3], y=[256], batch=[8023], ptr=[257])

Step 15:
Number of graphs in the current batch: 116
DataBatch(edge_index=[2, 6874], x=[3346, 14], edge_attr=[6874, 3], y=[116], batch=[3346], ptr=[117])



### Running of experiments 


In [32]:
# Add extra arguments
args.input_dim = dataset.num_features
args.output_dim = dataset.num_classes
args.input_dim_edge = dataset.num_edge_features
args.max_num_nodes_in_graph = max_num_nodes_in_graph
args.device = torch.device(f'cuda:0' if torch.cuda.is_available() else 'cpu')
assert args.normalised or args.deg_normalised
if args.sheaf_decay is None:
    args.sheaf_decay = args.weight_decay

#### Baseline scalar model (**d=1**), [mean, sum, max] readout 

In [None]:
#Model-specific arguments

# Optimisation params
args.lr=0.01
args.weight_decay=0.0005

# Model configuration
args.d=1
args.layers=4
args.hidden_channels= 15
args.input_dropout=0.0
args.dropout=0.1
args.orth = "householder"   #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization for the orthogonal group
args.edges_feat = "none"    #choices=['none', 'concat', 'linear', 'bilinear']
args.readout = "concat"     #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']


config=args.__dict__
print(config)

model = DiscreteDiagSheafDiffusion(config)
#model = DiscreteDiagPoolSheafDiffusion(config)
#model = DiscreteBundleSheafDiffusion(config)
model = model.to(args.device)

sheaf_learner_params, other_params = model.grouped_parameters()

# Optimizers 
optimizer = torch.optim.Adam([
    {'params': sheaf_learner_params, 'weight_decay': config['sheaf_decay']},
    {'params': other_params, 'weight_decay': config['weight_decay']}
], lr=config['lr'])


# Training and test criterions
criterion_train = torch.nn.CrossEntropyLoss()
# In order not to make test loss change with batch_size 
criterion_test = torch.nn.CrossEntropyLoss(reduction = 'sum')

best_test_acc = 0
best_epoch = 0
# Initial best loss value, used for the early stopping technique
best_loss = 100.00
patience = args.patience

for epoch in range(1, args.epochs):
    # Training epoch
    train(model, optimizer, criterion_train, train_loader, args.device)


    # Compute training, validation and test accuracy
    train_acc, train_loss = test(model, criterion_test, train_loader, args.device)
    valid_acc, valid_loss = test(model, criterion_test, valid_loader, args.device)
    test_acc, test_loss = test(model, criterion_test, test_loader, args.device)

    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_epoch = epoch

    if valid_loss < best_loss:
        best_loss = valid_loss
        patience = args.patience
    else:
        patience -= 1
        
    if patience <= 0: 
        # Early stopping with patience
        break 

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Test Acc: {test_acc:.4f}')

print("\n")
print("Best test accuracy: ", best_test_acc)
print("Best epoch: ", best_epoch)

{'epochs': 200, 'lr': 0.01, 'weight_decay': 0.0005, 'sheaf_decay': 0.0005, 'patience': 15, 'second_linear': False, 'd': 1, 'layers': 4, 'normalised': True, 'deg_normalised': False, 'linear': False, 'hidden_channels': 15, 'input_dropout': 0.0, 'dropout': 0.1, 'left_weights': True, 'right_weights': True, 'add_lp': False, 'add_hp': False, 'use_act': True, 'sheaf_act': 'tanh', 'edge_weights': True, 'orth': 'householder', 'edges_feat': 'none', 'readout': 'concat', 'dense_intermediate_dim': 256, 'dense_output_graph_dim': 128, 'output_nn_intermediate_dim': 64, 'set_transformer_k': 8, 'input_dim': 14, 'output_dim': 2, 'input_dim_edge': 3, 'max_num_nodes_in_graph': 417, 'device': device(type='cuda', index=0)}
Epoch: 001, Train Acc: 65.7297, Train Loss: 0.6309, Valid Loss: 0.6401, Test Acc: 64.9852
Epoch: 002, Train Acc: 65.4054, Train Loss: 0.6041, Valid Loss: 0.5984, Test Acc: 65.2819
Epoch: 003, Train Acc: 69.3514, Train Loss: 0.5834, Valid Loss: 0.5977, Test Acc: 67.0623
Epoch: 004, Train Ac

#### Diagonal sheaf diffusion, **[mean, sum, max]** readout 

In [None]:
#Model-specific arguments

# Optimisation params
args.lr=0.01
args.weight_decay=0.0005

# Model configuration
args.d=3
args.layers=4
args.hidden_channels= 15
args.input_dropout=0.0
args.dropout=0.1
args.orth = "householder"   #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization for the orthogonal group
args.edges_feat = "none"    #choices=['none', 'concat', 'linear', 'bilinear']
args.readout = "concat"     #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']


config=args.__dict__
print(config)

model = DiscreteDiagSheafDiffusion(config)
#model = DiscreteDiagPoolSheafDiffusion(config)
#model = DiscreteBundleSheafDiffusion(config)
model = model.to(args.device)

sheaf_learner_params, other_params = model.grouped_parameters()

# Optimizers 
optimizer = torch.optim.Adam([
    {'params': sheaf_learner_params, 'weight_decay': config['sheaf_decay']},
    {'params': other_params, 'weight_decay': config['weight_decay']}
], lr=config['lr'])


# Training and test criterions
criterion_train = torch.nn.CrossEntropyLoss()
# In order not to make test loss change with batch_size 
criterion_test = torch.nn.CrossEntropyLoss(reduction = 'sum')

best_test_acc = 0
best_epoch = 0
# Initial best loss value, used for the early stopping technique
best_loss = 100.00
patience = args.patience

for epoch in range(1, args.epochs):
    # Training epoch
    train(model, optimizer, criterion_train, train_loader, args.device)


    # Compute training, validation and test accuracy
    train_acc, train_loss = test(model, criterion_test, train_loader, args.device)
    valid_acc, valid_loss = test(model, criterion_test, valid_loader, args.device)
    test_acc, test_loss = test(model, criterion_test, test_loader, args.device)

    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_epoch = epoch

    if valid_loss < best_loss:
        best_loss = valid_loss
        patience = args.patience
    else:
        patience -= 1
        
    if patience <= 0: 
        # Early stopping with patience
        break 

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Test Acc: {test_acc:.4f}')

print("\n")
print("Best test accuracy: ", best_test_acc)
print("Best epoch: ", best_epoch)

{'epochs': 200, 'lr': 0.01, 'weight_decay': 0.0005, 'sheaf_decay': 0.0005, 'patience': 15, 'second_linear': False, 'd': 3, 'layers': 4, 'normalised': True, 'deg_normalised': False, 'linear': False, 'hidden_channels': 15, 'input_dropout': 0.0, 'dropout': 0.1, 'left_weights': True, 'right_weights': True, 'add_lp': False, 'add_hp': False, 'use_act': True, 'sheaf_act': 'tanh', 'edge_weights': True, 'orth': 'householder', 'edges_feat': 'none', 'readout': 'concat', 'dense_intermediate_dim': 256, 'dense_output_graph_dim': 128, 'output_nn_intermediate_dim': 64, 'set_transformer_k': 8, 'input_dim': 14, 'output_dim': 2, 'input_dim_edge': 3, 'max_num_nodes_in_graph': 417, 'device': device(type='cuda', index=0)}
Epoch: 001, Train Acc: 55.1081, Train Loss: 0.6889, Valid Loss: 0.6746, Test Acc: 56.3798
Epoch: 002, Train Acc: 65.3243, Train Loss: 0.6118, Valid Loss: 0.6181, Test Acc: 63.7982
Epoch: 003, Train Acc: 68.3514, Train Loss: 0.5799, Valid Loss: 0.5827, Test Acc: 67.0623
Epoch: 004, Train Ac

#### **Bundle** sheaf diffusion, **[mean, sum, max]** readout 

In [None]:
#Model-specific arguments

# Optimisation params
args.lr=0.01
args.weight_decay=0.0005

# Model configuration
args.d=3
args.layers=4
#args.hidden_channels= 15
#args.input_dropout=0.0
#args.dropout=0.1
#args.orth = "householder"   #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization for the orthogonal group
args.edges_feat = "none"    #choices=['none', 'concat', 'linear', 'bilinear']
args.readout = "concat"       #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']


config=args.__dict__
print(config)

#model = DiscreteDiagSheafDiffusion(config)
#model = DiscreteDiagPoolSheafDiffusion(config)
model = DiscreteBundleSheafDiffusion(config)
model = model.to(args.device)

sheaf_learner_params, other_params = model.grouped_parameters()

# Optimizers 
optimizer = torch.optim.Adam([
    {'params': sheaf_learner_params, 'weight_decay': config['sheaf_decay']},
    {'params': other_params, 'weight_decay': config['weight_decay']}
], lr=config['lr'])


# Training and test criterions
criterion_train = torch.nn.CrossEntropyLoss()
# In order not to make test loss change with batch_size 
criterion_test = torch.nn.CrossEntropyLoss(reduction = 'sum')

best_test_acc = 0
best_epoch = 0
# Initial best loss value, used for the early stopping technique
best_loss = 100.00
patience = args.patience

for epoch in range(1, args.epochs):
    # Training epoch
    train(model, optimizer, criterion_train, train_loader, args.device)


    # Compute training, validation and test accuracy
    train_acc, train_loss = test(model, criterion_test, train_loader, args.device)
    valid_acc, valid_loss = test(model, criterion_test, valid_loader, args.device)
    test_acc, test_loss = test(model, criterion_test, test_loader, args.device)

    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_epoch = epoch

    if valid_loss < best_loss:
        best_loss = valid_loss
        patience = args.patience
    else:
        patience -= 1
        
    if patience <= 0: 
        # Early stopping with patience
        break 

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Test Acc: {test_acc:.4f}')

print("\n")
print("Best test accuracy: ", best_test_acc)
print("Best epoch: ", best_epoch)

{'epochs': 200, 'lr': 0.01, 'weight_decay': 0.0005, 'sheaf_decay': 0.0005, 'patience': 15, 'second_linear': False, 'd': 2, 'layers': 4, 'normalised': True, 'deg_normalised': False, 'linear': False, 'hidden_channels': 15, 'input_dropout': 0.0, 'dropout': 0.1, 'left_weights': True, 'right_weights': True, 'add_lp': False, 'add_hp': False, 'use_act': True, 'sheaf_act': 'tanh', 'edge_weights': True, 'orth': 'householder', 'edges_feat': 'none', 'readout': 'concat', 'dense_intermediate_dim': 256, 'dense_output_graph_dim': 128, 'output_nn_intermediate_dim': 64, 'set_transformer_k': 8, 'input_dim': 14, 'output_dim': 2, 'input_dim_edge': 3, 'max_num_nodes_in_graph': 417, 'device': device(type='cuda', index=0)}
Epoch: 001, Train Acc: 62.8378, Train Loss: 0.6536, Valid Loss: 0.6572, Test Acc: 59.9407
Epoch: 002, Train Acc: 64.6486, Train Loss: 0.6216, Valid Loss: 0.6187, Test Acc: 63.2047
Epoch: 003, Train Acc: 65.8649, Train Loss: 0.5960, Valid Loss: 0.5994, Test Acc: 64.0950
Epoch: 004, Train Ac

#### Diagonal sheaf diffusion, **[mean, sum, max]** readout, edge features handling through **concatenation**

In [None]:
#Model-specific arguments

# Optimisation params
args.lr=0.01
args.weight_decay=0.0005

# Model configuration
args.d=3
args.layers=4
args.hidden_channels= 15
args.input_dropout=0.0
args.dropout=0.1
args.orth = "householder"   #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization for the orthogonal group
args.edges_feat = "concat"    #choices=['none', 'concat', 'linear', 'bilinear']
args.readout = "concat"     #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']


config=args.__dict__
print(config)

model = DiscreteDiagSheafDiffusion(config)
#model = DiscreteDiagPoolSheafDiffusion(config)
#model = DiscreteBundleSheafDiffusion(config)
model = model.to(args.device)

sheaf_learner_params, other_params = model.grouped_parameters()

# Optimizers 
optimizer = torch.optim.Adam([
    {'params': sheaf_learner_params, 'weight_decay': config['sheaf_decay']},
    {'params': other_params, 'weight_decay': config['weight_decay']}
], lr=config['lr'])


# Training and test criterions
criterion_train = torch.nn.CrossEntropyLoss()
# In order not to make test loss change with batch_size 
criterion_test = torch.nn.CrossEntropyLoss(reduction = 'sum')

best_test_acc = 0
best_epoch = 0
# Initial best loss value, used for the early stopping technique
best_loss = 100.00
patience = args.patience

for epoch in range(1, args.epochs):
    # Training epoch
    train(model, optimizer, criterion_train, train_loader, args.device)


    # Compute training, validation and test accuracy
    train_acc, train_loss = test(model, criterion_test, train_loader, args.device)
    valid_acc, valid_loss = test(model, criterion_test, valid_loader, args.device)
    test_acc, test_loss = test(model, criterion_test, test_loader, args.device)

    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_epoch = epoch

    if valid_loss < best_loss:
        best_loss = valid_loss
        patience = args.patience
    else:
        patience -= 1
        
    if patience <= 0: 
        # Early stopping with patience
        break 

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Test Acc: {test_acc:.4f}')

print("\n")
print("Best test accuracy: ", best_test_acc)
print("Best epoch: ", best_epoch)

{'epochs': 200, 'lr': 0.01, 'weight_decay': 0.0005, 'sheaf_decay': 0.0005, 'patience': 15, 'second_linear': False, 'd': 3, 'layers': 4, 'normalised': True, 'deg_normalised': False, 'linear': False, 'hidden_channels': 15, 'input_dropout': 0.0, 'dropout': 0.1, 'left_weights': True, 'right_weights': True, 'add_lp': False, 'add_hp': False, 'use_act': True, 'sheaf_act': 'tanh', 'edge_weights': True, 'orth': 'householder', 'edges_feat': 'concat', 'readout': 'concat', 'dense_intermediate_dim': 256, 'dense_output_graph_dim': 128, 'output_nn_intermediate_dim': 64, 'set_transformer_k': 8, 'input_dim': 14, 'output_dim': 2, 'input_dim_edge': 3, 'max_num_nodes_in_graph': 417, 'device': device(type='cuda', index=0)}
Epoch: 001, Train Acc: 55.4595, Train Loss: 0.6728, Valid Loss: 0.6687, Test Acc: 56.3798
Epoch: 002, Train Acc: 64.5405, Train Loss: 0.6191, Valid Loss: 0.6154, Test Acc: 62.3145
Epoch: 003, Train Acc: 66.3514, Train Loss: 0.5968, Valid Loss: 0.6011, Test Acc: 64.9852
Epoch: 004, Train 

#### Diagonal sheaf diffusion, **[mean, sum, max]** readout, edge features handling through **linear** transform




In [None]:
#Model-specific arguments

# Optimisation params
args.lr=0.01
args.weight_decay=0.0005

# Model configuration
args.d=3
args.layers=4
args.hidden_channels= 15
args.input_dropout=0.0
args.dropout=0.1
args.orth = "householder"   #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization for the orthogonal group
args.edges_feat = "linear"    #choices=['none', 'concat', 'linear', 'bilinear']
args.readout = "concat"     #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']


config=args.__dict__
print(config)

model = DiscreteDiagSheafDiffusion(config)
#model = DiscreteDiagPoolSheafDiffusion(config)
#model = DiscreteBundleSheafDiffusion(config)
model = model.to(args.device)

sheaf_learner_params, other_params = model.grouped_parameters()

# Optimizers 
optimizer = torch.optim.Adam([
    {'params': sheaf_learner_params, 'weight_decay': config['sheaf_decay']},
    {'params': other_params, 'weight_decay': config['weight_decay']}
], lr=config['lr'])


# Training and test criterions
criterion_train = torch.nn.CrossEntropyLoss()
# In order not to make test loss change with batch_size 
criterion_test = torch.nn.CrossEntropyLoss(reduction = 'sum')

best_test_acc = 0
best_epoch = 0
# Initial best loss value, used for the early stopping technique
best_loss = 100.00
patience = args.patience

for epoch in range(1, args.epochs):
    # Training epoch
    train(model, optimizer, criterion_train, train_loader, args.device)


    # Compute training, validation and test accuracy
    train_acc, train_loss = test(model, criterion_test, train_loader, args.device)
    valid_acc, valid_loss = test(model, criterion_test, valid_loader, args.device)
    test_acc, test_loss = test(model, criterion_test, test_loader, args.device)

    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_epoch = epoch

    if valid_loss < best_loss:
        best_loss = valid_loss
        patience = args.patience
    else:
        patience -= 1
        
    if patience <= 0: 
        # Early stopping with patience
        break 

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Test Acc: {test_acc:.4f}')

print("\n")
print("Best test accuracy: ", best_test_acc)
print("Best epoch: ", best_epoch)

{'epochs': 200, 'lr': 0.01, 'weight_decay': 0.0005, 'sheaf_decay': 0.0005, 'patience': 15, 'second_linear': False, 'd': 3, 'layers': 4, 'normalised': True, 'deg_normalised': False, 'linear': False, 'hidden_channels': 15, 'input_dropout': 0.0, 'dropout': 0.1, 'left_weights': True, 'right_weights': True, 'add_lp': False, 'add_hp': False, 'use_act': True, 'sheaf_act': 'tanh', 'edge_weights': True, 'orth': 'householder', 'edges_feat': 'linear', 'readout': 'concat', 'dense_intermediate_dim': 256, 'dense_output_graph_dim': 128, 'output_nn_intermediate_dim': 64, 'set_transformer_k': 8, 'input_dim': 14, 'output_dim': 2, 'input_dim_edge': 3, 'max_num_nodes_in_graph': 417, 'device': device(type='cuda', index=0)}
Epoch: 001, Train Acc: 59.2703, Train Loss: 0.6804, Valid Loss: 0.6930, Test Acc: 55.7864
Epoch: 002, Train Acc: 61.6757, Train Loss: 0.6428, Valid Loss: 0.6377, Test Acc: 61.4243
Epoch: 003, Train Acc: 66.4054, Train Loss: 0.6033, Valid Loss: 0.6137, Test Acc: 62.9080
Epoch: 004, Train 

#### Diagonal sheaf diffusion, **[mean, sum, max]** readout, edge features handling through **bilinear** transform

In [None]:
#Model-specific arguments

# Optimisation params
args.lr=0.01
args.weight_decay=0.0005

# Model configuration
args.d=3
args.layers=4
args.hidden_channels= 15
args.input_dropout=0.0
args.dropout=0.0
args.orth = "householder"   #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization for the orthogonal group
args.edges_feat = "bilinear"    #choices=['none', 'concat', 'linear', 'bilinear']
args.readout = "concat"       #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']


config=args.__dict__
print(config)

model = DiscreteDiagSheafDiffusion(config)
#model = DiscreteDiagPoolSheafDiffusion(config)
#model = DiscreteBundleSheafDiffusion(config)
model = model.to(args.device)

sheaf_learner_params, other_params = model.grouped_parameters()

# Optimizers 
optimizer = torch.optim.Adam([
    {'params': sheaf_learner_params, 'weight_decay': config['sheaf_decay']},
    {'params': other_params, 'weight_decay': config['weight_decay']}
], lr=config['lr'])


# Training and test criterions
criterion_train = torch.nn.CrossEntropyLoss()
# In order not to make test loss change with batch_size 
criterion_test = torch.nn.CrossEntropyLoss(reduction = 'sum')

best_test_acc = 0
best_epoch = 0
# Initial best loss value, used for the early stopping technique
best_loss = 100.00
patience = args.patience

for epoch in range(1, args.epochs):
    # Training epoch
    train(model, optimizer, criterion_train, train_loader, args.device)


    # Compute training, validation and test accuracy
    train_acc, train_loss = test(model, criterion_test, train_loader, args.device)
    valid_acc, valid_loss = test(model, criterion_test, valid_loader, args.device)
    test_acc, test_loss = test(model, criterion_test, test_loader, args.device)

    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_epoch = epoch

    if valid_loss < best_loss:
        best_loss = valid_loss
        patience = args.patience
    else:
        patience -= 1
        
    if patience <= 0: 
        # Early stopping with patience
        break 

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Test Acc: {test_acc:.4f}')

print("\n")
print("Best test accuracy: ", best_test_acc)
print("Best epoch: ", best_epoch)

{'epochs': 200, 'lr': 0.01, 'weight_decay': 0.0005, 'sheaf_decay': 0.0005, 'patience': 15, 'second_linear': False, 'd': 3, 'layers': 4, 'normalised': True, 'deg_normalised': False, 'linear': False, 'hidden_channels': 15, 'input_dropout': 0.0, 'dropout': 0.0, 'left_weights': True, 'right_weights': True, 'add_lp': False, 'add_hp': False, 'use_act': True, 'sheaf_act': 'tanh', 'edge_weights': True, 'orth': 'householder', 'edges_feat': 'bilinear', 'readout': 'concat', 'dense_intermediate_dim': 256, 'dense_output_graph_dim': 128, 'output_nn_intermediate_dim': 64, 'set_transformer_k': 8, 'input_dim': 14, 'output_dim': 2, 'input_dim_edge': 3, 'max_num_nodes_in_graph': 417, 'device': device(type='cuda', index=0)}
Epoch: 001, Train Acc: 61.5676, Train Loss: 0.6784, Valid Loss: 0.6784, Test Acc: 58.4570
Epoch: 002, Train Acc: 65.3784, Train Loss: 0.6627, Valid Loss: 0.6660, Test Acc: 64.6884
Epoch: 003, Train Acc: 66.0270, Train Loss: 0.6223, Valid Loss: 0.6554, Test Acc: 64.0950
Epoch: 004, Trai

#### Diagonal sheaf diffusion, edge-features handling through bilinear transform, **mean** readout

In [None]:
#Model-specific arguments

# Optimisation params
args.lr=0.02
args.weight_decay=0.0005

# Model configuration
args.d=3
args.layers=4
args.hidden_channels= 15
args.input_dropout=0.0
args.dropout=0.0
args.orth = "householder"   #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization for the orthogonal group
args.edges_feat = "bilinear"    #choices=['none', 'concat', 'linear', 'bilinear']
args.readout = "mean"       #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']


config=args.__dict__
print(config)

model = DiscreteDiagSheafDiffusion(config)
#model = DiscreteDiagPoolSheafDiffusion(config)
#model = DiscreteBundleSheafDiffusion(config)
model = model.to(args.device)

sheaf_learner_params, other_params = model.grouped_parameters()

# Optimizers 
optimizer = torch.optim.Adam([
    {'params': sheaf_learner_params, 'weight_decay': config['sheaf_decay']},
    {'params': other_params, 'weight_decay': config['weight_decay']}
], lr=config['lr'])


# Training and test criterions
criterion_train = torch.nn.CrossEntropyLoss()
# In order not to make test loss change with batch_size 
criterion_test = torch.nn.CrossEntropyLoss(reduction = 'sum')

best_test_acc = 0
best_epoch = 0
# Initial best loss value, used for the early stopping technique
best_loss = 100.00
patience = args.patience

for epoch in range(1, args.epochs):
    # Training epoch
    train(model, optimizer, criterion_train, train_loader, args.device)


    # Compute training, validation and test accuracy
    train_acc, train_loss = test(model, criterion_test, train_loader, args.device)
    valid_acc, valid_loss = test(model, criterion_test, valid_loader, args.device)
    test_acc, test_loss = test(model, criterion_test, test_loader, args.device)

    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_epoch = epoch

    if valid_loss < best_loss:
        best_loss = valid_loss
        patience = args.patience
    else:
        patience -= 1
        
    if patience <= 0: 
        # Early stopping with patience
        break 

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Test Acc: {test_acc:.4f}')

print("\n")
print("Best test accuracy: ", best_test_acc)
print("Best epoch: ", best_epoch)

{'epochs': 200, 'lr': 0.02, 'weight_decay': 0.0005, 'sheaf_decay': 0.0005, 'patience': 15, 'second_linear': False, 'd': 3, 'layers': 4, 'normalised': True, 'deg_normalised': False, 'linear': False, 'hidden_channels': 15, 'input_dropout': 0.0, 'dropout': 0.0, 'left_weights': True, 'right_weights': True, 'add_lp': False, 'add_hp': False, 'use_act': True, 'sheaf_act': 'tanh', 'edge_weights': True, 'orth': 'householder', 'edges_feat': 'bilinear', 'readout': 'mean', 'dense_intermediate_dim': 256, 'dense_output_graph_dim': 128, 'output_nn_intermediate_dim': 64, 'set_transformer_k': 8, 'input_dim': 14, 'output_dim': 2, 'input_dim_edge': 3, 'max_num_nodes_in_graph': 417, 'device': device(type='cuda', index=0)}
Epoch: 001, Train Acc: 67.6757, Train Loss: 0.6268, Valid Loss: 0.6335, Test Acc: 64.9852
Epoch: 002, Train Acc: 68.5405, Train Loss: 0.5949, Valid Loss: 0.6064, Test Acc: 67.0623
Epoch: 003, Train Acc: 68.7027, Train Loss: 0.5823, Valid Loss: 0.5995, Test Acc: 66.4688
Epoch: 004, Train 

#### Diagonal sheaf diffusion, edge-features handling through bilinear transform, **sum** readout

In [None]:
#Model-specific arguments

# Optimisation params
args.lr=0.03
args.weight_decay=0.0005

# Model configuration
args.d=3
args.layers=4
args.hidden_channels= 15
args.input_dropout=0.0
args.dropout=0.0
args.orth = "householder"   #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization for the orthogonal group
args.edges_feat = "bilinear"    #choices=['none', 'concat', 'linear', 'bilinear']
args.readout = "sum"       #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']


config=args.__dict__
print(config)

model = DiscreteDiagSheafDiffusion(config)
#model = DiscreteDiagPoolSheafDiffusion(config)
#model = DiscreteBundleSheafDiffusion(config)
model = model.to(args.device)

sheaf_learner_params, other_params = model.grouped_parameters()

# Optimizers 
optimizer = torch.optim.Adam([
    {'params': sheaf_learner_params, 'weight_decay': config['sheaf_decay']},
    {'params': other_params, 'weight_decay': config['weight_decay']}
], lr=config['lr'])


# Training and test criterions
criterion_train = torch.nn.CrossEntropyLoss()
# In order not to make test loss change with batch_size 
criterion_test = torch.nn.CrossEntropyLoss(reduction = 'sum')

best_test_acc = 0
best_epoch = 0
# Initial best loss value, used for the early stopping technique
best_loss = 100.00
patience = args.patience

for epoch in range(1, args.epochs):
    # Training epoch
    train(model, optimizer, criterion_train, train_loader, args.device)


    # Compute training, validation and test accuracy
    train_acc, train_loss = test(model, criterion_test, train_loader, args.device)
    valid_acc, valid_loss = test(model, criterion_test, valid_loader, args.device)
    test_acc, test_loss = test(model, criterion_test, test_loader, args.device)

    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_epoch = epoch

    if valid_loss < best_loss:
        best_loss = valid_loss
        patience = args.patience
    else:
        patience -= 1
        
    if patience <= 0: 
        # Early stopping with patience
        break 

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Test Acc: {test_acc:.4f}')

print("\n")
print("Best test accuracy: ", best_test_acc)
print("Best epoch: ", best_epoch)

{'epochs': 200, 'lr': 0.03, 'weight_decay': 0.0005, 'sheaf_decay': 0.0005, 'patience': 15, 'second_linear': False, 'd': 3, 'layers': 4, 'normalised': True, 'deg_normalised': False, 'linear': False, 'hidden_channels': 15, 'input_dropout': 0.0, 'dropout': 0.0, 'left_weights': True, 'right_weights': True, 'add_lp': False, 'add_hp': False, 'use_act': True, 'sheaf_act': 'tanh', 'edge_weights': True, 'orth': 'householder', 'edges_feat': 'bilinear', 'readout': 'sum', 'dense_intermediate_dim': 256, 'dense_output_graph_dim': 128, 'output_nn_intermediate_dim': 64, 'set_transformer_k': 8, 'input_dim': 14, 'output_dim': 2, 'input_dim_edge': 3, 'max_num_nodes_in_graph': 417, 'device': device(type='cuda', index=0)}
Epoch: 001, Train Acc: 64.7027, Train Loss: 0.6477, Valid Loss: 0.6525, Test Acc: 63.2047
Epoch: 002, Train Acc: 59.0541, Train Loss: 0.6494, Valid Loss: 0.6410, Test Acc: 60.2374
Epoch: 003, Train Acc: 65.9730, Train Loss: 0.6074, Valid Loss: 0.6151, Test Acc: 64.9852
Epoch: 004, Train A

#### Diagonal sheaf diffusion, edge-features handling through bilinear transform, **max** readout

In [None]:
#Model-specific arguments

# Optimisation params
args.lr=0.01
args.weight_decay=0.0005

# Model configuration
args.d=3
args.layers=4
args.hidden_channels= 15
args.input_dropout=0.0
args.dropout=0.0
args.orth = "householder"   #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization for the orthogonal group
args.edges_feat = "bilinear"    #choices=['none', 'concat', 'linear', 'bilinear']
args.readout = "max"       #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']


config=args.__dict__
print(config)

model = DiscreteDiagSheafDiffusion(config)
#model = DiscreteDiagPoolSheafDiffusion(config)
#model = DiscreteBundleSheafDiffusion(config)
model = model.to(args.device)

sheaf_learner_params, other_params = model.grouped_parameters()

# Optimizers 
optimizer = torch.optim.Adam([
    {'params': sheaf_learner_params, 'weight_decay': config['sheaf_decay']},
    {'params': other_params, 'weight_decay': config['weight_decay']}
], lr=config['lr'])


# Training and test criterions
criterion_train = torch.nn.CrossEntropyLoss()
# In order not to make test loss change with batch_size 
criterion_test = torch.nn.CrossEntropyLoss(reduction = 'sum')

best_test_acc = 0
best_epoch = 0
# Initial best loss value, used for the early stopping technique
best_loss = 100.00
patience = args.patience

for epoch in range(1, args.epochs):
    # Training epoch
    train(model, optimizer, criterion_train, train_loader, args.device)


    # Compute training, validation and test accuracy
    train_acc, train_loss = test(model, criterion_test, train_loader, args.device)
    valid_acc, valid_loss = test(model, criterion_test, valid_loader, args.device)
    test_acc, test_loss = test(model, criterion_test, test_loader, args.device)

    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_epoch = epoch

    if valid_loss < best_loss:
        best_loss = valid_loss
        patience = args.patience
    else:
        patience -= 1
        
    if patience <= 0: 
        # Early stopping with patience
        break 

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Test Acc: {test_acc:.4f}')

print("\n")
print("Best test accuracy: ", best_test_acc)
print("Best epoch: ", best_epoch)

{'epochs': 200, 'lr': 0.01, 'weight_decay': 0.0005, 'sheaf_decay': 0.0005, 'patience': 15, 'second_linear': False, 'd': 3, 'layers': 4, 'normalised': True, 'deg_normalised': False, 'linear': False, 'hidden_channels': 15, 'input_dropout': 0.0, 'dropout': 0.0, 'left_weights': True, 'right_weights': True, 'add_lp': False, 'add_hp': False, 'use_act': True, 'sheaf_act': 'tanh', 'edge_weights': True, 'orth': 'householder', 'edges_feat': 'bilinear', 'readout': 'max', 'dense_intermediate_dim': 256, 'dense_output_graph_dim': 128, 'output_nn_intermediate_dim': 64, 'set_transformer_k': 8, 'input_dim': 14, 'output_dim': 2, 'input_dim_edge': 3, 'max_num_nodes_in_graph': 417, 'device': device(type='cuda', index=0)}
Epoch: 001, Train Acc: 59.3514, Train Loss: 0.6625, Valid Loss: 0.6609, Test Acc: 58.4570
Epoch: 002, Train Acc: 65.8108, Train Loss: 0.6104, Valid Loss: 0.6279, Test Acc: 64.9852
Epoch: 003, Train Acc: 67.0270, Train Loss: 0.5980, Valid Loss: 0.6292, Test Acc: 65.5786
Epoch: 004, Train A

#### Diagonal sheaf diffusion, edge-features handling through bilinear transform, **MLP** readout

In [None]:
#Model-specific arguments

# Optimisation params
args.lr=0.001
args.weight_decay=0.0005

# Model configuration
args.d=3
args.layers=4
args.hidden_channels= 15
args.input_dropout=0.0
args.dropout=0.1
#args.orth = "householder"   #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization for the orthogonal group
args.edges_feat = "bilinear"    #choices=['none', 'concat', 'linear', 'bilinear']
args.readout = "mlp"       #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']


config=args.__dict__
print(config)

model = DiscreteDiagSheafDiffusion(config)
#model = DiscreteDiagPoolSheafDiffusion(config)
#model = DiscreteBundleSheafDiffusion(config)
model = model.to(args.device)

sheaf_learner_params, other_params = model.grouped_parameters()

# Optimizers 
optimizer = torch.optim.Adam([
    {'params': sheaf_learner_params, 'weight_decay': config['sheaf_decay']},
    {'params': other_params, 'weight_decay': config['weight_decay']}
], lr=config['lr'])


# Training and test criterions
criterion_train = torch.nn.CrossEntropyLoss()
# In order not to make test loss change with batch_size 
criterion_test = torch.nn.CrossEntropyLoss(reduction = 'sum')

best_test_acc = 0
best_epoch = 0
# Initial best loss value, used for the early stopping technique
best_loss = 100.00
patience = args.patience

for epoch in range(1, args.epochs):
    # Training epoch
    train(model, optimizer, criterion_train, train_loader, args.device)


    # Compute training, validation and test accuracy
    train_acc, train_loss = test(model, criterion_test, train_loader, args.device)
    valid_acc, valid_loss = test(model, criterion_test, valid_loader, args.device)
    test_acc, test_loss = test(model, criterion_test, test_loader, args.device)

    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_epoch = epoch

    if valid_loss < best_loss:
        best_loss = valid_loss
        patience = args.patience
    else:
        patience -= 1
        
    if patience <= 0: 
        # Early stopping with patience
        break 

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Test Acc: {test_acc:.4f}')

print("\n")
print("Best test accuracy: ", best_test_acc)
print("Best epoch: ", best_epoch)

{'epochs': 200, 'lr': 0.001, 'weight_decay': 0.0005, 'sheaf_decay': 0.0005, 'patience': 15, 'second_linear': False, 'd': 3, 'layers': 4, 'normalised': True, 'deg_normalised': False, 'linear': False, 'hidden_channels': 15, 'input_dropout': 0.0, 'dropout': 0.1, 'left_weights': True, 'right_weights': True, 'add_lp': False, 'add_hp': False, 'use_act': True, 'sheaf_act': 'tanh', 'edge_weights': True, 'orth': 'householder', 'edges_feat': 'bilinear', 'readout': 'mlp', 'dense_intermediate_dim': 256, 'dense_output_graph_dim': 128, 'output_nn_intermediate_dim': 64, 'set_transformer_k': 8, 'input_dim': 14, 'output_dim': 2, 'input_dim_edge': 3, 'max_num_nodes_in_graph': 417, 'device': device(type='cuda', index=0)}
Epoch: 001, Train Acc: 61.6486, Train Loss: 0.6646, Valid Loss: 0.6777, Test Acc: 60.8309
Epoch: 002, Train Acc: 63.0811, Train Loss: 0.6328, Valid Loss: 0.6487, Test Acc: 64.3917
Epoch: 003, Train Acc: 66.5946, Train Loss: 0.6095, Valid Loss: 0.6301, Test Acc: 64.3917
Epoch: 004, Train 

#### Diagonal sheaf diffusion, edge-features handling through bilinear transform, **global SAG** readout 

In [None]:
#Model-specific arguments

# Optimisation params
args.lr=0.02
args.weight_decay=0.0005

# Model configuration
args.d=3
args.layers=4
args.hidden_channels= 15
args.input_dropout=0.0
args.dropout=0.1
#args.orth = "householder"   #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization for the orthogonal group
args.edges_feat = "bilinear"    #choices=['none', 'concat', 'linear', 'bilinear']
args.readout = "sag"       #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']


config=args.__dict__
print(config)

model = DiscreteDiagSheafDiffusion(config)
#model = DiscreteDiagPoolSheafDiffusion(config)
#model = DiscreteBundleSheafDiffusion(config)
model = model.to(args.device)

sheaf_learner_params, other_params = model.grouped_parameters()

# Optimizers 
optimizer = torch.optim.Adam([
    {'params': sheaf_learner_params, 'weight_decay': config['sheaf_decay']},
    {'params': other_params, 'weight_decay': config['weight_decay']}
], lr=config['lr'])


# Training and test criterions
criterion_train = torch.nn.CrossEntropyLoss()
# In order not to make test loss change with batch_size 
criterion_test = torch.nn.CrossEntropyLoss(reduction = 'sum')

best_test_acc = 0
best_epoch = 0
# Initial best loss value, used for the early stopping technique
best_loss = 100.00
patience = args.patience

for epoch in range(1, args.epochs):
    # Training epoch
    train(model, optimizer, criterion_train, train_loader, args.device)


    # Compute training, validation and test accuracy
    train_acc, train_loss = test(model, criterion_test, train_loader, args.device)
    valid_acc, valid_loss = test(model, criterion_test, valid_loader, args.device)
    test_acc, test_loss = test(model, criterion_test, test_loader, args.device)

    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_epoch = epoch

    if valid_loss < best_loss:
        best_loss = valid_loss
        patience = args.patience
    else:
        patience -= 1
        
    if patience <= 0: 
        # Early stopping with patience
        break 

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Test Acc: {test_acc:.4f}')

print("\n")
print("Best test accuracy: ", best_test_acc)
print("Best epoch: ", best_epoch)

{'epochs': 200, 'lr': 0.02, 'weight_decay': 0.0005, 'sheaf_decay': 0.0005, 'patience': 15, 'second_linear': False, 'd': 3, 'layers': 4, 'normalised': True, 'deg_normalised': False, 'linear': False, 'hidden_channels': 15, 'input_dropout': 0.0, 'dropout': 0.1, 'left_weights': True, 'right_weights': True, 'add_lp': False, 'add_hp': False, 'use_act': True, 'sheaf_act': 'tanh', 'edge_weights': True, 'orth': 'householder', 'edges_feat': 'bilinear', 'readout': 'sag', 'dense_intermediate_dim': 256, 'dense_output_graph_dim': 128, 'output_nn_intermediate_dim': 64, 'set_transformer_k': 8, 'input_dim': 14, 'output_dim': 2, 'input_dim_edge': 3, 'max_num_nodes_in_graph': 417, 'device': device(type='cuda', index=0)}
Epoch: 001, Train Acc: 59.9730, Train Loss: 0.6564, Valid Loss: 0.6774, Test Acc: 59.0504
Epoch: 002, Train Acc: 58.8649, Train Loss: 0.6527, Valid Loss: 0.6715, Test Acc: 61.1276
Epoch: 003, Train Acc: 67.9189, Train Loss: 0.5994, Valid Loss: 0.6179, Test Acc: 66.4688
Epoch: 004, Train A

#### Diagonal sheaf diffusion, edge-features handling through bilinear transform, **hierarchical SAG** readout 

In [None]:
#Model-specific arguments

# Optimisation params
args.lr=0.005
args.weight_decay=0.0005

# Model configuration
args.d=3
args.layers=2
args.hidden_channels= 15
args.input_dropout=0.0
args.dropout=0.0
#args.orth = "householder"   #choices=['matrix_exp', 'cayley', 'householder', 'euler'], parametrization for the orthogonal group
args.edges_feat = "bilinear"    #choices=['none', 'concat', 'linear', 'bilinear']
#args.readout = "sag"       #choices=['mean', 'sum', 'max', 'sag', 'mlp', 'concat']


config=args.__dict__
print(config)

#model = DiscreteDiagSheafDiffusion(config)
model = DiscreteDiagPoolSheafDiffusion(config)
#model = DiscreteBundleSheafDiffusion(config)
model = model.to(args.device)

sheaf_learner_params, other_params = model.grouped_parameters()

# Optimizers 
optimizer = torch.optim.Adam([
    {'params': sheaf_learner_params, 'weight_decay': config['sheaf_decay']},
    {'params': other_params, 'weight_decay': config['weight_decay']}
], lr=config['lr'])


# Training and test criterions
criterion_train = torch.nn.CrossEntropyLoss()
# In order not to make test loss change with batch_size 
criterion_test = torch.nn.CrossEntropyLoss(reduction = 'sum')

best_test_acc = 0
best_epoch = 0
# Initial best loss value, used for the early stopping technique
best_loss = 100.00
patience = args.patience

for epoch in range(1, args.epochs):
    # Training epoch
    train(model, optimizer, criterion_train, train_loader, args.device)


    # Compute training, validation and test accuracy
    train_acc, train_loss = test(model, criterion_test, train_loader, args.device)
    valid_acc, valid_loss = test(model, criterion_test, valid_loader, args.device)
    test_acc, test_loss = test(model, criterion_test, test_loader, args.device)

    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_epoch = epoch

    if valid_loss < best_loss:
        best_loss = valid_loss
        patience = args.patience
    else:
        patience -= 1
        
    if patience <= 0: 
        # Early stopping with patience
        break 

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Test Acc: {test_acc:.4f}')

print("\n")
print("Best test accuracy: ", best_test_acc)
print("Best epoch: ", best_epoch)

{'epochs': 200, 'lr': 0.005, 'weight_decay': 0.0005, 'sheaf_decay': 0.0005, 'patience': 15, 'second_linear': False, 'd': 3, 'layers': 2, 'normalised': True, 'deg_normalised': False, 'linear': False, 'hidden_channels': 15, 'input_dropout': 0.0, 'dropout': 0.0, 'left_weights': True, 'right_weights': True, 'add_lp': False, 'add_hp': False, 'use_act': True, 'sheaf_act': 'tanh', 'edge_weights': True, 'orth': 'householder', 'edges_feat': 'bilinear', 'readout': 'sag', 'dense_intermediate_dim': 256, 'dense_output_graph_dim': 128, 'output_nn_intermediate_dim': 64, 'set_transformer_k': 8, 'input_dim': 14, 'output_dim': 2, 'input_dim_edge': 3, 'max_num_nodes_in_graph': 417, 'device': device(type='cuda', index=0)}
Epoch: 001, Train Acc: 65.2432, Train Loss: 0.6438, Valid Loss: 0.6474, Test Acc: 59.6439
Epoch: 002, Train Acc: 63.1351, Train Loss: 0.6265, Valid Loss: 0.6253, Test Acc: 64.0950
Epoch: 003, Train Acc: 66.4865, Train Loss: 0.6116, Valid Loss: 0.6301, Test Acc: 62.0178
Epoch: 004, Train 

# Benchmarks: SOTA for Graph Classification

References:
- https://paperswithcode.com/sota/graph-classification-on-enzymes
- https://paperswithcode.com/sota/graph-classification-on-mutagenicity 
- https://paperswithcode.com/paper/hierarchical-graph-pooling-with-structure
- https://github.com/qslim/gnn-spectrum/tree/main/tu
- https://github.com/cszhangzhen/HGP-SL 

## HGP-SL: Hierarchical Graph Pooling with Structure Learning

### Downloads and imports

In [None]:
!git clone https://github.com/dmlc/dgl.git

Cloning into 'dgl'...
remote: Enumerating objects: 37060, done.[K
remote: Counting objects: 100% (54/54), done.[K
remote: Compressing objects: 100% (50/50), done.[K
remote: Total 37060 (delta 11), reused 20 (delta 4), pack-reused 37006[K
Receiving objects: 100% (37060/37060), 20.92 MiB | 20.74 MiB/s, done.
Resolving deltas: 100% (24446/24446), done.


In [None]:
%cd dgl
%cd examples
%cd pytorch
%cd hgp_sl/
!ls

/content/dgl
/content/dgl/examples
/content/dgl/examples/pytorch
/content/dgl/examples/pytorch/hgp_sl
functions.py  layers.py  main.py  networks.py  README.md  utils.py


In [None]:
!pip install dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html

### ENZYMES (HGP-SL)

In [None]:
!python main.py --device 0 --dataset ENZYMES --lr 0.001 --batch_size 128 --pool_ratio 0.8 --dropout 0.0 --conv_layers 2

Trial 1/1
Downloading ./dataset/ENZYMES.zip from https://www.chrsmrrs.com/graphkerneldatasets/ENZYMES.zip...
Extracting file to ./dataset/ENZYMES
  cpuset_checked))
Epoch 10: loss=1.7064, val_acc=0.2167, final_test_acc=0.2333
Epoch 20: loss=1.6623, val_acc=0.3167, final_test_acc=0.3167
Epoch 30: loss=1.6423, val_acc=0.3333, final_test_acc=0.3167
Epoch 40: loss=1.5955, val_acc=0.2833, final_test_acc=0.3167
Epoch 50: loss=1.5488, val_acc=0.2667, final_test_acc=0.3167
Epoch 60: loss=1.5139, val_acc=0.3500, final_test_acc=0.3167
Epoch 70: loss=1.4553, val_acc=0.3000, final_test_acc=0.3500
Epoch 80: loss=1.4471, val_acc=0.3333, final_test_acc=0.4167
Epoch 90: loss=1.3703, val_acc=0.4667, final_test_acc=0.4500
Epoch 100: loss=1.3047, val_acc=0.4667, final_test_acc=0.4333
Epoch 110: loss=1.2902, val_acc=0.5000, final_test_acc=0.4667
Epoch 120: loss=1.2161, val_acc=0.2833, final_test_acc=0.4667
Epoch 130: loss=1.2111, val_acc=0.3667, final_test_acc=0.4167
Epoch 140: loss=1.1071, val_acc=0.5500

### Mutagenicity (HGP-SL)

In [None]:
!python main.py --device 0 --dataset Mutagenicity --lr 0.001 --batch_size 512 --pool_ratio 0.8 --dropout 0.0 --conv_layers 3

DGL backend not selected or invalid.  Assuming PyTorch for now.
Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)
Trial 1/1
Downloading ./dataset/Mutagenicity.zip from https://www.chrsmrrs.com/graphkerneldatasets/Mutagenicity.zip...
Extracting file to ./dataset/Mutagenicity
No Node Attribute Data
  cpuset_checked))
Epoch 10: loss=0.5848, val_acc=0.7159, final_test_acc=0.6253
Epoch 20: loss=0.5566, val_acc=0.7344, final_test_acc=0.6529
Epoch 30: loss=0.5433, val_acc=0.7390, final_test_acc=0.6920
Epoch 40: loss=0.5248, val_acc=0.7575, final_test_acc=0.7080
Epoch 50: loss=0.5177, val_acc=0.7644, final_test_acc=0.7149
Epoch 60: loss=0.4922, val_acc=0.7806, final_test_acc=0.7218
Epoch 70: loss=0.4845, val_acc=0.7737, final_test_acc=0.7471
Epoch 80: loss=0.4833, val_acc=0.7945, final_test_acc=0.7494
Epoch 90: loss=0.4784, val_acc=0.7921, final

## Norm-GN: A New Perspective on the Effects of Spectrum in Graph Neural Networks

### Downloads and imports

In [None]:
!git clone https://github.com/qslim/gnn-spectrum.git

Cloning into 'gnn-spectrum'...
remote: Enumerating objects: 38, done.[K
remote: Counting objects: 100% (38/38), done.[K
remote: Compressing objects: 100% (29/29), done.[K
remote: Total 38 (delta 8), reused 38 (delta 8), pack-reused 0[K
Unpacking objects: 100% (38/38), done.


In [None]:
!pip install dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html

In [None]:
!pip install ogb==1.3.1
!pip install numpy
!pip install easydict
!pip install tensorboard
!pip install tqdm
!pip install json5

In [None]:
%cd gnn-spectrum/tu
!ls

gnn-spectrum  sample_data
/content/gnn-spectrum/tu
configs  main.py  model.py  run_script.sh  tu_dataloader.py


### ENZYMES (Norm-GN)

In [None]:
!python -u ./main.py --config="./configs/ENZYMES.json"

[1;30;43mOutput streaming troncato alle ultime 5000 righe.[0m
Epoch: 078, Train Loss: 0.0064897, Train Acc: 1.0000000, Test Acc: 0.7000000
Epoch: 079, Train Loss: 0.0062109, Train Acc: 1.0000000, Test Acc: 0.7166667
Epoch: 080, Train Loss: 0.0061906, Train Acc: 1.0000000, Test Acc: 0.6666667
Epoch: 081, Train Loss: 0.0056912, Train Acc: 1.0000000, Test Acc: 0.7000000
Epoch: 082, Train Loss: 0.0051489, Train Acc: 1.0000000, Test Acc: 0.6833333
Epoch: 083, Train Loss: 0.0048548, Train Acc: 1.0000000, Test Acc: 0.6833333
Epoch: 084, Train Loss: 0.0052254, Train Acc: 1.0000000, Test Acc: 0.6833333
Epoch: 085, Train Loss: 0.0049430, Train Acc: 1.0000000, Test Acc: 0.6833333
Epoch: 086, Train Loss: 0.0043917, Train Acc: 1.0000000, Test Acc: 0.6833333
Epoch: 087, Train Loss: 0.0052839, Train Acc: 1.0000000, Test Acc: 0.6666667
Epoch: 088, Train Loss: 0.0069308, Train Acc: 1.0000000, Test Acc: 0.7166667
Epoch: 089, Train Loss: 0.0049746, Train Acc: 1.0000000, Test Acc: 0.7166667
Epoch: 090, 

### Mutagenicity (Norm-GN)

In [None]:
!python -u ./main.py --config="./configs/Mutagenicity.json"

DGL backend not selected or invalid.  Assuming PyTorch for now.
Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)
{'dataset_name': 'Mutagenicity', 'basis': 'rho', 'epsilon': 0.5, 'power': 18, 'seeds': [0], 'num_folds': 10, 'num_workers': 8, 'hyperparams': {'batch_size': 64, 'epochs': 121, 'learning_rate': 0.001, 'step_size': 50, 'decay_rate': 0.6}, 'architecture': {'nonlinear': 'GELU', 'layers': 6, 'hidden': 256, 'pooling': 'X', 'dropout': 0}, 'commit_id': '', 'time_stamp': '', 'directory': ''}
{'dataset_name': 'Mutagenicity', 'basis': 'rho', 'epsilon': 0.5, 'power': 18, 'seeds': [0], 'num_folds': 10, 'num_workers': 8, 'hyperparams': {'batch_size': 64, 'epochs': 121, 'learning_rate': 0.001, 'step_size': 50, 'decay_rate': 0.6}, 'architecture': {'nonlinear': 'GELU', 'layers': 6, 'hidden': 256, 'pooling': 'X', 'dropout': 0}, 'commit_id': ''