# ERNIESage代码解析

本项目主要是为了直接提供一个可以运行ERNIESage模型的代码介绍，以便同学们能够直观感受到ERNIESage的魅力，同时也会对ERNIESage中的部分关键代码进行必要讲解。Let's enjoy!

**ERNIESage**可以很轻松地在PGL中的消息传递范式中进行实现，目前PGL在github上提供了3个版本的ERNIESage模型：
- **ERNIESage v1**: ERNIE 作用于text graph节点上;
- **ERNIESage v2**: ERNIE 作用在text graph的边上;
- **ERNIESage v3**: ERNIE 作用于一阶邻居及起边上;

### 讲解流程
- 数据
- **模型**
- 训练

In [1]:
# 拉取PGL代码，由于github拉取较慢，已经提前拉取完毕了
# !git clone https://github.com/PaddlePaddle/PGL
# !cd PGL/example/erniesage
# 为了正常运行代码，首先我们需要安装以下依赖
!pip install pgl
!pip install easydict
!python3 -m pip install --no-deps paddle-propeller
!pip install paddle-ernie
!pip uninstall -y colorlog
!export CUDAV_VISIBLE_DEVICES=0

## 数据
### 输入example数据集
1. example_data/link_predict/graph_data.txt - 简单的输入文件，格式为每行query \t answer，可作简单的运行实例使用，link predict任务一般直接用图中的边作为训练目标。

In [1]:
! head -n 3 example_data/link_predict/graph_data.txt
! wc -l example_data/link_predict/graph_data.txt

head: 无法打开'example_data/link_predict/graph_data.txt' 读取数据: 没有那个文件或目录
wc: example_data/link_predict/graph_data.txt: 没有那个文件或目录


### 如何表达一个文本图
- 出现过的每一个文本段当作一个节点，比如“黑缘粗角肖叶甲触角有多大？”就是一个节点
- 一行两个节点作为一条边
- 节点的文本段逐字转成id，形成id序列，作为**节点特征**

In [2]:
from preprocessing.dump_graph import dump_graph
from preprocessing.dump_graph import dump_node_feat
from preprocessing.dump_graph import download_ernie_model
from preprocessing.dump_graph import load_config
from pgl.graph_wrapper import BatchGraphWrapper
import propeller.paddle as propeller
import paddle.fluid as F
import paddle.fluid.layers as L
import numpy as np
from preprocessing.dump_graph import load_config
from models.pretrain_model_loader import PretrainedModelLoader
from pgl.graph import MemmapGraph
from models.encoder import linear
from ernie import ErnieModel
np.random.seed(123)
config = load_config("./config/erniesage_link_predict.yaml")

In [3]:
# 将原始QA数据产出一个文本图，并使用grpah.dump存放到 workdir 目录下
dump_graph(config)
dump_node_feat(config)


In [4]:
# MemmapGraph可以将PGL中graph.dump的模型，重新load回来
graph = MemmapGraph("./workdir/") 
# 看一下图基础信息
print("节点", graph.num_nodes,"个") 
print("边", graph.edges, graph.edges.shape)

In [5]:
# 看一下节点特征
print([("%s shape is %s" % (key, str(graph.node_feat[key].shape))) for key in graph.node_feat])
print(graph.node_feat) #  按字的粒度转成ID，每段文本为一个节点，文本全部保留40长度
# 1021个节点，每个节点有长度为40的id序列

## 模型
### ERNIESage V1 模型核心流程
- ERNIE提取节点语义 -> GNN聚合

<img src="https://ai-studio-static-online.cdn.bcebos.com/0ab25a4f0c1647acbcfacc1be2066d47e98ec4f1931d4dcebd209347dc1b5448" alt="图片替换文本" width="300" height="313" align="bottom" />


