In [1]:
import torch

In [10]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

In [22]:
class GCNLayer(nn.Module):
    """Implementation of a single graph convolutional layer.

    Args:
        in_channels (int): Size of each input sample.
        out_channels (int): Size of each output sample.
        aug_adj_type (str): Type of augmented adjacency matrix to use
    """


    def __init__(self, in_channels: int, out_channels: int, aug_adj_type: str):


        super(GCNLayer, self).__init__()


        self.weight = nn.Parameter(torch.rand((in_channels, out_channels)) / 4 - 0.125)
        self.aug_adj_type = aug_adj_type


    def forward(self, x: torch.Tensor, adj_matrix: torch.Tensor):


        degree_matrix = np.diag(np.sum(adj_matrix, axis=1))
        num_nodes = len(adj_matrix)


        if self.aug_adj_type == "symmetric":


            aug_adj_matrix = (
                np.power((degree_matrix + np.identity(num_nodes)), -0.5)
                @ (aug_adj_matrix + np.identity(num_nodes))
                @ np.power((degree_matrix + np.identity(num_nodes)), -0.5)
            )
        elif self.aug_adj_type == "adjacency":
            aug_adj_matrix = adj_matrix
        elif self.aug_adj_type == "degree":
            aug_adj_matrix = degree_matrix
        elif self.aug_adj_type == "random walk":
            aug_adj_matrix = np.power(degree_matrix, -1) @ adj_matrix
        else:
            raise ValueError("Received invalid augmented adjacency matrix type.")

        x = F.relu(aug_adj_matrix @ x @ self.weight)

        return x

In [None]:
class GCN(nn.Module):
    def __init__(
        self,
        in_channels: int,
        hidden_channels: int,
        out_channels: int,
        num_layers: int,
        dropout: float = 0.3,
    ):
        super(GCN, self).__init__()

        self.conv_layers = nn.ModuleList()
        self.num_layers = num_layers
        self.dropout = dropout

        if num_layers >= 2:
            self.conv_layers.append(GCNLayer(in_channels, hidden_channels))

            for i in range(num_layers - 2):
                self.conv_layers.append(GCNLayer(hidden_channels, hidden_channels))

            self.final_conv_layer = GCNLayer(hidden_channels, hidden_channels)
        elif num_layers == 1:
            self.final_conv_layer = GCNLayer(in_channels, hidden_channels)
        else:  # num_layers == 0, single feed-forward network
            self.weight = nn.Parameter(
                torch.rand((in_channels, hidden_channels)) / 4 - 0.125
            )

        self.output_layer = nn.Linear(hidden_channels, out_channels)

    def forward(self, x: torch.Tensor, adj_matrix: torch.Tensor) -> torch.Tensor:
        if self.num_layers >= 1:
            for layer in self.conv_layers:
                x = layer(x, adj_matrix)
                x = F.dropout(x, p=self.dropout)

            x = self.final_conv_layer(x, adj_matrix)

        else:  # num_layers == 0, single feed-forward network
            x = F.relu(x @ self.weight)

        x = self.output_layer(x)
        x = F.softmax(x, dim=1)

        return x

    # def param_init(self):
    #     for layer in self.conv_layers:
    #         layer.reset_parameters()

    #     if self.output_layer:
    #         self.output_layer.reset_parameters()