## 1. 形参设定

In [849]:
import argparse

parser = argparse.ArgumentParser()

parser.add_argument('--input', type=str,
                    default='C:\\Users\\Wei Zhou\\Desktop\\test\\图神经网络几个算法\\异构图网络算法\\GATNE-master\\data\\example',
                    help='Input dataset path')

parser.add_argument('--features', type=str, default=None,
                    help='Input node features')

parser.add_argument('--walk-file', type=str, default=None,
                    help='Input random walks')

parser.add_argument('--epoch', type=int, default=100,
                    help='Number of epoch. Default is 100.')

parser.add_argument('--batch-size', type=int, default=64,
                    help='Number of batch_size. Default is 64.')

parser.add_argument('--eval-type', type=str, default='all',
                    help='The edge type(s) for evaluation.')

parser.add_argument('--schema', type=str, default=None,
                    help='The metapath schema (e.g., U-I-U,I-U-I).')

parser.add_argument('--dimensions', type=int, default=200,
                    help='Number of dimensions. Default is 200.')

parser.add_argument('--edge-dim', type=int, default=10,
                    help='Number of edge embedding dimensions. Default is 10.')

parser.add_argument('--att-dim', type=int, default=20,
                    help='Number of attention dimensions. Default is 20.')

parser.add_argument('--walk-length', type=int, default=5,
                    help='Length of walk per source. Default is 10.')

parser.add_argument('--num-walks', type=int, default=20,
                    help='Number of walks per source. Default is 20.')

parser.add_argument('--window-size', type=int, default=5,
                    help='Context size for optimization. Default is 5.')

parser.add_argument('--negative-samples', type=int, default=5,
                    help='Negative samples for optimization. Default is 5.')

parser.add_argument('--neighbor-samples', type=int, default=10,
                    help='Neighbor samples for aggregation. Default is 10.')

parser.add_argument('--patience', type=int, default=5,
                    help='Early stopping patience. Default is 5.')

parser.add_argument('--num-workers', type=int, default=16,
                    help='Number of workers for generating random walks. Default is 16.')
args, _ = parser.parse_known_args()
file_name = args.input
print(args)

Namespace(input='C:\\Users\\Wei Zhou\\Desktop\\test\\图神经网络几个算法\\异构图网络算法\\GATNE-master\\data\\example', features=None, walk_file=None, epoch=100, batch_size=64, eval_type='all', schema=None, dimensions=200, edge_dim=10, att_dim=20, walk_length=5, num_walks=20, window_size=5, negative_samples=5, neighbor_samples=10, patience=5, num_workers=16)


## 2. 数据加载

## 加载训练集数据

In [850]:
def load_training_data(f_name):
    print('We are loading data from:', f_name)
    ## 边类型字典
    edge_data_by_type = dict()
    all_nodes = list()
    with open(f_name, 'r') as f:
        for line in f:
            ##这里最后一行是空的所以只取到倒数第二行
            words = line[:-1].split(' ')
            ## 第一列是表示的边种类，如果字典中没该边种类的键则将该键添加
            if words[0] not in edge_data_by_type:
                edge_data_by_type[words[0]] = list()
            ##在对应的边类型下添加该边类型下的节点和节点的连接关系
            x, y = words[1], words[2]
            edge_data_by_type[words[0]].append((x, y))
            all_nodes.append(x)
            all_nodes.append(y)
    all_nodes = list(set(all_nodes))
    print('Total training nodes: ' + str(len(all_nodes)))
    return edge_data_by_type

## 加载节点特征

In [851]:
def load_feature_data(f_name):
    feature_dic = {}
    with open(f_name, 'r') as f:
        first = True
        for line in f:
            if first:
                first = False
                continue
            items = line.strip().split()
            feature_dic[items[0]] = items[1:]
    return feature_dic


## 加载test和valid训练集

In [852]:
def load_testing_data(f_name):
    print('We are loading data from:', f_name)
    ## 节点和节点之间的边连接正确的字典
    true_edge_data_by_type = dict()
    ## 节点和节点之间的边连接错误的字典
    false_edge_data_by_type = dict()
    all_nodes = list()

    with open(f_name, 'r') as f:
        for line in f:
            words = line[:-1].split(' ')
            x, y = words[1], words[2]
            ##如果连接正确
            if int(words[3]) == 1:
                ##判断是否有该类型的边连接，如果没有则添加新的边连接类型列表
                if words[0] not in true_edge_data_by_type:
                    true_edge_data_by_type[words[0]] = list()
                ##在该类型的边连接下添加节点连接
                true_edge_data_by_type[words[0]].append((x, y))
            ##反之亦然
            else:
                if words[0] not in false_edge_data_by_type:
                    false_edge_data_by_type[words[0]] = list()
                false_edge_data_by_type[words[0]].append((x, y))
            all_nodes.append(x)
            all_nodes.append(y)
    all_nodes = list(set(all_nodes))
    return true_edge_data_by_type, false_edge_data_by_type

In [853]:
if args.features is not None:
    feature_dic = load_feature_data(args.features)
else:
    feature_dic = None

