In [1]:
import networkx as nx
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from dgl.nn.pytorch import GraphConv #内置的GCNlayer
import dgl
import matplotlib.pyplot as plt
import random
import time
import tqdm
import sys
import os

In [47]:
def construct_graph():
    file_user = '../data/user_features.csv'
    file_item = '../data/item_features.csv'
    file_edge = '../data/data_action.csv'
    f_user = pd.read_csv(file_user)
    f_item = pd.read_csv(file_item)
    f_edge = pd.read_csv(file_edge)
    
    users = set()
    items = set()
    for index,row in f_edge.iterrows():
        users.add(row['user_id'])
        items.add(row['sku_id'])
    
    user_ids_index_map = {x:i for i,x in enumerate(users)}
    item_ids_index_map = {x:i for i,x in enumerate(items)}
    user_index_id_map = {i:x for i,x in enumerate(users)}
    item_index_id_map = {i:x for i,x in enumerate(items)}
    
    user_item_src = []
    user_item_dst = []
    for index,row in f_edge.iterrows():
        user_item_src.append(user_ids_index_map.get(row['user_id']))
        user_item_dst.append(item_ids_index_map.get(row['sku_id']))
    
    #构图
    ui = dgl.bipartite((user_item_src, user_item_dst), 'user', 'ui', 'item')
    iu = dgl.bipartite((user_item_dst, user_item_src), 'item', 'iu', 'user')

    hg = dgl.hetero_from_relations([ui, iu])
    return hg, user_index_id_map, item_index_id_map

In [None]:
def parse_trace(trace, user_index_id_map, item_index_id_map):
    s = []
    for index in range(trace.size):
        if index % 2 == 0:
            s.append(user_index_id_map[trace[index]])
        else:
            s.append(item_index_id_map[trace[index]])
    return ','.join(s)

In [None]:
def main():
    hg, user_index_id_map, item_index_id_map = construct_graph()
    meta_path = ['ui','iu','ui','iu','ui','iu']
    num_walks_per_node = 1
    f = open("../output/output_path.txt", "w")
    for user_idx in tqdm.trange(hg.number_of_nodes('user')): #以user开头的metapath
        traces = dgl.contrib.sampling.metapath_random_walk(
            hg=hg, etypes=meta_path, seeds=[user_idx,], num_traces=num_walks_per_node)
        tr = traces[0][0].numpy()
        tr = np.insert(tr,0,user_idx)
        res = parse_trace(tr, user_index_id_map, item_index_id_map)
        f.write(res+'\n')
    f.close()

In [48]:
if __name__=='__man__':
    main()