In [None]:
from torch_geometric.nn import GraphConv

class GCN_ALT(nn.Module):
    def __init__(self, dim, hidden_channels, outputsize):
        super(GCN_ALT, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GraphConv(dim, hidden_channels)
        self.conv2 = GraphConv(hidden_channels, hidden_channels)
        self.conv3 = GraphConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, outputsize)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        x = global_mean_pool(x, batch)
        
        x = x.view(x.size()[0], -1)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

In [None]:
class AE(nn.Module):
    def __init__(self, dim, hidden_channels):
        super().__init__()
        self.encoder_hidden_layer = nn.Linear(
            dim, out_features=hidden_channels
        )
        self.encoder_output_layer = nn.Linear(
            in_features=hidden_channels, out_features=hidden_channels
        )
        self.decoder_hidden_layer = nn.Linear(
            in_features=hidden_channels, out_features=hidden_channels
        )
        self.decoder_output_layer = nn.Linear(
            in_features=hidden_channels, out_features=dim
        )

    def forward(self, x, edge_index):
        G = x[0]
        edges = G.edges()
        edgelist = list(zip(*edges))
        col1 = list(edgelist[0])
        col2 = list(edgelist[1])
        edge_index = torch.tensor([col1, col2])
        
        model = DeepWalk(dimensions=16, workers=12, walk_number=5,walk_length=10)
        model.fit(G)
        x_embed = model.get_embedding()
        x_embed = torch.from_numpy(x_embed)
        
        x = self.encoder_hidden_layer(x_embed)
        x = torch.relu(x)
        x = self.encoder_output_layer(x)
        x = torch.relu(x)
        x = self.decoder_hidden_layer(x)
        x = torch.relu(x)
        
        #x = global_mean_pool(x, batch)
        x = F.dropout(x, p=0.5)
        x = self.decoder_output_layer(x)
        x = x.reshape(1, x.shape[0] * x.shape[1])
        return x

In [None]:
class SimpleTextClassifier(nn.Module):
    def __init__(self, dim=10, hidden_channels=64, outputsize=2):
        super(SimpleTextClassifier, self).__init__()
        self.linear1 = nn.Linear(dim, hidden_channels)
        self.linear2 = nn.Linear(hidden_channels, outputsize)

    def forward(self, x, edge_index):
        G = x[0]
        edges = G.edges()
        edgelist = list(zip(*edges))
        col1 = list(edgelist[0])
        col2 = list(edgelist[1])
        edge_index = torch.tensor([col1, col2])
        
        model = DeepWalk(dimensions=16, workers=12, walk_number=5,walk_length=10)
        model.fit(G)
        x_embed = model.get_embedding()
        x_embed = torch.from_numpy(x_embed)
        
        x = self.linear1(x_embed).clamp(min=0)
        x = self.linear2(x)
        
        #x = global_mean_pool(x, batch)
        x = x.view(x.size()[0], -1)
        x = F.dropout(x, p=0.5)
        x = x.reshape(1, x.shape[0] * x.shape[1])
        return x