training_data_by_type = load_training_data(file_name + "/train.txt")
valid_true_data_by_edge, valid_false_data_by_edge = load_testing_data(
    file_name + "/valid.txt"
)
testing_true_data_by_edge, testing_false_data_by_edge = load_testing_data(
    file_name + "/test.txt"
)

We are loading data from: C:\Users\Wei Zhou\Desktop\test\图神经网络几个算法\异构图网络算法\GATNE-master\data\example/train.txt
Total training nodes: 511
We are loading data from: C:\Users\Wei Zhou\Desktop\test\图神经网络几个算法\异构图网络算法\GATNE-master\data\example/valid.txt
We are loading data from: C:\Users\Wei Zhou\Desktop\test\图神经网络几个算法\异构图网络算法\GATNE-master\data\example/test.txt


## 3. 获取训练语料

In [854]:
def generate(network_data, num_walks, walk_length, schema, file_name, window_size, num_workers, walk_file):
    ##如果已经有游走序列则直接导入并加载——3-1-1 load_walks
    if walk_file is not None:
        all_walks = load_walks(walk_file)
    ##如果没有则建立随机游走——3-1-2 generate_walks
    else:
        all_walks = generate_walks(network_data, num_walks, walk_length, schema, file_name, num_workers)
        ##保存随机游走序列结果——3-1-3 save_walks
        save_walks(file_name + '/walks.txt', all_walks)
    ##生成训练语料——3-2 generate_vocab
    vocab, index2word = generate_vocab(all_walks)
    ##生成节点对——3-3 generate_pairs
    train_pairs = generate_pairs(all_walks, vocab, window_size, num_workers)

    return vocab, index2word, train_pairs

## 3-1-1 load_walks

In [855]:
def load_walks(walk_file):
    print('Loading walks')
    all_walks = []
    with open(walk_file, 'r') as f:
        for line in f:
            content = line.strip().split()
            layer_id = int(content[0])
            if layer_id >= len(all_walks):
                all_walks.append([])
            all_walks[layer_id].append(content[1:])
    return all_walks

## 3-1-2 生成随机游走序列——generate_walks

In [856]:
import random
import multiprocessing
from tqdm import tqdm


def generate_walks(network_data, num_walks, walk_length, schema, file_name, num_workers):
    ## 如果有节点种类则导入节点种类——3-1-2-1 load_node_type
    if schema is not None:
        node_type = load_node_type(file_name + '/node_type.txt')
    else:
        node_type = None
    ##建立储存随机游走序列的列表
    all_walks = []
    ##遍历储存不同边类型连接的字典
    for layer_id, layer_name in enumerate(network_data):
        tmp_data = network_data[layer_name]
        ##开始对每个边类型层上进行随机游走——RWGraph 3-1-2-3、get_G_from_edge 3-1-2-2
        layer_walker = RWGraph(get_G_from_edges(tmp_data), node_type, num_workers)
        print('Generating random walks for layer', layer_id)
        layer_walks = layer_walker.simulate_walks(num_walks, walk_length, schema=schema)

        all_walks.append(layer_walks)

    print('Finish generating the walks')

    return all_walks

## 3-1-2-1 获取节点种类——load_node_type

In [857]:
def load_node_type(f_name):
    print('We are loading node type from:', f_name)
    node_type = {}
    with open(f_name, 'r') as f:
        for line in f:
            items = line.strip().split()
            node_type[items[0]] = items[1]
    return node_type

## 3-1-2-2 根据节点连接关系建立边得到的是节点和哪些节点相连的字典——get_G_from_edge

In [858]:
from collections import defaultdict
def get_G_from_edges(edges):
    edge_dict = defaultdict(set)
    for edge in edges:
        u, v = str(edge[0]), str(edge[1])
        edge_dict[u].add(v)
        edge_dict[v].add(u)
    return edge_dict

## 3-1-2-3 随机游走图——RWGraph

In [859]:
class RWGraph():
    def __init__(self, nx_G, node_type_arr=None, num_workers=20):
        ##图结构数据
        self.G = nx_G
        ##节点类型数组
        self.node_type = node_type_arr
        ##游走次数
        self.num_workers = num_workers

    ##为节点生成多次游走的迭代器
    def node_list(self, nodes, num_walks):
        for loop in range(num_walks):
            for node in nodes:
                yield node

    ##模拟随机游走并收集结果
    def simulate_walks(self, num_walks, walk_length, schema=None):
        ##收集游走序列
        all_walks = []
        ##获取该边类型层的节点
        nodes = list(self.G.keys())
        random.shuffle(nodes)
    
        if schema is None:
            for node in tqdm(self.node_list(nodes, num_walks)):
                all_walks.append(self.walk((walk_length, node, '')))
        else:
            schema_list = schema.split(',')
            for schema_iter in schema_list:
                for node in tqdm(self.node_list(nodes, num_walks)):
                    if schema_iter.split('-')[0] == self.node_type[node]:
                        all_walks.append(self.walk((walk_length, node, schema_iter)))
        return all_walks
