In [54]:
import copy
import pickle
import numpy as np
import networkx as nx
from sklearn.metrics import f1_score

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as pyg_nn
from torch.utils.data import DataLoader
from torch_sparse import SparseTensor, matmul

import deepsnap
from deepsnap.hetero_gnn import (
    HeteroSAGEConv,
    HeteroConv,
    forward_op
)
from deepsnap.hetero_graph import HeteroGraph
from deepsnap.dataset import GraphDataset
from deepsnap.batch import Batch

## Define GraphSAGE for Link Prediction

In [2]:
class HeteroGNNConv(pyg_nn.MessagePassing):
    def __init__(self, in_channels_src, in_channels_dst, out_channels):
        super(HeteroGNNConv, self).__init__(aggr="mean")

        self.in_channels_src = in_channels_src
        self.in_channels_dst = in_channels_dst
        self.out_channels = out_channels
        self.lin_dst = nn.Linear(in_channels_dst, out_channels)
        self.lin_src = nn.Linear(in_channels_src, out_channels)
        self.lin_update = nn.Linear(2 * out_channels, out_channels)

    def forward(
        self,
        node_feature_src,
        node_feature_dst,
        edge_index,
        size=None
    ):
        return self.propagate(edge_index=edge_index, size=size, node_feature_src=node_feature_src,
                       node_feature_dst=node_feature_dst)

    def message_and_aggregate(self, edge_index, node_feature_src):
        return matmul(edge_index, node_feature_src, reduce="mean")

    def update(self, aggr_out, node_feature_dst):
        aggr_out = self.lin_src(aggr_out)
        node_feature_dst = self.lin_dst(node_feature_dst)
        aggr_out = torch.cat([node_feature_dst, aggr_out], dim = -1)
        aggr_out = self.lin_update(aggr_out)
        return aggr_out

In [3]:
class HeteroGNNWrapperConv(deepsnap.hetero_gnn.HeteroConv):
    def __init__(self, convs, args, aggr="mean"):
        super(HeteroGNNWrapperConv, self).__init__(convs, None)
        self.aggr = aggr
        self.mapping = {}  # Map the index and message type
        self.alpha = None  # A numpy array that stores the final attention probability
        self.attn_proj = None

        if self.aggr == "attn":
            self.attn_proj = nn.Sequential(
                nn.Linear(args['hidden_size'], args['attn_size']),
                nn.Tanh(),
                nn.Linear(args['attn_size'], 1, bias=False)
            )

    def reset_parameters(self):
        super(HeteroGNNWrapperConv, self).reset_parameters()
        if self.aggr == "attn":
            for layer in self.attn_proj.children():
                layer.reset_parameters()

    def forward(self, node_features, edge_indices):
        message_type_emb = {}
        for message_key, message_type in edge_indices.items():
            src_type, edge_type, dst_type = message_key
            node_feature_src = node_features[src_type]
            node_feature_dst = node_features[dst_type]
            edge_index = edge_indices[message_key]
            message_type_emb[message_key] = (
                self.convs[message_key](
                    node_feature_src,
                    node_feature_dst,
                    edge_index,
                )
            )
        node_emb = {dst: [] for _, _, dst in message_type_emb.keys()}
        mapping = {}
        for (src, edge_type, dst), item in message_type_emb.items():
            mapping[len(node_emb[dst])] = (src, edge_type, dst)
            node_emb[dst].append(item)
        self.mapping = mapping
        for node_type, embs in node_emb.items():
            if len(embs) == 1:
                node_emb[node_type] = embs[0]
            else:
                node_emb[node_type] = self.aggregate(embs)
        return node_emb

    def aggregate(self, xs):
        if self.aggr == "mean":
            x = torch.stack(xs, dim=-1)
            return x.mean(dim=-1)

        elif self.aggr == "attn":
            N = xs[0].shape[0] # Number of nodes for that node type
            M = len(xs) # Number of message types for that node type

            x = torch.cat(xs, dim=0).view(M, N, -1) # M * N * D
            z = self.attn_proj(x).view(M, N) # M * N * 1
            z = z.mean(1) # M * 1
            alpha = torch.softmax(z, dim=0) # M * 1

            # Store the attention result to self.alpha as np array
            self.alpha = alpha.view(-1).data.cpu().numpy()

            alpha = alpha.view(M, 1, 1)
            x = x * alpha
            return x.sum(dim=0)

In [57]:
def generate_convs(hetero_graph, conv, hidden_size):
    convs1 = {}
    convs2 = {}
    for message_type in hetero_graph.message_types:
        n_type = message_type[0]
        s_type = message_type[2]
        n_feat_dim = hetero_graph.num_node_features(n_type)
        s_feat_dim = hetero_graph.num_node_features(s_type)
        convs1[message_type] = conv(n_feat_dim, hidden_size, s_feat_dim)
        convs2[message_type] = conv(hidden_size, hidden_size, hidden_size)
    return convs1, convs2

