In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.nn as dglnn
import dgl.function as fn

In [9]:
hetero_graph, type_ord = torch.load('../tyc_cm/heter_graph.pt')

In [5]:
class HeteroDotProductPredictor(nn.Module):
    def forward(self, graph, h, etype):
        # h是从5.1节中对异构图的每种类型的边所计算的节点表示
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']

In [6]:
def construct_negative_graph(graph, k, etype):
    utype, _, vtype = etype
    src, dst = graph.edges(etype=etype)
    neg_src = src.repeat_interleave(k)
    neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))
    return dgl.heterograph(
        {etype: (neg_src, neg_dst)},
        num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})

In [12]:
class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()
        # 实例化HeteroGraphConv，in_feats是输入特征的维度，out_feats是输出特征的维度，aggregate是聚合函数的类型
        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # 输入是节点的特征字典
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

In [7]:
class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        super().__init__()
        self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
        self.pred = HeteroDotProductPredictor()
    def forward(self, g, neg_g, x, etype):
        h = self.sage(g, x)
        return self.pred(g, h, etype), self.pred(neg_g, h, etype)

In [None]:
def compute_loss(pos_score, neg_score):
    # 间隔损失
    n_edges = pos_score.shape[0]
    return (1 - pos_score.unsqueeze(1) + neg_score.view(n_edges, -1)).clamp(min=0).mean()

k = 5
model = Model(10, 20, 5, hetero_graph.etypes)
# user_feats = hetero_graph.nodes['user'].data['feature']
# item_feats = hetero_graph.nodes['item'].data['feature']
node_features = {'company': torch.rand(92524, 10), 'organize': torch.rand(14290, 10), 'brand': torch.rand(55056, 10)}
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
    negative_graph = construct_negative_graph(hetero_graph, k, ('company', 'comp_touzi', 'company'))
    pos_score, neg_score = model(hetero_graph, negative_graph, node_features, ('company', 'comp_touzi', 'company'))
    loss = compute_loss(pos_score, neg_score)
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

In [10]:
g

Graph(num_nodes={'brand': 55056, 'company': 92524, 'organize': 14290},
      num_edges={('brand', 'brand_belong', 'company'): 51039, ('brand', 'brand_jingpin', 'brand'): 1147103, ('brand', 'brand_rongzi', 'organize'): 150568, ('company', 'comp_gongying', 'company'): 67486, ('company', 'comp_gudong', 'company'): 70566, ('company', 'comp_jingpin', 'brand'): 275658, ('company', 'comp_jingpin', 'company'): 252377, ('company', 'comp_lsgudong', 'company'): 15120, ('company', 'comp_touzi', 'company'): 70047, ('organize', 'org_gktouzi', 'company'): 135706, ('organize', 'org_gongshang', 'company'): 49446, ('organize', 'org_wgktouzi', 'company'): 103839},
      metagraph=[('brand', 'company', 'brand_belong'), ('brand', 'brand', 'brand_jingpin'), ('brand', 'organize', 'brand_rongzi'), ('company', 'company', 'comp_gongying'), ('company', 'company', 'comp_gudong'), ('company', 'company', 'comp_jingpin'), ('company', 'company', 'comp_lsgudong'), ('company', 'company', 'comp_touzi'), ('company', 'bra