In [1]:
import numpy as np
from time import time
from datetime import timedelta

# Pytorch
import torch
from torch import nn
from torch.nn import init
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader

In [2]:
class NodePairDataset(Dataset):
    def __init__(self, instances):
        self.instances = instances

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, i):
        instance = self.instances[i]
        edge_embedding = instance['edge_embedding']
        label = instance['label']
        return edge_embedding, label

def get_dataloader(instances, batch_size=1, num_workers=2):
    dataset = NodePairDataset(instances)
    dataloader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)
    return dataloader

In [3]:
class GraphConvolution(nn.Module):
    def __init__(self, input_features, hidden_features):
        super(GraphConvolution, self).__init__()
        self.input_features = input_features
        self.hidden_features = hidden_features
        self.weight = nn.Parameter(nn.init.kaiming_normal_(torch.empty(hidden_features, input_features), mode='fan_in', nonlinearity='relu'))
        self.bias = nn.Parameter(nn.init.kaiming_normal_(torch.empty(hidden_features, input_features), mode='fan_in', nonlinearity='relu'))

    def forward(self, input_features, adj_matrix):
        # aggregate 
        aggregate  = torch.mm(input_features, self.weight)
        propagation = torch.spmm(adj_matrix, aggregate)
        if self.bias:
            return propagation + self.bias
        return propagation

In [4]:
class GCN(nn.Module):
    def __init__(self, inputs_dim, hidden_dim, class_num=2):
        super(GCN, self).__init__()
        self.gcn_layer1 = GraphConvolution(inputs_dim, hidden_dim)
        self.gcn_layer2 = GraphConvolution(hidden_dim, class_num)

    def forward(self, input_features, adj_matrix):
        hidden_state = nn.ReLU(self.gcn_layer1(input_features, adj_matrix))
        hidden_state = self.gcn_layer2(input_features, adj_matrix)
        return nn.Sigmoid(hidden_state)