In [None]:
from torch.nn import AvgPool1d, BatchNorm1d, MaxPool1d, Module


class ResidualBlock(Module):
    def __init__(self, in_channels, out_channels, hidden_channels, K):
        super().__init__()
        self.conv1 = ChebConv(in_channels, hidden_channels, K)
        self.relu1 = nn.ReLU()
        self.bn2 = BatchNorm1d(hidden_channels)
        self.conv2 = ChebConv(in_channels, out_channels, K)
        self.relu2 = nn.ReLU()
        
    def forward(self, x, laplacian):
        out = self.conv1(x, laplacian) # (B, C, V)
        out = self.relu1(out) # (B, C, V)
        out = self.bn2(out) # (B, C, V)
        out = self.conv2(out, laplacian) # (B, C, V)
        return self.relu2(out + x) # (B, C, V)

In [None]:
class ResGEChebNet(Module):
    def __init__(
        self,
        graph: Graph,
        K: int,
        in_channels: int,
        out_channels: int,
        hidden_channels: int = 20,
        pooling: Optional[str] = "max",
        device: Optional[Device] = None,
    ):
        """
        Initialize a ChebNet with convolutional layers and batch normalization.

        Args:
            graph (Graph): graph.
            K (int): degree of the Chebyschev polynomials, the sum goes from indices 0 to K-1.
            in_channels (int): number of dimensions of the input layer.
            out_channels (int): number of dimensions of the output layer.
            hidden_channels (int, optional): number of dimensions of the hidden layers. Defaults to 20.
            pooling (str, optional): global pooling function. Defaults to 'max'.
            device (Device, optional): computation device. Defaults to None.

        Raises:
            ValueError: pooling must be 'avg' or 'max'
        """
        super(ResGEChebNet, self).__init__()

        self.laplacian = self._normlaplacian(
            graph.laplacian(device), lmax=2.0, num_nodes=graph.num_nodes
        )

        if pooling not in {"avg", "max"}:
            raise ValueError(f"{pooling} is not a valid value for pooling: must be 'avg' or 'max'")

        self.resblock1 = ResidualBlock(in_channels, hidden_channels, hidden_channels, K)
        
        self.bn2 = BatchNorm1d(hidden_channels)
        self.resblock2 = ResidualBlock(in_channels, hidden_channels, hidden_channels, K)
        
        self.bn3 = BatchNorm1d(hidden_channels)
        self.resblock3 = ResidualBlock(in_channels, hidden_channels, out_channels, K)
            
        if pooling == "avg":
            self.pool = AvgPool1d(graph.num_nodes)  # theoretical equivariance
        else:
            self.pool = MaxPool1d(graph.num_nodes)  # adds some non linearities, better in practice
            
        self.logsoftmax = Softmax(dim=1)

    def forward(self, x: FloatTensor) -> FloatTensor:
        """
        Forward function receiving as input a batch and outputing a prediction on this batch

        Args:
            x (FloatTensor): the batch to feed the network with.

        Returns:
            (FloatTensor): the predictions on the batch.
        """
        # Input layer
        out = self.resblock1(x, self.laplacian)  # (B, C, V)

        # Hidden layers
        out = self.bn2(out)
        out = self.resblock2(out, self.laplacian)  # (B, C, V)
        out = self.bn3(out)
        out = self.resblock3(out, self.laplacian)  # (B, C, V)

        # Output layer
        out = self.pool(out).squeeze()  # (B, C)
        return self.logsoftmax(out)  # (B, C)

    def _normlaplacian(
        self, laplacian: SparseFloatTensor, lmax: float, num_nodes: int
    ) -> SparseFloatTensor:
        """Scale the eigenvalues from [0, lmax] to [-1, 1]."""
        return 2 * laplacian / lmax - sparse_tensor_diag(num_nodes, device=laplacian.device)

    @property
    def capacity(self) -> int:
        """
        Return the capacity of the network, i.e. its number of trainable parameters.

        Returns:
            (int): number of trainable parameters of the network.
        """
        return sum(p.numel() for p in self.parameters())