##3-1-2-3-1 根据参数模拟单次随机游走——walk    
    def walk(self,args):
        print(args)
        walk_length, start, schema = args
        # Simulate a random walk starting from start node.
        rand = random.Random()
    
        if schema:
            schema_items = schema.split('-')
            assert schema_items[0] == schema_items[-1]
    
        walk = [start]
        while len(walk) < walk_length:
            print(len(walk))
            print(walk_length)
            cur = walk[-1]
            candidates = []
            for node in self.G[cur]:
                if schema == '' or self.node_type[node] == schema_items[len(walk) % (len(schema_items) - 1)]:
                    candidates.append(node)
            if candidates:
                walk.append(rand.choice(candidates))
            else:
                break
        return [str(node) for node in walk]


## 3-1-3 保存随机游走

In [860]:
def save_walks(walk_file, all_walks):
    with open(walk_file, 'w') as f:
        for layer_id, walks in enumerate(all_walks):
            print('Saving walks for layer', layer_id)
            for walk in tqdm(walks):
                f.write(' '.join([str(layer_id)] + [str(x) for x in walk]) + '\n')

## 3-2 建立节点索引映射——generate_vocab

In [861]:
from six import iteritems
##基于随机游走来生成训练语料
def generate_vocab(all_walks):
##创建空列表来储存单词
    index2word = []
    ##计算每个单词出现的次数
    raw_vocab = defaultdict(int)
    ##这部分可以看看walk.txt，这部分是对随机游走的结果进行一个遍历
    for layer_id, walks in enumerate(all_walks):
        print('Counting vocab for layer', layer_id)
        ##对每个层中的每个路径遍历它的节点
        for walk in tqdm(walks):
            ##路径中出现的节点进行计数
            for word in walk:
                raw_vocab[word] += 1
##创建新字典
    vocab = {}
    ##对之前收集到了所有路径中的节点计数进行遍历
    for word, v in iteritems(raw_vocab):
        ##利用vocab字典记录每个节点的数量和它的索引并将其添加到index2word中
        vocab[word] = Vocab(count=v, index=len(index2word))
        index2word.append(word)
    ##进行排序
    index2word.sort(key=lambda word: vocab[word].count, reverse=True)
    ##更新索引，使其与排序后的index2word相匹配
    for i, word in enumerate(index2word):
        vocab[word].index = i

    return vocab, index2word


class Vocab(object):

    def __init__(self, count, index):
        self.count = count
        self.index = index

## 3-3 建立训练的节点对——generate_pairs

In [862]:
def generate_pairs(all_walks, vocab, window_size, num_workers):
## 训练节点对列表
    pairs = []
## 窗口大小的一般包含的目标节点左右的节点
    skip_window = window_size // 2
## 对每层的随机游走序列进行遍历
    for layer_id, walks in enumerate(all_walks):
        print('Generating training pairs for layer', layer_id)
        ##对每个层的每个路径进行遍历
        for walk in tqdm(walks):
            ##对目标节点的左右两侧生成节点对
            for i in range(len(walk)):
                for j in range(1, skip_window + 1):
                    ##目标节点和左侧邻居的节点对
                    if i - j >= 0:
                        pairs.append((vocab[walk[i]].index, vocab[walk[i - j]].index, layer_id))
                    ##目标单词和右侧的邻居的节点对
                    if i + j < len(walk):
                        pairs.append((vocab[walk[i]].index, vocab[walk[i + j]].index, layer_id))
                ##这些节点对是通过vocab来找其索引然后附加当层的ID
    return pairs

## 3-具体使用

In [863]:
vocab, index2word, train_pairs = generate(training_data_by_type, args.num_walks, args.walk_length, args.schema,
                                          file_name, args.window_size, args.num_workers, args.walk_file)

Generating random walks for layer 0


0it [00:00, ?it/s]

(5, '17266', '')
1
5
2
5
3
5
4
5
(5, '29765', '')
1
5
2
5
3
5
4
5
(5, '86501', '')
1
5
2
5
3
5
4
5
(5, '420720', '')
1
5
2
5
3
5
4
5
(5, '8832', '')
1
5
2
5
3
5
4
5
(5, '321498', '')
1
5
2
5
3
5
4
5
(5, '32335', '')
1
5
2
5
3
5
4
5
(5, '22526', '')
1
5
2
5
3
5
4
5
(5, '13325', '')
1
5
2
5
3
5
4
5
(5, '302946', '')
1
5
2
5
3
5
4
5
(5, '454764', '')
1
5
2
5
3
5
4
5
(5, '438438', '')
1
5
2
5
3
5
4
5
(5, '86917', '')
1
5
2
5
3
5
4
5
(5, '82578', '')
1
5
2
5
3
5
4
5
(5, '154770', '')
1
5
2
5
3
5
4
5
(5, '82011', '')
1
5
2
5
3
5
4
5
(5, '15195', '')
1
5
2
5
3
5
4
5
(5, '57219', '')
1
5
2
5
3
5
4
5
(5, '20708', '')
1
5
2
5
3
5
4
5
(5, '206148', '')
1
5
2
5
3
5
4
5
(5, '394495', '')
1
5
2
5
3
5
4
5
(5, '267374', '')
1
5
2
5
3
5
4
5
(5, '8347', '')
1
5
2
5
3
5
4
5
(5, '84823', '')
1
5
2
5
3
5
4
5
(5, '8687', '')
1
5
2
5
3
5
4
5
(5, '128580', '')
1
5
2
5
3
5
4
5
(5, '196373', '')
1
5
2
5
3
5
4
5
(5, '13931', '')
1
5
2
5
3
5
4
5
(5, '459812', '')
1
5
2
5
3
5
4
5
(5, '232514', '')
1
5
2
5
3
5
4
5