In [7]:
# ERNIESage V1，ERNIE作用在节点上
class ERNIESageV1Encoder():
    def __init__(self, config):
        self.config = config

    def __call__(self, graph_wrappers, inputs):
        
        # step1. ERNIE提取节点语义
        # 输入每个节点的文本的id序列
        term_ids = graph_wrappers[0].node_feat["term_ids"]
        
        cls = L.fill_constant_batch_size_like(term_ids, [-1, 1], "int64",
                                              self.config.cls_id) # cls [B, 1]
        term_ids = L.concat([cls, term_ids], 1) # term_ids [B, S]
        # [CLS], id1, id2, id3 .. [SEP]

        ernie_model = ErnieModel(self.config.ernie_config) 
        # 获得ERNIE的[CLS]位置的表达
        cls_feat, _ = ernie_model(term_ids) # cls_feat [B, F]

        # step2. GNN聚合
        feature = graphsage_sum(cls_feat, graph_wrappers[0], self.config.hidden_size, "v1_graphsage_sum", "leaky_relu")
        
        final_feats = [
            self.take_final_feature(feature, i, "v1_final_fc") for i in inputs
        ]
        return final_feats
    
    def take_final_feature(self, feature, index, name):
        """take final feature"""
        feat = L.gather(feature, index, overwrite=False)
        feat = linear(feat, self.config.hidden_size, name)
        feat = L.l2_normalize(feat, axis=1)
        return feat


def graphsage_sum(feature, gw, hidden_size, name, act):
    # copy_send
    msg = gw.send(lambda src, dst, edge: src["h"], nfeat_list=[("h", feature)])
    # sum_recv
    neigh_feature = gw.recv(msg, lambda feat: L.sequence_pool(feat, pool_type="sum"))

    self_feature = linear(feature, hidden_size, name+"_l", act)
    neigh_feature = linear(neigh_feature, hidden_size, name+"_r", act)
    output = L.concat([self_feature, neigh_feature], axis=1) # [B, 2H]
    output = L.l2_normalize(output, axis=1)
    return output

In [5]:
# 随机构造些数据
feat_size = 40
feed_dict = {
    "num_nodes": np.array([4]),
    "num_edges": np.array([6]),
    "edges": np.array([[0,1],[1,0],[0,2],[2,0],[0,3],[3,0]]),
    "term_ids": np.random.randint(4, 10000, size=(4, feat_size)),
    "inputs": np.array([0])}
place = F.CUDAPlace(0)
exe = F.Executor(place)

<img src="https://ai-studio-static-online.cdn.bcebos.com/94ab49de20ec4574a5d27e7ad3d23354df5ade177666450ba2f6d4cde11c33b6" alt="图片替换文本" width="300" height="313" align="bottom" />


In [8]:
# 模型v1
erniesage_v1_encoder = ERNIESageV1Encoder(config)

main_prog, start_prog = F.Program(), F.Program()
with F.program_guard(main_prog, start_prog):
    with F.unique_name.guard():
        num_nodes = L.data("num_nodes", [1], False, 'int64')
        num_edges = L.data("num_edges", [1], False, 'int64')
        edges = L.data("edges", [-1, 2], False, 'int64')
        node_feat = L.data("term_ids", [-1, 40], False, 'int64')
        inputs = L.data("inputs", [-1], False, 'int64')

        # 输入图的基本信息（边、点、特征）构造一个graph 
        gw = BatchGraphWrapper(num_nodes, num_edges, edges, {"term_ids": node_feat})
        outputs = erniesage_v1_encoder([gw], [inputs])

exe.run(start_prog)
outputs_np = exe.run(main_prog, feed=feed_dict, fetch_list=[outputs])[0]
print(outputs_np)

### ERNIESage V2 核心代码
- GNN send 文本id -> ERNIE提取边语义 -> GNN recv 聚合邻居语义 -> ERNIE提取中心节点语义并concat

<img src="https://ai-studio-static-online.cdn.bcebos.com/24d5cca257624cc6bb94eeea7a7c3f84512534070c5949a5a9aca8fc8455f52e" alt="图片替换文本" width="500" height="313" align="bottom" />

为了使得大家对下面有关ERNIE模型的部分能够有所了解，这里先贴出ERNIE的主模型框架图。

