In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.data
import numpy as np
from dgl import save_graphs, load_graphs
import torch as th
import dgl.function as fn
import dgl.nn as dglnn
from dgl.dataloading.negative_sampler import _BaseNegativeSampler
from dgl import backend as b

In [2]:
class PerSourceUniformCustom(_BaseNegativeSampler):

    def __init__(self, k):
        self.k = k

    def _generate(self, g, eids, canonical_etype):
        unique_authors = torch.unique(g.edges(etype = "authored")[1])
        #print(len(unique_authors))
        _, _, vtype = canonical_etype
        shape = b.shape(eids)
        dtype = b.dtype(eids)
        ctx = b.context(eids)
        shape = (shape[0] * self.k,)
        src, _ = g.find_edges(eids, etype=canonical_etype)
        src = b.repeat(src, self.k, 0)
        dst_indexes = th.randint(0, len(unique_authors), shape, dtype=dtype, device=ctx)
        dst = unique_authors[dst_indexes]
        return src, dst

In [69]:
class HeteroDotProductPredictor(nn.Module):
    def forward(self, graph, h, etype):
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']

    def returnScore(self, graph, neg_graph, h, etype):
        return self(graph, h, etype), self(neg_graph, h, etype)

class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, num_classes_papers, num_classes_authors):

        super().__init__()

        self.embedding_paper = nn.Embedding(num_classes_papers, 64)
        self.embedding_author = nn.Embedding(num_classes_authors, 64)
        #self.conv_writes = dglnn.GraphConv(64, 64, allow_zero_in_degree=True)
        #self.conv_cites = dglnn.GraphConv(64, 64, allow_zero_in_degree=True)
        self.conv1 = dglnn.HeteroGraphConv({
            "writes": dglnn.GraphConv(64, 64),
        }, aggregate="stack")

        self.conv2 = dglnn.HeteroGraphConv({
            "cites": dglnn.GraphConv(64, 64),
        })


    def forward(self, graph, inputs):
        classes_paper = inputs["paper"]
        classes_author = inputs["author"]

        # Put all classes into an embedding
        embedded_papers = self.embedding_paper(classes_paper)
        print(embedded_papers[0])
        embedded_authors = self.embedding_author(classes_author)

        #paper_feat = torch.cat([embedded_papers, author_agg], dim=-1)
        #h = self.conv_cites(paper_feat)
        h_writes = self.conv1(graph["writes"], {"author": embedded_authors, "paper": 0})
        h_authored_papers = torch.cat((h_writes, embedded_papers.unsqueeze(2)), dim=1)
        h_cites = self.conv2(graph["cites"], {"paper": h_authored_papers, "author": 0}).squeeze(1)
        return {"paper": h_cites, "author": embedded_authors}

class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names, num_classes_paper, num_classes_author):
        super().__init__()
        # Encoder
        self.sage = RGCN(in_features, hidden_features, out_features, num_classes_paper, num_classes_author)
        # Decoder
        self.pred = HeteroDotProductPredictor()
    def forward(self, g, x):
        h = self.sage(g, x)
        return h

    def scores(self, g, neg_g, x, etype):
      h = self(g, neg_g, x, etype)
      return self.pred(g, h, etype), self.pred(neg_g, h, etype)


def accuracy(logits, graph):
  with torch.no_grad():
    all_papers = torch.unique(graph.edges(etype="authored")[0])
    src, dst = graph.edges(etype="authored")
    tst = 0
    author_logits = logits["author"]
    #print(len(all_papers))
    for idx, index_paper in enumerate(all_papers):
      #if idx % 25000 == 0:
        #print(f"{idx}/{len(all_papers)}")
      current_logits = logits["paper"][index_paper]
      max = torch.argmax(torch.sum(current_logits * author_logits, dim=-1))
      filter_acc = src == index_paper
      if max in dst[filter_acc]:
        tst += 1
    return tst/len(all_papers)

def compute_loss_logits(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score])
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]
    )
    return F.binary_cross_entropy_with_logits(scores.squeeze(1), labels)

def construct_negative_graph(graph, k, etype):
    utype, _, vtype = etype
    src, dst = graph.edges(etype=etype)
    neg_src = src.repeat_interleave(k)
    neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))
    return dgl.heterograph(
        {etype: (neg_src, neg_dst)},
        num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})

In [133]:
from dgl import sum_nodes