2644it [00:00, 26325.79it/s]

1
5
2
5
3
5
4
5
(5, '218752', '')
1
5
2
5
3
5
4
5
(5, '10433', '')
1
5
2
5
3
5
4
5
(5, '6588', '')
1
5
2
5
3
5
4
5
(5, '149992', '')
1
5
2
5
3
5
4
5
(5, '283920', '')
1
5
2
5
3
5
4
5
(5, '18341', '')
1
5
2
5
3
5
4
5
(5, '298016', '')
1
5
2
5
3
5
4
5
(5, '117468', '')
1
5
2
5
3
5
4
5
(5, '14665', '')
1
5
2
5
3
5
4
5
(5, '46303', '')
1
5
2
5
3
5
4
5
(5, '69924', '')
1
5
2
5
3
5
4
5
(5, '22306', '')
1
5
2
5
3
5
4
5
(5, '356326', '')
1
5
2
5
3
5
4
5
(5, '60145', '')
1
5
2
5
3
5
4
5
(5, '4155', '')
1
5
2
5
3
5
4
5
(5, '4147', '')
1
5
2
5
3
5
4
5
(5, '270690', '')
1
5
2
5
3
5
4
5
(5, '464207', '')
1
5
2
5
3
5
4
5
(5, '213856', '')
1
5
2
5
3
5
4
5
(5, '25687', '')
1
5
2
5
3
5
4
5
(5, '486089', '')
1
5
2
5
3
5
4
5
(5, '126510', '')
1
5
2
5
3
5
4
5
(5, '21659', '')
1
5
2
5
3
5
4
5
(5, '2816', '')
1
5
2
5
3
5
4
5
(5, '6806', '')
1
5
2
5
3
5
4
5
(5, '386286', '')
1
5
2
5
3
5
4
5
(5, '218753', '')
1
5
2
5
3
5
4
5
(5, '15674', '')
1
5
2
5
3
5
4
5
(5, '90308', '')
1
5
2
5
3
5
4
5
(5, '86754', '')
1


5289it [00:00, 26395.41it/s]

1
5
2
5
3
5
4
5
(5, '20708', '')
1
5
2
5
3
5
4
5
(5, '206148', '')
1
5
2
5
3
5
4
5
(5, '394495', '')
1
5
2
5
3
5
4
5
(5, '267374', '')
1
5
2
5
3
5
4
5
(5, '8347', '')
1
5
2
5
3
5
4
5
(5, '84823', '')
1
5
2
5
3
5
4
5
(5, '8687', '')
1
5
2
5
3
5
4
5
(5, '128580', '')
1
5
2
5
3
5
4
5
(5, '196373', '')
1
5
2
5
3
5
4
5
(5, '13931', '')
1
5
2
5
3
5
4
5
(5, '459812', '')
1
5
2
5
3
5
4
5
(5, '232514', '')
1
5
2
5
3
5
4
5
(5, '59397', '')
1
5
2
5
3
5
4
5
(5, '40413', '')
1
5
2
5
3
5
4
5
(5, '24532', '')
1
5
2
5
3
5
4
5
(5, '60487', '')
1
5
2
5
3
5
4
5
(5, '1511', '')
1
5
2
5
3
5
4
5
(5, '34697', '')
1
5
2
5
3
5
4
5
(5, '285429', '')
1
5
2
5
3
5
4
5
(5, '4064', '')
1
5
2
5
3
5
4
5
(5, '20107', '')
1
5
2
5
3
5
4
5
(5, '28726', '')
1
5
2
5
3
5
4
5
(5, '57735', '')
1
5
2
5
3
5
4
5
(5, '24492', '')
1
5
2
5
3
5
4
5
(5, '351190', '')
1
5
2
5
3
5
4
5
(5, '87928', '')
1
5
2
5
3
5
4
5
(5, '139873', '')
1
5
2
5
3
5
4
5
(5, '15526', '')
1
5
2
5
3
5
4
5
(5, '12626', '')
1
5
2
5
3
5
4
5
(5, '56915', '')
1
5


8640it [00:00, 26304.00it/s]


