In [5]:
import dgl
import torch
import numpy as np
n_users = 1000 #user 1000명
n_items = 500 #item 500개
n_follows = 3000 #follows 3000
n_clicks = 5000 #클릭데이터 5000개
n_dislikes = 500 #싫어요 500개
n_hetero_features = 10
n_user_classes = 5 # user label
n_max_clicks = 10 # clik label

follow_src = np.random.randint(0,n_users, n_follows)
follow_dst = np.random.randint(0,n_users, n_follows)
click_src = np.random.randint(0,n_users,n_clicks) #user idx에서 click 어디에 했는지 즉 5000개
click_dst = np.random.randint(0,n_items,n_clicks) #위에 user에서 어떤 아이템인지
dislike_src = np.random.randint(0,n_users,n_dislikes)
dislike_dst = np.random.randint(0,n_items,n_dislikes)
hetero_graph = dgl.heterograph({
    ('user', 'follow', 'user'): (follow_src, follow_dst),
    ('user', 'followed-by', 'user'): (follow_dst, follow_src),
    ('user', 'click', 'item'): (click_src, click_dst),
    ('item', 'clicked-by', 'user'): (click_dst, click_src),
    ('user', 'dislike', 'item'): (dislike_src, dislike_dst),
    ('item', 'disliked-by', 'user'): (dislike_dst, dislike_src)})

print(hetero_graph)

Graph(num_nodes={'item': 500, 'user': 1000},
      num_edges={('item', 'clicked-by', 'user'): 5000, ('item', 'disliked-by', 'user'): 500, ('user', 'click', 'item'): 5000, ('user', 'dislike', 'item'): 500, ('user', 'follow', 'user'): 3000, ('user', 'followed-by', 'user'): 3000},
      metagraph=[('item', 'user', 'clicked-by'), ('item', 'user', 'disliked-by'), ('user', 'item', 'click'), ('user', 'item', 'dislike'), ('user', 'user', 'follow'), ('user', 'user', 'followed-by')])


In [24]:
#heterogeneous graph의 각 노드 feature와 user의 클래스(노드분류), click edge의 라벨을 설정해준다.
hetero_graph.nodes['user'].data['feature'] = torch.randn(n_users,n_hetero_features)
hetero_graph.nodes['item'].data['feature'] = torch.randn(n_items,n_hetero_features)
hetero_graph.nodes['user'].data['label'] = torch.randint(0,n_user_classes,(n_users,))
hetero_graph.edges['click'].data['label'] = torch.randint(1,n_max_clicks,(n_clicks,)).float()
#hetero_graph.nodes['user'].data['train_mask'] = torch.zeros(n_users, dtype=torch.bool).bernoulli(0.6)
#hetero_graph.edges['click'].data['train_mask'] = torch.zeros(n_clicks, dtype=torch.bool).bernoulli(0.6)

In [25]:
import torch.nn as nn
import dgl.function as fn

class HeteroDotProductPredictor(nn.Module):
    def forward(self, graph, h, etype):
        # h contains the node representations for each node type computed from
        # the GNN defined in the previous section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']

In [37]:
# heterogeneous graph에서 link prediction 또한 homogeneous graph와 크게 다르지않다.
# 예측을 하고 싶은 엣지에 대해서 수행해준다고 생각하면 된다. 즉, edge type을 하나 정하고 그 edge에 대해서 negative sampling을 수행해주면 되는 것.
# 주어진 그래프에서 엣지 타입은 user - click - item, user - dislike - item, user - follow - user 등이 있다.
target_edge = hetero_graph.edges(etype= ('user','click','item'))
target_edge_src , target_edge_dst = target_edge[0], target_edge[1] #target edge를 click으로 잡고 src, dst로 구분해줌
print(target_edge_src)
print(target_edge_dst)
eids = np.arange(hetero_graph.num_edges('click'))
np.random.shuffle(eids)
test_size = int(len(eids)*0.2)
train_size = len(eids)-test_size

test_pos_src, test_pos_dst = target_edge_src[eids[:test_size]], target_edge_dst[eids[:test_size]]
train_pos_src, train_pos_dst = target_edge_src[eids[test_size:]], target_edge_dst[eids[test_size:]]
print(test_pos_src,test_pos_dst)