![](https://ai-studio-static-online.cdn.bcebos.com/8b2bf7e82042474e904e867b415b83fed436281fe75e46dca1f9cb97189172bc)




In [9]:
# ERNIESage V2，ERNIE作用在边上
class ERNIESageV2Encoder():
    def __init__(self, config):
        self.config = config

    def __call__(self, graph_wrappers, inputs):
        gw = graph_wrappers[0]
        term_ids = gw.node_feat["term_ids"] # term_ids [B, S]
        
        # step1. GNN send 文本id
        def ernie_send(src_feat, dst_feat, edge_feat):
            def build_position_ids(term_ids):
                input_mask = L.cast(term_ids > 0, "int64")
                position_ids = L.cumsum(input_mask, axis=1) - 1
                return position_ids
            
            # src_ids, dst_ids 为发送src和接收dst节点分别的文本ID序列
            src_ids, dst_ids = src_feat["term_ids"], dst_feat["term_ids"]

            # 生成[CLS]对应的id列, 并与前半段concat
            cls = L.fill_constant_batch_size_like(
                src_feat["term_ids"], [-1, 1], "int64", self.config.cls_id) # cls [B, 1]
            src_ids = L.concat([cls, src_ids], 1) # src_ids [B, S+1]

            # 将src与dst concat在一起作为完整token ids
            term_ids = L.concat([src_ids, dst_ids], 1) # term_ids [B, 2S+1]
            # [CLS], src_id1, src_id2.. [SEP], dst_id1, dst_id2..[SEP]

            sent_ids = L.concat([L.zeros_like(src_ids), L.ones_like(dst_ids)], 1)
            #   0, 0, 0 .. 0, 1, 1 .. 1 

            position_ids = build_position_ids(term_ids)
            #   0, 1, 2, 3 ..  
            
            # step2. ERNIE提取边语义 
            ernie_model = ErnieModel(self.config.ernie_config)
            cls_feat, _ = ernie_model(term_ids, sent_ids, position_ids)
            # cls_feat 为ERNIE提取的句子级隐向量表达
            return cls_feat

        msg = gw.send(ernie_send, nfeat_list=[("term_ids", term_ids)])
        
        # step3. GNN recv 聚合邻居语义 
        # 接收了邻居的CLS语义表达，sum聚合在一起
        neigh_feature = gw.recv(msg, lambda feat: F.layers.sequence_pool(feat, pool_type="sum"))

        # 为每个节点也拼接一个CLS表达
        cls = L.fill_constant_batch_size_like(term_ids, [-1, 1],
                                              "int64", self.config.cls_id)
        
        term_ids = L.concat([cls, term_ids], 1)
        # [CLS], id1, id2, ... [SEP]
        
        # step4. ERNIE提取中心节点语义并concat
        # 对中心节点过一次ERNIE    
        ernie_model = ErnieModel(self.config.ernie_config)

        # 获取中心节点的语义CLS表达
        self_cls_feat, _ = ernie_model(term_ids)

        hidden_size = self.config.hidden_size        
        self_feature = linear(self_cls_feat, hidden_size, "erniesage_v2_l", "leaky_relu")
        neigh_feature = linear(neigh_feature, hidden_size, "erniesage_v2_r", "leaky_relu")
        output = L.concat([self_feature, neigh_feature], axis=1)
        output = L.l2_normalize(output, axis=1)

        final_feats = [
            self.take_final_feature(output, i, "v2_final_fc") for i in inputs
        ]
        return final_feats

    def take_final_feature(self, feature, index, name):
        """take final feature"""
        feat = L.gather(feature, index, overwrite=False)
        feat = linear(feat, self.config.hidden_size, name)
        feat = L.l2_normalize(feat, axis=1)
        return feat

In [10]:
# 直接run一下
erniesage_v2_encoder = ERNIESageV2Encoder(config)

main_prog, start_prog = F.Program(), F.Program()
with F.program_guard(main_prog, start_prog):
    with F.unique_name.guard():
        num_nodes = L.data("num_nodes", [1], False, 'int64')
        num_edges = L.data("num_edges", [1], False, 'int64')
        edges = L.data("edges", [-1, 2], False, 'int64')
        node_feat = L.data("term_ids", [10, 40], False, 'int64')
        inputs = L.data("inputs", [2], False, 'int64')

        gw = BatchGraphWrapper(num_nodes, num_edges, edges, {"term_ids": node_feat})
        outputs = erniesage_v2_encoder([gw], [inputs])

exe = F.Executor(place)
exe.run(start_prog)
outputs_np = exe.run(main_prog, feed=feed_dict, fetch_list=[outputs])[0]
print(outputs_np)

### ERNIESage V3 核心过程
- GNN send 文本id序列 -> GNN recv 拼接文本id序列 -> ERNIE同时提取中心和多个邻居语义表达
<img src="https://ai-studio-static-online.cdn.bcebos.com/b18ab78738764e88b624d1db8ce5e95c72a4161cd4b845cb80bdf0d5e914cfbc" alt="图片替换文本" width="500" height="313" align="bottom" />


In [13]:
from models.encoder import v3_build_sentence_ids
from models.encoder import v3_build_position_ids

class ERNIESageV3Encoder():
    def __init__(self, config):
        self.config = config

    def __call__(self, graph_wrappers, inputs):
        gw = graph_wrappers[0]
        term_ids = gw.node_feat["term_ids"]

        # step1. GNN send 文本id序列
        # copy_send
        msg = gw.send(lambda src, dst, edge: src["h"], nfeat_list=[("h", term_ids)])

        # step2. GNN recv 拼接文本id序列
        def ernie_recv(term_ids):
            """doc"""
            num_neighbor = self.config.samples[0]
            pad_value = L.zeros([1], "int64")

            # 这里使用seq_pad，将num_neighbor个邻居节点的文本id序列拼接在一下
            # 对于不足num_neighbor个邻居的将会pad到num_neighbor个
            neighbors_term_ids, _ = L.sequence_pad(
                term_ids, pad_value=pad_value, maxlen=num_neighbor) # [B, N*S]

            neighbors_term_ids = L.reshape(neighbors_term_ids, [0, self.config.max_seqlen * num_neighbor])
            return neighbors_term_ids
    
        neigh_term_ids = gw.recv(msg, ernie_recv)
        neigh_term_ids = L.cast(neigh_term_ids, "int64")

        # step3. ERNIE同时提取中心和多个邻居语义表达
        cls = L.fill_constant_batch_size_like(term_ids, [-1, 1], "int64",
                                              self.config.cls_id) # [B, 1]

        # 将中心与多个邻居的文本全部拼接在一起，形成超长的文本（num_nerghbor+1) * seqlen
        multi_term_ids = L.concat([cls, term_ids[:, :-1], neigh_term_ids], 1) # multi_term_ids [B, (N+1)*S]
        # [CLS], center_id1, center_id2..[SEP]n1_id1, n1_id2..[SEP]n2_id1, n2_id2..[SEP]..[SEP]
        slot_seqlen = self.config.max_seqlen
        final_feats = []
        for index in inputs:
            term_ids = L.gather(multi_term_ids, index, overwrite=False)
            position_ids = v3_build_position_ids(term_ids, slot_seqlen)
            sent_ids = v3_build_sentence_ids(term_ids, slot_seqlen)

            # 将需要计算的超长文本，使用Ernie提取CLS位置的语义表达
            ernie_model = ErnieModel(self.config.ernie_config)
            cls_feat, _ = ernie_model(term_ids, sent_ids, position_ids)

            feature = linear(cls_feat, self.config.hidden_size, "v3_final_fc")
            feature = L.l2_normalize(feature, axis=1)
            final_feats.append(feature)
        return final_feats

In [14]:
# 直接run一下
erniesage_v3_encoder = ERNIESageV3Encoder(config)

main_prog, start_prog = F.Program(), F.Program()
with F.program_guard(main_prog, start_prog):
    num_nodes = L.data("num_nodes", [1], False, 'int64')
    num_edges = L.data("num_edges", [1], False, 'int64')
    edges = L.data("edges", [-1, 2], False, 'int64')
    node_feat = L.data("term_ids", [-1, 40], False, 'int64')
    inputs = L.data("inputs", [-1], False, 'int64')

    gw = BatchGraphWrapper(num_nodes, num_edges, edges, {"term_ids": node_feat})
    outputs = erniesage_v3_encoder([gw], [inputs])

exe.run(start_prog)
outputs_np = exe.run(main_prog, feed=feed_dict, fetch_list=[outputs])[0]
print(outputs_np)

## 训练
### link predict任务
以一个link predict的任务为例，读取一个语义图，以上面的边为目标进行无监督的训练

In [17]:
class ERNIESageLinkPredictModel(propeller.train.Model):
    def __init__(self, hparam, mode, run_config):
        self.hparam = hparam
        self.mode = mode
        self.run_config = run_config

    def forward(self, features):
        num_nodes, num_edges, edges, node_feat_index, node_feat_term_ids, user_index, \
            pos_item_index, neg_item_index, user_real_index, pos_item_real_index = features

        node_feat = {"index": node_feat_index, "term_ids": node_feat_term_ids}
        graph_wrapper = BatchGraphWrapper(num_nodes, num_edges, edges,
                                          node_feat)

        #encoder = ERNIESageV1Encoder(self.hparam)
        encoder = ERNIESageV2Encoder(self.hparam)
        #encoder = ERNIESageV3Encoder(self.hparam)

        # 中心节点、邻居节点、随机采样节点 分别提取特征
        outputs = encoder([graph_wrapper],
                          [user_index, pos_item_index, neg_item_index])
        user_feat, pos_item_feat, neg_item_feat = outputs
    
        if self.mode is not propeller.RunMode.PREDICT:
            return user_feat, pos_item_feat, neg_item_feat
        else:
            return user_feat, user_real_index

    def loss(self, predictions, labels):
        user_feat, pos_item_feat, neg_item_feat = predictions
        pos = L.reduce_sum(user_feat * pos_item_feat, -1, keep_dim=True) # 
        #neg = L.reduce_sum(user_feat * neg_item_feat, -1, keep_dim=True)# 60.
        neg = L.matmul(user_feat, neg_item_feat, transpose_y=True) # 80.
        # 距离（中心，邻居）> 距离(中心，随机负)
        loss = L.reduce_mean(L.relu(neg - pos + self.hparam.margin))
        return loss

    def backward(self, loss):
        adam = F.optimizer.Adam(learning_rate=self.hparam['learning_rate'])
        adam.minimize(loss)

    def metrics(self, predictions, label):
        return {}

In [18]:
from link_predict import train
from link_predict import predict

train(config, ERNIESageLinkPredictModel)

In [27]:
predict(config, ERNIESageLinkPredictModel)

In [28]:
! head output/part-0

### 如何评价

为了可以比较清楚地知道Embedding的效果，我们直接通过MRR简单判断一下graphp_data.txt计算出来的Embedding结果，此处将graph_data.txt同时作为训练集和验证集。

In [29]:
!python build_dev.py --path "./example_data/link_predict/graph_data.txt" # 此命令用于将训练数据输出为需要的格式，产生的文件为dev_out.txt

In [30]:
# 接下来，计算MRR得分。
# 注意，运行此代码的前提是，我们已经将config对应的yaml配置文件中的input_data参数修改为了："data.txt"
# 并且注意训练的模型是针对data.txt的，如果不符合，请重新训练模型。
!python mrr.py --emb_path output/part-0

## 总结
通过以上三个版本的模型代码简单的讲解，我们可以知道他们的不同点，其实主要就是在消息传递机制的部分有所不同。ERNIESageV1版本只作用在text graph的节点上，在传递消息(Send阶段)时只考虑了邻居本身的文本信息；而ERNIESageV2版本则作用在了边上，在Send阶段同时考虑了当前节点和其邻居节点的文本信息，达到更好的交互效果，
ERNIESageV3则作用在中心和全部邻居上，使节点之间能够互相attention。

希望通过这一运行实例，可以帮助同学们对ERNIESage有更好的了解和认识，大家快快用起来吧！