(5, '2', '')
1
5
2
5
3
5
4
5
(5, '7662', '')
1
5
2
5
3
5
4
5
(5, '126515', '')
1
5
2
5
3
5
4
5
(5, '158569', '')
1
5
2
5
3
5
4
5
(5, '164772', '')
1
5
2
5
3
5
4
5
(5, '53030', '')
1
5
2
5
3
5
4
5
(5, '339709', '')
1
5
2
5
3
5
4
5
(5, '87188', '')
1
5
2
5
3
5
4
5
(5, '119867', '')
1
5
2
5
3
5
4
5
(5, '296568', '')
1
5
2
5
3
5
4
5
(5, '128852', '')
1
5
2
5
3
5
4
5
(5, '20058', '')
1
5
2
5
3
5
4
5
(5, '8738', '')
1
5
2
5
3
5
4
5
(5, '322313', '')
1
5
2
5
3
5
4
5
(5, '15690', '')
1
5
2
5
3
5
4
5
(5, '302002', '')
1
5
2
5
3
5
4
5
(5, '16279', '')
1
5
2
5
3
5
4
5
(5, '46320', '')
1
5
2
5
3
5
4
5
(5, '379232', '')
1
5
2
5
3
5
4
5
(5, '44641', '')
1
5
2
5
3
5
4
5
(5, '191450', '')
1
5
2
5
3
5
4
5
(5, '6586', '')
1
5
2
5
3
5
4
5
(5, '81753', '')
1
5
2
5
3
5
4
5
(5, '218752', '')
1
5
2
5
3
5
4
5
(5, '10433', '')
1
5
2
5
3
5
4
5
(5, '6588', '')
1
5
2
5
3
5
4
5
(5, '149992', '')
1
5
2
5
3
5
4
5
(5, '283920', '')
1
5
2
5
3
5
4
5
(5, '18341', '')
1
5
2
5
3
5
4
5
(5, '298016', '')
1
5
2
5
3
5
4
5
(5,

3038it [00:00, 30377.35it/s]

(5, '466668', '')
1
5
2
5
3
5
4
5
(5, '206148', '')
1
5
2
5
3
5
4
5
(5, '53030', '')
1
5
2
5
3
5
4
5
(5, '11628', '')
1
5
2
5
3
5
4
5
(5, '355892', '')
1
5
2
5
3
5
4
5
(5, '33761', '')
1
5
2
5
3
5
4
5
(5, '206599', '')
1
5
2
5
3
5
4
5
(5, '4142', '')
1
5
2
5
3
5
4
5
(5, '14704', '')
1
5
2
5
3
5
4
5
(5, '98579', '')
1
5
2
5
3
5
4
5
(5, '101226', '')
1
5
2
5
3
5
4
5
(5, '86754', '')
1
5
2
5
3
5
4
5
(5, '163180', '')
1
5
2
5
3
5
4
5
(5, '87481', '')
1
5
2
5
3
5
4
5
(5, '46325', '')
1
5
2
5
3
5
4
5
(5, '75311', '')
1
5
2
5
3
5
4
5
(5, '90308', '')
1
5
2
5
3
5
4
5
(5, '32655', '')
1
5
2
5
3
5
4
5
(5, '6277', '')
1
5
2
5
3
5
4
5
(5, '44634', '')
1
5
2
5
3
5
4
5
(5, '33210', '')
1
5
2
5
3
5
4
5
(5, '2631', '')
1
5
2
5
3
5
4
5
(5, '4024', '')
1
5
2
5
3
5
4
5
(5, '58209', '')
1
5
2
5
3
5
4
5
(5, '31832', '')
1
5
2
5
3
5
4
5
(5, '47173', '')
1
5
2
5
3
5
4
5
(5, '57219', '')
1
5
2
5
3
5
4
5
(5, '33066', '')
1
5
2
5
3
5
4
5
(5, '406546', '')
1
5
2
5
3
5
4
5
(5, '49250', '')
1
5
2
5
3
5
4
5
(5, '43

4660it [00:00, 29922.53it/s]

Finish generating the walks





Saving walks for layer 0


100%|██████████| 8640/8640 [00:00<00:00, 639041.87it/s]


Saving walks for layer 1


100%|██████████| 4660/4660 [00:00<00:00, 696311.24it/s]


Counting vocab for layer 0


100%|██████████| 8640/8640 [00:00<00:00, 1565863.83it/s]


Counting vocab for layer 1


100%|██████████| 4660/4660 [00:00<00:00, 1553075.62it/s]


Generating training pairs for layer 0


100%|██████████| 8640/8640 [00:00<00:00, 193873.24it/s]


Generating training pairs for layer 1


100%|██████████| 4660/4660 [00:00<00:00, 15411.94it/s]


## 4.模型训练的预准备

## 4-1 形参调用

In [864]:
import torch
import numpy as np

## 获取边类型
edge_types = list(training_data_by_type.keys())
##获取通过随机游走后的节点数量
num_nodes = len(index2word)
##边类型数量
edge_type_count = len(edge_types)
epochs = args.epoch
batch_size = args.batch_size
##设定base_embedding的size
embedding_size = args.dimensions
##设定edge_embedding的size
embedding_u_size = args.edge_dim
## 设定edge_embedding的数量
u_num = edge_type_count
## 设定负采样数量
num_sampled = args.negative_samples
## 注意力机制的维度
dim_a = args.att_dim
## 注意力头的数量
att_head = 1
## 邻居采样个数
neighbor_samples = args.neighbor_samples
## 设置设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## 4-2 创建基于训练语料的邻居节点

In [865]:
def generate_neighbors(network_data, vocab, num_nodes, edge_types, neighbor_samples):
    edge_type_count = len(edge_types)
    ## 创建三维列表，其大小为“num_nodes x edge_type_count x []。里面是对于每个节点对于每种边类型的边的邻居
    neighbors = [[[] for __ in range(edge_type_count)] for _ in range(num_nodes)]
    ##对每种边类型遍历
    for r in range(edge_type_count):
        print('Generating neighbors for layer', r)
        ##遍历图中所有的边
        g = network_data[edge_types[r]]
        for (x, y) in tqdm(g):
        ##对每个边进行操作
            ##找到节点’x‘和节点’y‘在vocab中的索引
            ix = vocab[x].index
            iy = vocab[y].index
            ##将作为到该边类型下的邻居节点
            neighbors[ix][r].append(iy)
            neighbors[iy][r].append(ix)
    ##对所有节点进行遍历
        for i in range(num_nodes):
            ##没有邻居节点则及那个自身添加为邻居节点
            if len(neighbors[i][r]) == 0:
                neighbors[i][r] = [i] * neighbor_samples
            ##如果邻居节点数量小于采样数量则进行随机重采样
            elif len(neighbors[i][r]) < neighbor_samples:
                neighbors[i][r].extend(
                    list(np.random.choice(neighbors[i][r], size=neighbor_samples - len(neighbors[i][r]))))
            ##如果邻居节点数量大于采样数量则随机选择被采样的邻居节点
            elif len(neighbors[i][r]) > neighbor_samples:
                neighbors[i][r] = list(np.random.choice(neighbors[i][r], size=neighbor_samples))
    return neighbors

In [866]:
neighbors = generate_neighbors(training_data_by_type, vocab, num_nodes, edge_types, neighbor_samples)

Generating neighbors for layer 0


100%|██████████| 2683/2683 [00:00<00:00, 2684474.63it/s]


Generating neighbors for layer 1


100%|██████████| 792/792 [00:00<00:00, 788672.55it/s]


## 4-3节点特征设定

In [867]:
features = None
if feature_dic is not None:
    feature_dim = len(list(feature_dic.values())[0])
    print('feature dimension: ' + str(feature_dim))
    features = np.zeros((num_nodes, feature_dim), dtype=np.float32)
    for key, value in feature_dic.items():
        if key in vocab:
            features[vocab[key].index, :] = np.array(value)
    features = torch.FloatTensor(features).to(device)

## 5 模型构建

In [839]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from numpy import random
from torch.nn.parameter import Parameter
import math


class GATNEModel(nn.Module):
    def __init__(
            self, num_nodes, embedding_size, embedding_u_size, edge_type_count, dim_a, features
    ):
        super(GATNEModel, self).__init__()
        ## 节点数量
        self.num_nodes = num_nodes
        ## 嵌入维度
        self.embedding_size = embedding_size
        ## 嵌入转化维度
        self.embedding_u_size = embedding_u_size
        ## 边种类数量
        self.edge_type_count = edge_type_count
        ## edge type 注意力机制维度
        self.dim_a = dim_a
        self.features = None
 ##如果有节点属性即GATNE-I
        if features is not None:
            ##节点特征
            self.features = features
            feature_dim = self.features.shape[-1]
            ##base下的节点特征转化函数——hz
            self.embed_trans = Parameter(torch.FloatTensor(feature_dim, embedding_size))
            ## 不同边下的节点embedding转化函数——gz，r
            self.u_embed_trans = Parameter(torch.FloatTensor(edge_type_count, feature_dim, embedding_u_size))
            
##如果没有节点属性即GATNE-T
        else:
            ##创建base下的节点embedding，bi
            self.node_embeddings = Parameter(torch.FloatTensor(num_nodes, embedding_size))
            
            ##创建不同边的节点embedding，ui,r
            self.node_type_embeddings = Parameter(
                torch.FloatTensor(num_nodes, edge_type_count, embedding_u_size)
            )
           
##转换矩阵
        self.trans_weights = Parameter(
            torch.FloatTensor(edge_type_count, embedding_u_size, embedding_size)
        )
        self.trans_weights_s1 = Parameter(
            torch.FloatTensor(edge_type_count, embedding_u_size, dim_a)
        )
        self.trans_weights_s2 = Parameter(torch.FloatTensor(edge_type_count, dim_a, 1))

        self.reset_parameters()
        
## 初始化参数
    def reset_parameters(self):
        if self.features is not None:
            self.embed_trans.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
            self.u_embed_trans.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
        else:
            self.node_embeddings.data.uniform_(-1.0, 1.0)
            self.node_type_embeddings.data.uniform_(-1.0, 1.0)
        self.trans_weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
        self.trans_weights_s1.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
        self.trans_weights_s2.data.normal_(std=1.0 / math.sqrt(self.embedding_size))

    def forward(self, train_inputs, train_types, node_neigh):
        
    ##GATNE-T
        if self.features is None:
            ##node_embeddings是被索引张量，train_inputs是索引张量，即从这么多的node_embedding中取出该批次训练所用的节点的embedding
            ##其形状为[64,200]
            node_embed = self.node_embeddings[train_inputs]
            ##node_type_embedding是被索引张量，node_neigh是索引张量，即从节点不同种类的emebdding中选择训练节点的邻居节点的embedding
            ##其形状为[64,2,10,2,10] 
            node_embed_neighbors = self.node_type_embeddings[node_neigh]
    ##GATNE-I
        else:
            ##一样进行张量索引，但是这部分是拥有节点特征的，然后们这一部分是获取输入节点的特征，和embedding_trans来构成node_embed的矩阵
            node_embed = torch.mm(self.features[train_inputs], self.embed_trans)
            ##通过爱因斯坦求和约定得到其形状和上述GATNE一致
            node_embed_neighbors = torch.einsum('bijk,akm->bijam', self.features[node_neigh], self.u_embed_trans)
        
        ##通过提取对角线将代表edge_type_count的索引的1维和代表edge_type_count内容的3维进行结合从而得到[64,2,10,10]
        node_embed_tmp = torch.diagonal(node_embed_neighbors, dim1=1, dim2=3).permute(0, 3, 1, 2)
        ##直接将代表embedding_u_size的索引和代表embedding_u_size的内容结合得到最终的索引张量表示[64,2,10]
        node_type_embed = torch.sum(node_embed_tmp, dim=2)
        
        ##这里和上述类似都是进行张量索引，我们需要获得转化时候张量索引，按照初始化构造的被索引张量其shape分别为trans_w[2,10,200],trans_w_s1[2,10,20],trans_w_s2[2,20,1]
        ##但是我们使用索引张量后train_types其shape为[64]后得到的的是[64,10,200],[64,10,20],[64,20,1]。
        trans_w = self.trans_weights[train_types]
        trans_w_s1 = self.trans_weights_s1[train_types]
        trans_w_s2 = self.trans_weights_s2[train_types]
        
    ##得到不同edgetype的重要性
        attention = F.softmax(
            torch.matmul(
                torch.tanh(torch.matmul(node_type_embed, trans_w_s1)), trans_w_s2
                ##矩阵相乘的部分最终输出为（64，2，1）
            ).squeeze(2),
            ## 移除第2个维度（64，2）
            dim=1,
            ##softmax函数归一化为概率分布，在第1维度进行，得到的是不同
        ).unsqueeze(1)
        ##将展开得到的是不同edge type的重要性最终其attention，shape为（64，1，2）
        
    ##根据不同edgetype的重要性更新node_type_embed其shape为[64,1,10]
        node_type_embed = torch.matmul(attention, node_type_embed)
    
    ## 将node_type_embed, trans_w进行举证相乘并在第1维度进行压缩后得到的shape为[64,200]
        node_embed = node_embed + torch.matmul(node_type_embed, trans_w).squeeze(1)
    ##进行标准化
        last_node_embed = F.normalize(node_embed, dim=1)

        return last_node_embed

In [840]:
model = GATNEModel(
    num_nodes, embedding_size, embedding_u_size, edge_type_count, dim_a, features
)

## 6 定义损失计算

In [841]:
class NSLoss(nn.Module):
    def __init__(self, num_nodes, num_sampled, embedding_size):
        super(NSLoss, self).__init__()
        self.num_nodes = num_nodes
        self.num_sampled = num_sampled
        self.embedding_size = embedding_size
        self.weights = Parameter(torch.FloatTensor(num_nodes, embedding_size))
        self.sample_weights = F.normalize(
            torch.Tensor(
                [
                    (math.log(k + 2) - math.log(k + 1)) / math.log(num_nodes + 1)
                    for k in range(num_nodes)
                ]
            ),
            dim=0,
        )

        self.reset_parameters()

    def reset_parameters(self):
        self.weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size))

    def forward(self, input, embs, label):
        n = input.shape[0]
        log_target = torch.log(
            torch.sigmoid(torch.sum(torch.mul(embs, self.weights[label]), 1))
        )
        negs = torch.multinomial(
            self.sample_weights, self.num_sampled * n, replacement=True
        ).view(n, self.num_sampled)
        noise = torch.neg(self.weights[negs])
        sum_log_sampled = torch.sum(
            torch.log(torch.sigmoid(torch.bmm(noise, embs.unsqueeze(2)))), 1
        ).squeeze()

        loss = log_target + sum_log_sampled
        return -loss.sum() / n