class HeteroSumPooling(nn.Module):

    def __init__(self, node_type):
        super(HeteroSumPooling, self).__init__()
        self.node_type = node_type

    def forward(self, graph, feat):

        with graph.local_scope():
            graph.nodes[self.node_type].data["h"] = feat
            readout = dgl.sum_nodes(graph, "h", ntype=self.node_type)
            return readout


# Little example

In [75]:
hetero_graph_example = dgl.heterograph(
    {
        ('paper', 'authored', 'author'): (np.array([2,0]), np.array([0,1])),
        ('paper', 'cites', 'paper'): (np.array([4,1,1,4,1,4,6]), np.array([1,2,3,0,0,5,6])),
        ('author', 'writes', 'paper'): (np.array([2,3,2,3,1,3]), np.array([1,1,6,6,4,4]))
    }
)

hetero_graph_example_eval = dgl.heterograph(
    {
        ('paper', 'authored', 'author'): (np.array([3,5]), np.array([0,1])),
        ('paper', 'cites', 'paper'): (np.array([0,1,1,0,1,0,6]), np.array([1,2,3,4,4,5,5])),
        ('author', 'writes', 'paper'): (np.array([2,2,3]), np.array([0,1,1]))
    }
)

num_papers = hetero_graph_example.number_of_nodes('paper')
num_authors = hetero_graph_example.number_of_nodes('author')
hetero_graph_example.nodes['paper'].data['feature'] = th.arange(num_papers).view(-1,1)
hetero_graph_example.nodes['author'].data['feature'] = th.arange(num_authors).view(-1,1)
author_feat = F.one_hot(th.arange(num_papers))


In [190]:
hetero_graph_example = dgl.heterograph(
    {
        ('author', 'writes', 'paper'): (np.array([0,1]), np.array([0,0]))
    }
)
num_papers = hetero_graph_example.number_of_nodes('paper')
num_authors = hetero_graph_example.number_of_nodes('author')
hetero_graph_example.nodes['paper'].data['feature'] = th.arange(num_papers).view(-1,1)
hetero_graph_example.nodes['author'].data['feature'] = th.arange(num_authors).view(-1,1)
author_feat = F.one_hot(th.arange(num_authors)).to(torch.float32)
paper_feat = th.tensor([[123,0]])

In [196]:
th.zeros((len(paper_feat), author_feat.shape[0]))

tensor([[0., 0.]])

In [192]:
hetero_graph_example.nodes["author"].data["h"] = author_feat
hetero_graph_example.nodes["paper"].data["h"] = paper_feat
hetero_graph_example["writes"].update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
hetero_graph_example.nodes['paper'].data['h']

tensor([[1., 1.]])

In [175]:
g = dgl.heterograph({
    ('user', 'follows', 'game'): ([0], [1]),
    ('game', 'attracts', 'user'): ([0], [1])
})

In [184]:
g.nodes['user'].data['h'] = torch.tensor([[1., 0.], [0., 1.]])
g.nodes['game'].data['h'] = torch.tensor([[1, 0], [10, 0]])
g["follows"].update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
g.nodes['game'].data['h']

tensor([[0., 0.],
        [1., 0.]])

In [128]:
conv1 = dglnn.HeteroGraphConv({
    "writes": dglnn.GraphConv(2, 2, weight=False),
}, aggregate="sum")

In [129]:
conv1(hetero_graph_example["writes"], {"author": author_feat, "paper": paper_feat})

{'paper': tensor([[0.7071, 0.7071]], grad_fn=<SumBackward1>)}

In [71]:
conv1 = dglnn.HeteroGraphConv({
    "writes": dglnn.GraphConv(4, 4),
}, aggregate="stack")

conv2 = dglnn.HeteroGraphConv({
    "cites": dglnn.GraphConv(4, 4),
})

embedding_paper = nn.Embedding(7, 4)
embedding_author = nn.Embedding(4, 4)

In [72]:
author_feats = embedding_author(hetero_graph_example.nodes['author'].data['feature'])
paper_feats = embedding_paper(hetero_graph_example.nodes['paper'].data['feature'])
node_features = {'author': author_feats, 'paper': paper_feats}

In [73]:
paper = conv1(hetero_graph_example["writes"], node_features)["paper"].squeeze(1)
paper_feats.shape

torch.Size([7, 1, 4])

In [48]:
paper_fea = torch.cat((paper, paper_feats), dim=1)
paper_fea.shape