In [74]:
class HeteroGNN(torch.nn.Module):
    def __init__(self, hetero_graph, args, aggr="mean"):
        super(HeteroGNN, self).__init__()
        self.aggr = aggr
        self.hidden_size = args['hidden_size']
        convs1, convs2 = generate_convs(hetero_graph, HeteroSAGEConv, self.hidden_size)
        self.convs1 = HeteroConv(convs1)
        self.convs2 = HeteroConv(convs2)
        self.bns1 = nn.ModuleDict()
        self.bns2 = nn.ModuleDict()
        self.relus1 = nn.ModuleDict()
        self.relus2 = nn.ModuleDict()
        self.post_mps = nn.ModuleDict()
        self.loss_fn = torch.nn.BCEWithLogitsLoss()

        for node_type in hetero_graph.node_types:
            self.bns1[node_type] = nn.BatchNorm1d(self.hidden_size, eps=1)
            self.bns2[node_type] = nn.BatchNorm1d(self.hidden_size, eps=1)
            self.relus1[node_type] = nn.LeakyReLU()
            self.relus2[node_type] = nn.LeakyReLU()
            self.post_mps[node_type] = nn.Linear(self.hidden_size, hetero_graph.num_node_labels(node_type))

    def forward(self, data):
        x, edge_index = data.node_feature, data.edge_index
        x = self.convs1(x, edge_index)
        x = forward_op(x, self.bns1)
        x = forward_op(x, self.relus1)
        x = self.convs2(x, edge_index)
        x = forward_op(x, self.bns2)
        x = forward_op(x, self.relus2)
        
        pred = {}
        for message_type in edge_index:
            nodes_first = torch.index_select(x['n1'], 0, edge_index[message_type][0,:].long())
            nodes_second = torch.index_select(x['n1'], 0, edge_index[message_type][1,:].long())
            pred[message_type] = torch.sum(nodes_first * nodes_second, dim=-1)
        return pred

    def loss(self, pred, y):
        loss = 0
        breakpoint()
        for key in pred:
            p = torch.sigmoid(pred[key])
            loss += self.loss_fn(p, y[key].type(pred[key].dtype))
        return loss

In [75]:
def test(model, dataloaders, args):
    model.eval()
    accs = {}
    for mode, dataloader in dataloaders.items():
        acc = 0
        for i, batch in enumerate(dataloader):
            num = 0
            batch.to(args["device"])
            pred = model(batch)
            for key in pred:
                p = torch.sigmoid(pred[key]).cpu().detach().numpy()
                pred_label = np.zeros_like(p, dtype=np.int64)
                pred_label[np.where(p > 0.5)[0]] = 1
                pred_label[np.where(p <= 0.5)[0]] = 0
                acc += np.sum(pred_label == batch.edge_label[key].cpu().numpy())
                num += len(pred_label)
        accs[mode] = acc / num
    return accs

def train(model, dataloaders, optimizer, args):
    val_max = 0
    best_model = model
    t_accu = []
    v_accu = []
    e_accu = []
    for epoch in range(1, args["epochs"] + 1):
        for iter_i, batch in enumerate(dataloaders['train']):
            batch.to(args["device"])
            model.train()
            optimizer.zero_grad()
            pred = model(batch)
            loss = model.loss(pred, batch.edge_label)
            loss.backward()
            optimizer.step()

            log = 'Epoch: {:03d}, Train loss: {:.4f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
            accs = test(model, dataloaders, args)
            t_accu.append(accs['train'])
            v_accu.append(accs['val'])
            e_accu.append(accs['test'])

            print(log.format(epoch, loss.item(), accs['train'], accs['val'], accs['test']))
            if val_max < accs['val']:
                val_max = accs['val']
                best_model = copy.deepcopy(model)

    log = 'Best: Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
    accs = test(best_model, dataloaders, args)
    print(log.format(accs['train'], accs['val'], accs['test']))

    return t_accu, v_accu, e_accu

# Begin Training

In [76]:
args = {
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'hidden_size': 64,
    'epochs': 100,
    'weight_decay': 1e-5,
    'lr': 0.003,
    'attn_size': 32,
    'edge_message_ratio': 0.8,
    'split_ratio': [0.8, 0.1, 0.1]
}

import data

In [77]:
with open("oncourse.pkl", 'rb') as f:
    G = pickle.load(f)
    
def WN_Transform(G, edge_types={'pin': 0}, input_dim=32):
    H = nx.MultiDiGraph()
    for node in G.nodes():
        # "user" if "user" in node else "course"
        H.add_node(node, node_type='n1', node_feature=torch.ones(input_dim))
    for u, v in G.edges:
        e_feat = torch.zeros(len(edge_types))
        e_feat[edge_types[G[u][v]['type']]] = 1.
        H.add_edge(u, v, edge_feature=e_feat, edge_type=G[u][v]['type'])
    return H


H = WN_Transform(G)
hetero = HeteroGraph(H)
hetero = HeteroGraph(
    edge_index=hetero.edge_index,
    edge_feature=hetero.edge_feature,
    node_feature=hetero.node_feature,
    directed=hetero.is_directed()
)

dataset = GraphDataset(
    [hetero],
    task='link_pred',
    edge_train_mode='disjoint',
    edge_message_ratio=args['edge_message_ratio']
)

create DataLoaders

In [78]:
dataset_train, dataset_val, dataset_test = dataset.split(
    transductive=True, split_ratio=args['split_ratio']
)
train_loader = DataLoader(
    dataset_train, collate_fn=Batch.collate(), batch_size=1
)
val_loader = DataLoader(
    dataset_val, collate_fn=Batch.collate(), batch_size=1
)
test_loader = DataLoader(
    dataset_test, collate_fn=Batch.collate(), batch_size=1
)
dataloaders = {
    'train': train_loader, 'val': val_loader, 'test': test_loader
}

In [80]:
train_loader

<torch.utils.data.dataloader.DataLoader at 0x7fc53358f410>

### Train

In [79]:
model = HeteroGNN(hetero, args, aggr="mean").to(args["device"])
optimizer = torch.optim.Adam(
    model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"]
)

t_accu, v_accu, e_accu = train(model, dataloaders, optimizer, args)

ValueError: Target size (torch.Size([9852])) must be the same as input size (torch.Size([19701]))