In [842]:
nsloss = NSLoss(num_nodes, num_sampled, embedding_size)
model.to(device)
nsloss.to(device)

NSLoss()

## 7 优化器定义

In [843]:
optimizer = torch.optim.Adam(
    [{"params": model.parameters()}, {"params": nsloss.parameters()}], lr=1e-4
)

## 8 训练整体流程构建

## 8-1 getbatch

In [844]:
def get_batches(pairs, neighbors, batch_size):
    n_batches = (len(pairs) + (batch_size - 1)) // batch_size

    for idx in range(n_batches):
        x, y, t, neigh = [], [], [], []
        for i in range(batch_size):
            index = idx * batch_size + i
            if index >= len(pairs):
                break
            x.append(pairs[index][0])
            y.append(pairs[index][1])
            t.append(pairs[index][2])
            neigh.append(neighbors[pairs[index][0]])
        yield torch.tensor(x), torch.tensor(y), torch.tensor(t), torch.tensor(neigh)

## 8-2 评估函数设定

In [845]:
from sklearn.metrics import (auc, f1_score, precision_recall_curve,
                             roc_auc_score)


def get_score(local_model, node1, node2):
    try:
        vector1 = local_model[node1]
        vector2 = local_model[node2]
        return np.dot(vector1, vector2) / (np.linalg.norm(vector1) * np.linalg.norm(vector2))
    except Exception as e:
        pass