tensor([196, 451, 943,  ..., 338,  42, 567])
tensor([295,  99, 262,  ..., 153, 318, 462])
tensor([833,  63, 352, 238, 968, 207, 837, 446, 651, 350, 197, 299, 216, 225,
        375, 562, 130, 238, 990,  50, 296, 864, 144, 927, 705, 114, 205, 521,
        100, 288, 892, 804, 790, 602, 339, 531, 570, 508,  60, 972, 617, 514,
        443, 835, 835, 408, 995, 215, 170, 198, 105, 645, 984, 982, 675, 251,
        222, 639, 137, 668, 272, 631, 315, 677, 279,  50,  39,  35, 874, 563,
          4, 236, 878, 743, 391, 799, 734, 602, 971, 779, 325, 798, 753, 391,
        948, 761,  11, 987, 463, 402, 483, 586,  77, 874,  13,  60, 690,  65,
        341, 138,  34, 978, 921, 441, 480, 822, 116, 462, 701, 127, 881, 476,
        774, 247, 405, 824, 927, 386, 423, 356, 395, 369, 601, 810, 226,  53,
        924, 867, 623, 599, 248, 315, 638, 741, 509, 357, 937, 645, 427, 810,
        655, 501, 249, 169, 443, 416, 642, 489, 491, 864, 791, 109, 146, 619,
        867, 896, 709, 211, 414, 747, 476,  70, 318,

In [85]:
#위에서 test와 train을 구분해주었고 이제 negative sampling을 해준다
import scipy.sparse as sp

adj = sp.coo_matrix((np.ones(len(target_edge_src)),(target_edge_src.numpy(),target_edge_dst.numpy())))
#coo_matrix(data,(row,col)) row와 col에 동시에 나타나면 1로 만들어 인접행렬 생성
adj_neg = 1- adj.todense()
neg_src, neg_dst = np.where(adj_neg !=0) #연결 안되어있는 노드 쌍 추출

#negative sampling 숫자는 많을 수 밖에 없음 연결되어있지 않은 모든 상황을 고려하기 때문에
neg_eids = np.random.choice(len(neg_src),len(hetero_graph.edges(etype= ('user','click','item')))//2) #click에 대한 edge숫자의 절반을 샘플링
test_neg_u, test_neg_v = neg_src[neg_eids[:test_size]], neg_dst[neg_eids[:test_size]]
train_neg_u, train_neg_v = neg_src[neg_eids[test_size:]], neg_dst[neg_eids[test_size:]]

train_g = dgl.remove_edges(hetero_graph,etype=('click'),eids= eids[:test_size]) #해당 edge 타입만 지울것이므로 etype지정
test_g = dgl.remove_edges(hetero_graph,etype=('click'),eids= eids[test_size:])
print(train_g) 

Graph(num_nodes={'item': 500, 'user': 1000},
      num_edges={('item', 'clicked-by', 'user'): 5000, ('item', 'disliked-by', 'user'): 500, ('user', 'click', 'item'): 4000, ('user', 'dislike', 'item'): 500, ('user', 'follow', 'user'): 3000, ('user', 'followed-by', 'user'): 3000},
      metagraph=[('item', 'user', 'clicked-by'), ('item', 'user', 'disliked-by'), ('user', 'item', 'click'), ('user', 'item', 'dislike'), ('user', 'user', 'follow'), ('user', 'user', 'followed-by')])


In [87]:
import dgl.nn as dglnn
import torch.nn.functional as F
import torch.nn as nn
class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h
class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        super().__init__()
        self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
        self.pred = HeteroDotProductPredictor()
    def forward(self, g, neg_g, x, etype):
        h = self.sage(g, x)
        return self.pred(g, h, etype), self.pred(neg_g, h, etype)

In [88]:
import torch.optim as optim
from sklearn.metrics import roc_auc_score
def criterion(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score])
    labels = torch.cat([torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])])
    labels = labels.unsqueeze(1)
    return F.binary_cross_entropy_with_logits(scores, labels)

def accuracy(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score])
    labels = torch.cat([torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).numpy()
    #labels = labels.unsqueeze(1)

    return roc_auc_score(labels, scores)


In [102]:
class HeteroDotProductPredictor(nn.Module):
    def forward(self, graph, h, etype):
        # h contains the node representations for each node type computed from
        # the GNN defined in the previous section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']
pred = HeteroDotProductPredictor()
model = Model(10, 32,1,hetero_graph.etypes)
num_k = 5
user_feats = train_g.nodes['user'].data['feature']
item_feats = train_g.nodes['item'].data['feature']
node_features = {'user': user_feats, 'item': item_feats}
all_logits = []
def construct_negative_graph(graph, k, etype):
    utype, _, vtype = etype
    src, dst = graph.edges(etype=etype)
    neg_src = src.repeat_interleave(k)
    neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))
    return dgl.heterograph(
        {etype: (neg_src, neg_dst)},
        num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})
print(train_g)
print(test_g)
optimizer = optim.Adam(model.parameters(),lr = 0.05)
for epoch in range(100):
    negative_graph = construct_negative_graph(train_g, num_k, ('user', 'click', 'item'))
    pos_score,neg_score = model(train_g,negative_graph, node_features,('user','click','item'))
    loss = criterion(pos_score, neg_score)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch+1) % 5 == 0 :
        print(f'In epoch {epoch+1}, loss: {loss:.4f}')

negative_graph = construct_negative_graph(test_g, num_k, ('user', 'click', 'item'))

pos_score,neg_score = model(test_g,negative_graph, node_features,('user','click','item'))
with torch.no_grad():
    #pos_score = pred(test_pos_g, h)
    #neg_score = pred(test_neg_g, h)
    print('AUC', accuracy(pos_score, neg_score))

Graph(num_nodes={'item': 500, 'user': 1000},
      num_edges={('item', 'clicked-by', 'user'): 5000, ('item', 'disliked-by', 'user'): 500, ('user', 'click', 'item'): 4000, ('user', 'dislike', 'item'): 500, ('user', 'follow', 'user'): 3000, ('user', 'followed-by', 'user'): 3000},
      metagraph=[('item', 'user', 'clicked-by'), ('item', 'user', 'disliked-by'), ('user', 'item', 'click'), ('user', 'item', 'dislike'), ('user', 'user', 'follow'), ('user', 'user', 'followed-by')])
Graph(num_nodes={'item': 500, 'user': 1000},
      num_edges={('item', 'clicked-by', 'user'): 5000, ('item', 'disliked-by', 'user'): 500, ('user', 'click', 'item'): 1000, ('user', 'dislike', 'item'): 500, ('user', 'follow', 'user'): 3000, ('user', 'followed-by', 'user'): 3000},
      metagraph=[('item', 'user', 'clicked-by'), ('item', 'user', 'disliked-by'), ('user', 'item', 'click'), ('user', 'item', 'dislike'), ('user', 'user', 'follow'), ('user', 'user', 'followed-by')])
In epoch 5, loss: 0.6260
In epoch 10, loss