torch.Size([7, 2, 4])

In [53]:
new_paper_fea = conv2(hetero_graph_example["cites"], {'author': author_feats, 'paper': paper_fea})["paper"]
new_paper_fea.flatten(1).shape

torch.Size([7, 8])

In [78]:
pred = HeteroDotProductPredictor()
model = Model(256, 64, 64, hetero_graph_example.etypes, len(hetero_graph_example.nodes("paper")), len(hetero_graph_example.nodes("author")))
author_feats = hetero_graph_example.nodes['author'].data['feature']
paper_feats = hetero_graph_example.nodes['paper'].data['feature']
node_features = {'author': author_feats, 'paper': paper_feats}
opt = torch.optim.Adam(model.parameters())
k = 1

In [81]:


for epoch in range(50):
    negative_graph = construct_negative_graph(hetero_graph_example, k, ('paper', 'authored', 'author'))
    h = model(hetero_graph_example, node_features)
    pos_score, neg_score = pred.returnScore(hetero_graph_example, negative_graph, h, ('paper', 'authored', 'author'))
    loss = compute_loss_logits(pos_score.squeeze(1), neg_score.squeeze(1))
    opt.zero_grad()
    loss.backward()
    opt.step()
    if epoch % 10 == 0:
        with torch.no_grad():
            acc = accuracy(h, hetero_graph_example)
            h_eval = model(hetero_graph_example_eval, node_features)
            acc_eval = accuracy(h_eval, hetero_graph_example_eval)
            print(acc, acc_eval)


tensor([[ 0.0327, -0.2471,  1.3498, -0.4649,  0.5063,  1.1272, -1.4630, -0.9207,
         -0.2087,  0.8228, -0.2701,  0.5944,  1.9127, -0.5640, -1.8044, -0.8177,
         -0.3373, -0.3140, -0.7660, -1.8931, -1.2541, -2.7131,  0.9276, -1.8957,
          0.0403,  0.0832,  1.2483,  0.2404, -0.0809,  1.8444,  1.1190, -1.1284,
         -0.4637, -0.0424,  0.7864,  0.2597, -0.4501,  0.0497,  1.5300,  0.1888,
         -1.1421, -1.5466,  0.6383, -0.7819, -0.0973,  0.9306, -1.3108, -1.1159,
         -0.5026, -0.3738, -1.3866, -0.1051, -0.4184,  1.7215, -0.6648, -1.2275,
          1.2156,  0.3436,  0.7285, -0.8094, -1.1248, -0.7959,  0.8455,  1.2729]],
       grad_fn=<SelectBackward0>)


RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

# Train on the complete graph

In [24]:
glist, label_dict = load_graphs("./graphs/val_graph.bin")
train_hetero_graph= glist[0]
val_pos_hetero_graph = glist[0]

In [34]:
paper_ids = val_pos_hetero_graph.nodes("paper")
paper_feats = F.one_hot(paper_ids, num_classes = len(paper_ids))

In [20]:
pred = HeteroDotProductPredictor()
model = Model(8638, 256, 256, val_pos_hetero_graph.etypes)
author_feats = val_pos_hetero_graph.nodes['author'].data['feature']
#paper_feats = val_pos_hetero_graph.nodes['paper'].data['feature']
node_features = {'author': author_feats, 'paper': paper_feats}
opt = torch.optim.Adam(model.parameters())
k = 1

for epoch in range(15000):
    print(epoch)
    negative_graph = construct_negative_graph(val_pos_hetero_graph, k, ('paper', 'authored', 'author'))
    h = model(val_pos_hetero_graph, node_features)
    pos_score, neg_score = pred.returnScore(val_pos_hetero_graph, negative_graph, h, ('paper', 'authored', 'author'))
    loss = compute_loss_logits(pos_score.squeeze(1), neg_score.squeeze(1))
    opt.zero_grad()
    loss.backward()
    opt.step()
    if epoch % 10 == 0:
        with torch.no_grad():
            acc = accuracy(h, val_pos_hetero_graph)
            print(acc)

NameError: name 'val_pos_hetero_graph' is not defined

In [43]:
pos_score

tensor([[6.6322],
        [6.6322],
        [5.6405],
        ...,
        [6.4707],
        [6.4707],
        [6.4707]], grad_fn=<GSDDMMBackward>)

In [45]:
neg_score.shape

torch.Size([1443, 1])