In [846]:
def evaluate(model, true_edges, false_edges):
    true_list = list()
    prediction_list = list()
    true_num = 0
    for edge in true_edges:
        tmp_score = get_score(model, str(edge[0]), str(edge[1]))
        if tmp_score is not None:
            true_list.append(1)
            prediction_list.append(tmp_score)
            true_num += 1

    for edge in false_edges:
        tmp_score = get_score(model, str(edge[0]), str(edge[1]))
        if tmp_score is not None:
            true_list.append(0)
            prediction_list.append(tmp_score)

    sorted_pred = prediction_list[:]
    sorted_pred.sort()
    threshold = sorted_pred[-true_num]

    y_pred = np.zeros(len(prediction_list), dtype=np.int32)
    for i in range(len(prediction_list)):
        if prediction_list[i] >= threshold:
            y_pred[i] = 1

    y_true = np.array(true_list)
    y_scores = np.array(prediction_list)
    ps, rs, _ = precision_recall_curve(y_true, y_scores)
    return roc_auc_score(y_true, y_scores), f1_score(y_true, y_pred), auc(rs, ps)

## 8 正式进行训练流程

In [847]:
best_score = 0
test_score = (0.0, 0.0, 0.0)
patience = 0
for epoch in range(epochs):
    random.shuffle(train_pairs)
    batches = get_batches(train_pairs, neighbors, batch_size)

    data_iter = tqdm(
        batches,
        desc="epoch %d" % (epoch),
        total=(len(train_pairs) + (batch_size - 1)) // batch_size,
        bar_format="{l_bar}{r_bar}",
    )
    avg_loss = 0.0

    for i, data in enumerate(data_iter):
        # print(data[0].shape)
        # print('-----------------------------')
        # print(data[2].shape)
        # print('-----------------------------')
        # print(data[3].shape)
        # 
        optimizer.zero_grad()
        embs = model(data[0].to(device), data[2].to(device), data[3].to(device), )
        loss = nsloss(data[0].to(device), embs, data[1].to(device))
        loss.backward()
        optimizer.step()

        avg_loss += loss.item()

        if i % 5000 == 0:
            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": avg_loss / (i + 1),
                "loss": loss.item(),
            }
            data_iter.write(str(post_fix))

    final_model = dict(zip(edge_types, [dict() for _ in range(edge_type_count)]))
    
    for i in range(num_nodes):
        train_inputs = torch.tensor([i for _ in range(edge_type_count)]).to(device)
        # print(train_inputs)
        # print('train_inputs.shape',train_inputs.shape)
        train_types = torch.tensor(list(range(edge_type_count))).to(device)
        # print('train_types.shape',train_types.shape)
        node_neigh = torch.tensor(
            [neighbors[i] for _ in range(edge_type_count)]
        ).to(device)
        # print('node_neigh.shape',node_neigh.shape)
        node_emb = model(train_inputs, train_types, node_neigh)
        # print('node_emb.shape',node_emb.shape)
        for j in range(edge_type_count):
            final_model[edge_types[j]][index2word[i]] = (
                node_emb[j].cpu().detach().numpy()
            )

    valid_aucs, valid_f1s, valid_prs = [], [], []
    test_aucs, test_f1s, test_prs = [], [], []
    for i in range(edge_type_count):
        if args.eval_type == "all" or edge_types[i] in args.eval_type.split(","):
            tmp_auc, tmp_f1, tmp_pr = evaluate(
                final_model[edge_types[i]],
                valid_true_data_by_edge[edge_types[i]],
                valid_false_data_by_edge[edge_types[i]],
            )
            valid_aucs.append(tmp_auc)
            valid_f1s.append(tmp_f1)
            valid_prs.append(tmp_pr)

            tmp_auc, tmp_f1, tmp_pr = evaluate(
                final_model[edge_types[i]],
                testing_true_data_by_edge[edge_types[i]],
                testing_false_data_by_edge[edge_types[i]],
            )
            test_aucs.append(tmp_auc)
            test_f1s.append(tmp_f1)
            test_prs.append(tmp_pr)
    print("valid auc:", np.mean(valid_aucs))
    print("valid pr:", np.mean(valid_prs))
    print("valid f1:", np.mean(valid_f1s))

    average_auc = np.mean(test_aucs)
    average_f1 = np.mean(test_f1s)
    average_pr = np.mean(test_prs)

    cur_score = np.mean(valid_aucs)
    if cur_score > best_score:
        best_score = cur_score
        test_score = (average_auc, average_f1, average_pr)
        patience = 0
    else:
        patience += 1
        if patience > args.patience:
            print("Early Stopping")
            break

epoch 0:   1%|| 36/2910 [00:00<00:15, 181.08it/s]

{'epoch': 0, 'iter': 0, 'avg_loss': 4.166138648986816, 'loss': 4.166138648986816}


epoch 0: 100%|| 2910/2910 [00:10<00:00, 289.53it/s]


valid auc: 0.549675774632141
valid pr: 0.5570909751947566
valid f1: 0.5370967741935484


epoch 1:   1%|| 27/2910 [00:00<00:10, 265.00it/s]

{'epoch': 1, 'iter': 0, 'avg_loss': 2.598921298980713, 'loss': 2.598921298980713}


epoch 1:  43%|| 1238/2910 [00:04<00:06, 274.25it/s]


KeyboardInterrupt: 