In [1]:
import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('max_colwidth', 1000)

import pyspark.sql.functions as F

In [2]:
import metaspore as ms
import subprocess

spark_confs={
        "spark.network.timeout":"500",
        "spark.sql.codegen.wholeStage": "false"
    }

spark = ms.spark.get_session(local=True,
                             app_name='soc-pokec Demo',
                             batch_size=256,
                             worker_count=2,
                             server_count=2,
                             worker_memory='10G',
                             server_memory='10G',
                             coordinator_memory='10G',
                             spark_confs=spark_confs)

updating: python/ (stored 0%)
updating: python/algos/ (stored 0%)
updating: python/algos/xdeepfm_net.py (deflated 71%)
updating: python/algos/widedeep_net.py (deflated 68%)
updating: python/algos/tuner/ (stored 0%)
updating: python/algos/tuner/base_tuner.py (deflated 70%)
updating: python/algos/multitask/ (stored 0%)
updating: python/algos/multitask/mmoe/ (stored 0%)
updating: python/algos/multitask/mmoe/mmoe_net.py (deflated 75%)
updating: python/algos/multitask/mmoe/mmoe_agent.py (deflated 70%)
updating: python/algos/multitask/mmoe/__pycache__/ (stored 0%)
updating: python/algos/multitask/mmoe/__pycache__/mmoe_net.cpython-38.pyc (deflated 43%)
updating: python/algos/multitask/mmoe/__pycache__/mmoe_agent.cpython-38.pyc (deflated 52%)
updating: python/algos/multitask/mmoe/.ipynb_checkpoints/ (stored 0%)
updating: python/algos/multitask/mmoe/.ipynb_checkpoints/mmoe_net-checkpoint.py (deflated 75%)
updating: python/algos/multitask/mmoe/.ipynb_checkpoints/mmoe_agent-checkpoint.py (deflate

22/07/20 09:04:57 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [3]:
edges = spark.createDataFrame([('1', '2', 1.0), 
                               ('2', '1', 1.0),
                              ('3', '1', 2.0),
                              ('1', '3', 2.0),
                               ('2', '3', 3.0),
                               ('3', '4', 3.0),
                               ('4', '3', 4.0),
                               ('5', '3', 4.0),
                               ('3', '5', 5.0),
                               ('4', '5', 5.0),
                              ('98', '99', 6.0),
                              ('99', '98', 6.0),
                              ('98', '100', 10.0)],
                              ['src', 'dst', 'weight'])

In [4]:
edges.show()

[Stage 0:>                                                          (0 + 1) / 1]

+---+---+------+
|src|dst|weight|
+---+---+------+
|  1|  2|   1.0|
|  2|  1|   1.0|
|  3|  1|   2.0|
|  1|  3|   2.0|
|  2|  3|   3.0|
|  3|  4|   3.0|
|  4|  3|   4.0|
|  5|  3|   4.0|
|  3|  5|   5.0|
|  4|  5|   5.0|
| 98| 99|   6.0|
| 99| 98|   6.0|
| 98|100|  10.0|
+---+---+------+



                                                                                

In [4]:
train_dataset = spark.read.parquet('s3://dmetasoul-bucket/demo/datasets/soc-pokec/demo_fg/train_dataset.parquet')
test_dataset = spark.read.parquet('s3://dmetasoul-bucket/demo/datasets/soc-pokec/demo_fg/test_dataset.parquet')
all_dataset = train_dataset.union(test_dataset)

edges = all_dataset.select(F.col('user_id').alias('src'), F.col('friend_id').alias('dst'), F.lit(1.0).alias('weight'))
edges.limit(10).toPandas()

                                                                                

Unnamed: 0,src,dst,weight
0,1,10,1.0
1,1,11,1.0
2,1,12,1.0
3,1,13,1.0
4,1,14,1.0
5,1,15,1.0
6,1,16,1.0
7,1,4,1.0
8,1,5,1.0
9,1,6,1.0


## Initialize lookup dataframe

In [6]:
df = edges.alias('t1').join(edges.alias('t2'), on=(F.col('t1.dst')==F.col('t2.src')), how='inner'). \
            select('t1.*', \
                   F.col('t2.dst').alias('next_dst'), \
                   F.col('t2.weight').alias('next_weight'))
df.show(10)

+---+---+------+--------+-----------+
|src|dst|weight|next_dst|next_weight|
+---+---+------+--------+-----------+
|  1| 10|   1.0|     305|        1.0|
|  1| 10|   1.0|     303|        1.0|
|  1| 10|   1.0|     301|        1.0|
|  1| 10|   1.0|     264|        1.0|
|  1| 10|   1.0|      62|        1.0|
|  1| 10|   1.0|      60|        1.0|
|  1| 10|   1.0|      33|        1.0|
|  1| 10|   1.0|     304|        1.0|
|  1| 10|   1.0|     302|        1.0|
|  1| 10|   1.0|     300|        1.0|
+---+---+------+--------+-----------+
only showing top 10 rows



In [7]:
src_neighbors = edges.groupBy(F.col('src')).agg(F.collect_list(F.col('dst')).alias('src_neighbors'))
src_neighbors.show()

[Stage 8:>                                                          (0 + 2) / 2]

+-----+--------------------+
|  src|       src_neighbors|
+-----+--------------------+
|10096|[10158, 10167, 10...|
|10351|[10714, 11645, 12...|
|10436|[12291, 12573, 16...|
| 1090|[1004, 10466, 107...|
|11078|[10925, 10928, 10...|
|11332|[1041, 11993, 157...|
|11563|[11484, 346, 4222...|
| 1159|[1154, 11574, 116...|
|11722|               [189]|
|11888|[10644, 10735, 11...|
|12394|               [189]|
|12529|    [349, 8815, 366]|
|12847|[12848, 12977, 12...|
|13192|[1135, 12844, 129...|
|13282|               [404]|
|13442|[10413, 11646, 13...|
|13610|               [189]|
|13772|[10026, 10163, 10...|
|13865|[12934, 15154, 41...|
|14157|[1080, 1086, 1132...|
+-----+--------------------+
only showing top 20 rows



                                                                                

In [8]:
df = df.join(src_neighbors, on='src', how='leftouter')
df.show(10)

                                                                                

+---+---+------+--------+-----------+--------------------+
|src|dst|weight|next_dst|next_weight|       src_neighbors|
+---+---+------+--------+-----------+--------------------+
|  1| 10|   1.0|     305|        1.0|[10, 11, 12, 13, ...|
|  1| 10|   1.0|     303|        1.0|[10, 11, 12, 13, ...|
|  1| 10|   1.0|     301|        1.0|[10, 11, 12, 13, ...|
|  1| 10|   1.0|     264|        1.0|[10, 11, 12, 13, ...|
|  1| 10|   1.0|      62|        1.0|[10, 11, 12, 13, ...|
|  1| 10|   1.0|      60|        1.0|[10, 11, 12, 13, ...|
|  1| 10|   1.0|      33|        1.0|[10, 11, 12, 13, ...|
|  1| 10|   1.0|     304|        1.0|[10, 11, 12, 13, ...|
|  1| 10|   1.0|     302|        1.0|[10, 11, 12, 13, ...|
|  1| 10|   1.0|     300|        1.0|[10, 11, 12, 13, ...|
+---+---+------+--------+-----------+--------------------+
only showing top 10 rows



In [9]:
df = df.groupBy([F.col('src'), F.col('dst')]).agg(F.struct(F.collect_list(F.struct(F.col('next_dst').alias('dst'),\
                                                                           F.col('next_weight').alias('weight'))\
                                                                          ).alias('dst_neighbors'),
                                                            F.first(F.col('src_neighbors')).alias('src_neighbors')\
                                                           ).alias('attributes'))
df.show(10, False)
df.printSchema()



+-----+-----+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

                                                                                

In [10]:
from collections import deque
from math import floor

def setup_alias(weights):
    N = len(weights)
    p = [-1.0] * N
    a = [-1] * N
    small = deque()
    large = deque()
    
    summation = sum(weights)
    for idx, weight in enumerate(weights):
        p[idx] = N * weight / summation
        small.append(idx) if p[idx] < 1.0 else large.append(idx)  
    
    while len(small) > 0 and len(large) > 0:
        s = small.pop()
        l = large.pop()
        a[s] = l
        p[l] = p[l] + p[s] - 1.0
        small.append(l) if p[l] < 1.0 else large.append(l)
    
    while len(large) > 0:
        p[large.pop()] = 1.0
        
    while len(small) > 0:
        p[small.pop()] = 1.0
        
    return p, a


def draw_alias(p, a):
    from random import Random
    from time import time

    rdg = Random(time())
    idx = floor(rdg.random() * len(p))
    return idx if rdg.random() < p[idx] else a[idx]

def verify(weights, p, a, sample_numb = 10000):
    N = len(weights)
    S = sum(weights)
    origin_probs = []
    for w in weights:
        origin_probs.append(w / S)
    print('Debug - original probs: ', origin_probs)
    
    count = [0] * N
    for i in range(sample_numb):
        idx = draw_alias(p, a)
        count[idx] = count[idx] + 1
    print('Debug - sampled probs: ', [c / sample_numb for c in count])
    

In [11]:
from pyspark.sql.types import Row

p = 2
q = 0.5
Z = 1.0

def setup_edges(row, p, q):
    src, dst, attributes = row['src'], row['dst'], row['attributes']
    dst_neighbors, src_neighbors = attributes['dst_neighbors'], attributes['src_neighbors']
    
    new_dst_neighbors = []
    pq_weights = []
    for dst_neighbor in dst_neighbors:
        neighbor_dst, neighbor_weight = dst_neighbor['dst'], dst_neighbor['weight']
        alpha = 1 / q
        if neighbor_dst in src_neighbors:
            alpha = 1
        elif neighbor_dst == src:
            alpha = 1 / p
        pq_weight = neighbor_weight * alpha / Z
        pq_weights.append(pq_weight)
        new_dst_neighbors.append(neighbor_dst)
    
    p, a = setup_alias(pq_weights)
    
    new_attributes = Row(dst_neighbors=new_dst_neighbors, p=p, a=a)
    new_row = Row(src=src, dst=dst, attributes=new_attributes)
    
    return new_row

edges_lookup = df.rdd.map(lambda row: setup_edges(row, p, q)).toDF(df.columns)
edges_lookup.printSchema()



root
 |-- src: string (nullable = true)
 |-- dst: string (nullable = true)
 |-- attributes: struct (nullable = true)
 |    |-- dst_neighbors: array (nullable = true)
 |    |    |-- element: string (containsNull = true)
 |    |-- p: array (nullable = true)
 |    |    |-- element: double (containsNull = true)
 |    |-- a: array (nullable = true)
 |    |    |-- element: long (containsNull = true)



                                                                                

In [12]:
edges_lookup.limit(100).toPandas()

Traceback (most recent call last):                               (15 + 4) / 200]
  File "/opt/spark/python/lib/pyspark.zip/pyspark/daemon.py", line 186, in manager
  File "/opt/spark/python/lib/pyspark.zip/pyspark/daemon.py", line 74, in worker
  File "/opt/spark/python/lib/pyspark.zip/pyspark/worker.py", line 643, in main
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
  File "/opt/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 564, in read_int
    raise EOFError
EOFError
Traceback (most recent call last):                               (76 + 4) / 200]
  File "/opt/spark/python/lib/pyspark.zip/pyspark/daemon.py", line 186, in manager
  File "/opt/spark/python/lib/pyspark.zip/pyspark/daemon.py", line 74, in worker
  File "/opt/spark/python/lib/pyspark.zip/pyspark/worker.py", line 643, in main
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
  File "/opt/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 564, in read_int
    raise EOFError
EOFError
Tr

Unnamed: 0,src,dst,attributes
0,10019,9525,"{'dst_neighbors': ['10435', '9733', '11180', '11288', '11437', '14287', '14588', '14780', '14935', '14936', '14941', '15236', '15247', '15258', '15743', '2330', '4299', '8071', '8764', '9626', '9628', '9723', '9916'], 'p': [1.0, 0.9047619047619075, 0.5476190476190477, 0.8095238095238122, 0.7142857142857169, 0.6190476190476215, 0.9761904761904783, 0.8809523809523829, 0.7857142857142876, 0.6904761904761922, 0.5952380952380969, 0.9523809523809537, 0.8571428571428583, 0.761904761904763, 0.6666666666666676, 0.5714285714285723, 0.928571428571429, 0.5476190476190477, 0.8333333333333337, 0.5476190476190477, 0.7380952380952384, 0.642857142857143, 0.5476190476190477], 'a': [-1, 0, 5, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 10, 16, 15, 18, 20, 21]}"
1,1004,2192,"{'dst_neighbors': ['1004', '11447', '2933', '3826', '3930', '6517', '10999', '11069', '12056', '12434', '1369', '15142', '15178', '15358', '1626', '168', '1777', '1874', '2002', '2119', '2144', '2796', '3282', '3523', '4064', '4666', '472', '485', '5169', '585', '626', '6270', '6424', '7056', '7413', '8121', '8169', '8307', '8426', '8538', '8768', '9027', '961', '9743'], 'p': [0.2603550295857988, 1.0, 0.9585798816568043, 0.917159763313609, 0.8757396449704138, 0.8343195266272185, 0.7928994082840233, 0.751479289940828, 0.7100591715976328, 0.6686390532544375, 0.6272189349112423, 0.585798816568047, 0.5443786982248517, 0.5029585798816565, 0.46153846153846123, 0.420118343195266, 0.3786982248520707, 0.33727810650887546, 0.2958579881656802, 0.9940828402366861, 0.9526627218934909, 0.9112426035502956, 0.8698224852071004, 0.8284023668639051, 0.7869822485207099, 0.7455621301775146, 0.7041420118343193, 0.6627218934911241, 0.6213017751479288, 0.5798816568047336, 0.5384615384615383, 0.97633136094..."
2,1006,9058,"{'dst_neighbors': ['1006', '60', '8243', '10565', '13493', '14432', '1838', '1920', '2371', '7172', '7981', '8540', '8544', '8898', '8899', '8906'], 'p': [0.26229508196721313, 1.0, 0.9508196721311484, 0.9016393442622959, 0.8524590163934433, 0.8032786885245908, 0.7540983606557383, 0.7049180327868858, 0.6557377049180333, 0.6065573770491808, 0.5573770491803283, 0.5081967213114758, 0.45901639344262324, 0.4098360655737707, 0.3606557377049182, 0.3114754098360657], 'a': [15, -1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]}"
3,10072,5736,"{'dst_neighbors': ['2582', '13225', '135', '143', '14322', '15771', '5399', '5616', '6652', '7073', '8169'], 'p': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 'a': [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]}"
4,10084,12144,"{'dst_neighbors': ['10193', '12286', '14962', '8780', '11432', '11787', '12298', '12511', '12993', '13041', '13584', '13711', '13739', '13818', '14202', '14760', '14816', '15137', '15248', '15347', '1783', '3065', '3074', '3217', '3693', '3700', '6861', '78', '7942', '8102', '8140', '8233', '8789', '9923'], 'p': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 'a': [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]}"
5,101,8158,"{'dst_neighbors': ['101', '1198', '13208', '14942', '15050', '15247', '15418', '15476', '4597', '5412', '9917'], 'p': [0.28205128205128205, 1.0, 0.8717948717948714, 0.7435897435897432, 0.615384615384615, 0.4871794871794868, 0.3589743589743586, 0.9487179487179487, 0.8205128205128205, 0.5641025641025641, 0.6923076923076923], 'a': [6, -1, 1, 2, 3, 4, 5, 6, 7, 10, 8]}"
6,10110,11921,"{'dst_neighbors': ['2101', '11069', '11413', '12853', '13480', '14363', '2637', '3453', '3740', '3741', '3819', '454', '7895', '8118'], 'p': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 'a': [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]}"
7,1014,3741,"{'dst_neighbors': ['12165', '13480', '2383', '465', '5373', '11003', '12056', '12160', '12170', '12789', '13253', '13343', '13422', '13423', '13482', '14274', '1542', '1749', '1781', '1838', '242', '3266', '3272', '3417', '3738', '3830', '454', '456', '60', '8040', '8105', '8163', '8246', '8899', '9261', '9813'], 'p': [1.0, 0.9565217391304339, 0.9130434782608687, 0.8695652173913035, 0.8260869565217384, 0.7826086956521732, 0.739130434782608, 0.6956521739130428, 0.6521739130434776, 0.6086956521739124, 0.5652173913043472, 0.9999999999999991, 0.9565217391304339, 0.9130434782608687, 0.8695652173913035, 0.8260869565217384, 0.7826086956521732, 0.739130434782608, 0.5217391304347826, 0.6956521739130428, 0.6521739130434776, 0.6086956521739124, 0.5652173913043472, 0.9999999999999996, 0.9565217391304344, 0.9130434782608692, 0.869565217391304, 0.8260869565217388, 0.7826086956521736, 0.7391304347826084, 0.6956521739130432, 0.652173913043478, 0.5217391304347826, 0.5217391304347826, 0.608695652173..."
8,10173,5479,"{'dst_neighbors': ['10173', '11391', '212', '334', '3683', '4257', '4441', '5211', '5466', '5472', '5478', '5614', '5869', '6371', '9184'], 'p': [0.30612244897959184, 1.0, 0.7755102040816335, 0.5510204081632661, 0.32653061224489877, 0.7959183673469397, 0.6122448979591837, 0.9591836734693884, 0.734693877551021, 0.8979591836734697, 0.6734693877551023, 0.6122448979591837, 0.8367346938775511, 0.6122448979591837, 0.6122448979591837], 'a': [4, -1, 1, 2, 3, 4, 5, 5, 7, 8, 9, 8, 10, 10, 12]}"
9,10177,8256,"{'dst_neighbors': ['8170', '8426', '10177', '11791', '1383', '4121', '487', '5623', '60', '634', '7977', '8119', '8121', '8331', '8335', '8342', '8352', '8468'], 'p': [1.0, 0.8196721311475406, 0.29508196721311475, 0.6393442622950816, 0.4590163934426226, 0.9836065573770487, 0.8032786885245897, 0.5901639344262295, 0.6229508196721307, 0.5901639344262295, 0.8524590163934425, 0.5901639344262295, 0.6721311475409835, 0.9016393442622952, 0.5901639344262295, 0.7213114754098362, 0.9508196721311475, 0.7704918032786885], 'a': [-1, 0, 4, 1, 3, 4, 5, 8, 6, 12, 8, 15, 10, 12, 17, 13, 15, 16]}"


In [13]:
edges.show()

+---+---+------+
|src|dst|weight|
+---+---+------+
|  1| 10|   1.0|
|  1| 11|   1.0|
|  1| 12|   1.0|
|  1| 13|   1.0|
|  1| 14|   1.0|
|  1| 15|   1.0|
|  1| 16|   1.0|
|  1|  4|   1.0|
|  1|  5|   1.0|
|  1|  6|   1.0|
|  1|  7|   1.0|
|  1|  8|   1.0|
| 10|  1|   1.0|
| 10|134|   1.0|
| 10|190|   1.0|
| 10|208|   1.0|
| 10|236|   1.0|
| 10|238|   1.0|
| 10|241|   1.0|
| 10|242|   1.0|
+---+---+------+
only showing top 20 rows



In [14]:
df = edges.groupBy(F.col('src')).agg(F.collect_list(F.struct(F.col('dst'),\
                                                             F.col('weight'))).alias('attributes'))
df.show(10, False)
df.printSchema()

+-----+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|src  |attributes                                                                                                                                                         

In [15]:
def setup_vertices(row):
    src, attributes = row['src'], row['attributes']

    neighbors = []
    weights = []
    for attribute in attributes:
        neighbors.append(attribute['dst'])
        weights.append(attribute['weight'])
    
    p, a = setup_alias(weights)
    new_attributes = Row(neighbors=neighbors, p=p, a=a)
    
    return src, new_attributes

vertices_lookup = df.rdd.map(lambda row: setup_vertices(row)).toDF(['src', 'attributes'])
vertices_lookup.printSchema()

root
 |-- src: string (nullable = true)
 |-- attributes: struct (nullable = true)
 |    |-- neighbors: array (nullable = true)
 |    |    |-- element: string (containsNull = true)
 |    |-- p: array (nullable = true)
 |    |    |-- element: double (containsNull = true)
 |    |-- a: array (nullable = true)
 |    |    |-- element: long (containsNull = true)



                                                                                

In [16]:
vertices_lookup.show(10, False)

+-----+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|src  |attributes                     

## Random walk

In [17]:
def first_step(row):
    src, attributes = row['src'], row['attributes']
    
    next_index = draw_alias(attributes['p'], attributes['a'])
    next_vertice = attributes['neighbors'][next_index]
    
    return src, [src, next_vertice]

walk_df = vertices_lookup.rdd.map(lambda row: first_step(row)).toDF(['origin', 'path'])
walk_df.printSchema()



root
 |-- origin: string (nullable = true)
 |-- path: array (nullable = true)
 |    |-- element: string (containsNull = true)



                                                                                

In [18]:
walk_df.show()

+------+--------------+
|origin|          path|
+------+--------------+
| 10096|[10096, 10167]|
| 10351|  [10351, 423]|
| 10436|[10436, 12291]|
|  1090|  [1090, 6830]|
| 11078| [11078, 7895]|
| 11332|[11332, 15793]|
| 11563| [11563, 4229]|
|  1159|  [1159, 1202]|
| 11722|  [11722, 189]|
| 11888| [11888, 2414]|
| 12394|  [12394, 189]|
| 12529|  [12529, 349]|
| 12847|[12847, 12977]|
| 13192|[13192, 13066]|
| 13282|  [13282, 404]|
| 13442|[13442, 13440]|
| 13610|  [13610, 189]|
| 13772| [13772, 8925]|
| 13865|  [13865, 417]|
| 14157| [14157, 4035]|
+------+--------------+
only showing top 20 rows



In [19]:
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType, ArrayType

num_walks = 10

def next_step(path, attributes):    
    if attributes is not None:
        next_index = draw_alias(attributes['p'], attributes['a'])
        next_vertice = attributes['dst_neighbors'][next_index]
        path.append(next_vertice)
    
    return path

next_path_udf = udf(lambda path, attributes: next_step(path, attributes), ArrayType(StringType()))


for i in range(num_walks - 2):
    walk_df = walk_df.withColumn('src', F.element_at(F.col('path'), -2))
    walk_df = walk_df.withColumn('dst', F.element_at(F.col('path'), -1))  
    walk_df = walk_df.join(edges_lookup, on=['src', 'dst'], how='leftouter')
    walk_df = walk_df.select('origin', next_path_udf('path', 'attributes').alias('path'))


In [20]:
walk_df.limit(10).toPandas()

                                                                                

Unnamed: 0,origin,path
0,4289,"[4289, 12173, 10147, 11220, 3597, 7196, 4917, 10389, 6695, 4169]"
1,4237,"[4237, 13085, 4331, 10992, 11115, 10739, 13410, 10556, 4254, 9280]"
2,5224,"[5224, 3969, 4241, 7322, 4087, 4275, 7789, 10857, 4171, 8670]"
3,2221,"[2221, 8679, 7281, 8895, 9805, 13185, 483, 11281, 14748, 545]"
4,3722,"[3722, 8083, 5992, 15127, 2996, 83, 450, 11531, 1627, 11985]"
5,11572,"[11572, 9385, 11572, 14322]"
6,9964,"[9964, 4290, 7857, 8724, 8597, 9370, 12925, 12, 359, 9272]"
7,1090,"[1090, 3563, 13000, 2950, 7417, 2950, 13000, 1208, 2981, 172]"
8,10011,"[10011, 4698, 8651, 8513, 13183, 3530, 5426, 12303, 14680, 7970]"
9,12642,"[12642, 11801, 11225, 11967, 12357, 12450, 12632, 12426, 12355, 12632]"


In [21]:
walk_df.cache()

DataFrame[origin: string, path: array<string>]

In [28]:
walk_df.count()

15271

In [29]:
all_dataset.count()

219458

In [30]:
edges.count()

219458

In [31]:
edges_lookup.count()

                                                                                

215931

In [32]:
vertices_lookup.count()

                                                                                

15271

## Word2Vec:

In [22]:
from pyspark.ml.feature import Word2Vec

In [23]:
sent = ("a b " * 100 + "a c " * 10).split(" ")
sent2 = ("a b " * 10 + "a c " * 1).split(" ")

In [24]:
doc = spark.createDataFrame([(sent,), (sent2,)], ["sentence"])
doc.toPandas()

Unnamed: 0,sentence
0,"[a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, ...]"
1,"[a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, b, a, c, ]"


In [25]:
word2Vec = Word2Vec(vectorSize=5, seed=42, inputCol="path", outputCol="model", windowSize=30, minCount=0, maxIter=10, numPartitions=1)

In [26]:
model = word2Vec.fit(walk_df)

22/07/19 10:40:43 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
22/07/19 10:40:43 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
                                                                                

In [27]:
model.getVectors().show(10, False)

+-----+-------------------------------------------------------------------------------------------------------+
|word |vector                                                                                                 |
+-----+-------------------------------------------------------------------------------------------------------+
|10292|[-1.9123085737228394,-0.2719021439552307,-0.30774733424186707,0.398619145154953,-0.34304240345954895]  |
|5451 |[0.2596725821495056,0.05336519330739975,0.6873347163200378,-0.25403642654418945,-0.39802688360214233]  |
|4018 |[0.1826438456773758,-0.1243230327963829,0.9898014068603516,-0.24240678548812866,0.00690585607662797]   |
|9936 |[-0.2681558132171631,0.8051718473434448,-0.7521481513977051,0.24533161520957947,-1.395830750465393]    |
|13172|[0.41216525435447693,0.06221426650881767,0.6139582991600037,-0.39063355326652527,0.26480531692504883]  |
|10304|[-1.1038172245025635,0.9332424998283386,-0.4859587550163269,-0.5751068592071533,-0.25783306360244

## Test estimator:

In [1]:
import metaspore as ms
import subprocess

subprocess.run(['zip', '-r', 'solutions/recommend/offline/social_network/python.zip', 'python'], cwd='../../../../')

spark_confs={
        "spark.network.timeout":"500",
        "spark.sql.codegen.wholeStage": "false",
        "spark.submit.pyFiles":"python.zip"
    }

spark = ms.spark.get_session(local=True,
                             app_name='soc-pokec Demo',
                             batch_size=256,
                             worker_count=2,
                             server_count=2,
                             worker_memory='10G',
                             server_memory='10G',
                             coordinator_memory='10G',
                             spark_confs=spark_confs)

updating: python/ (stored 0%)
updating: python/algos/ (stored 0%)
updating: python/algos/xdeepfm_net.py (deflated 71%)
updating: python/algos/widedeep_net.py (deflated 68%)
updating: python/algos/tuner/ (stored 0%)
updating: python/algos/tuner/base_tuner.py (deflated 70%)
updating: python/algos/multitask/ (stored 0%)
updating: python/algos/multitask/mmoe/ (stored 0%)
updating: python/algos/multitask/mmoe/mmoe_net.py (deflated 75%)
updating: python/algos/multitask/mmoe/mmoe_agent.py (deflated 70%)
updating: python/algos/multitask/mmoe/__pycache__/ (stored 0%)
updating: python/algos/multitask/mmoe/__pycache__/mmoe_net.cpython-38.pyc (deflated 43%)
updating: python/algos/multitask/mmoe/__pycache__/mmoe_agent.cpython-38.pyc (deflated 52%)
updating: python/algos/multitask/mmoe/.ipynb_checkpoints/ (stored 0%)
updating: python/algos/multitask/mmoe/.ipynb_checkpoints/mmoe_net-checkpoint.py (deflated 75%)
updating: python/algos/multitask/mmoe/.ipynb_checkpoints/mmoe_agent-checkpoint.py (deflate

22/07/20 09:49:10 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


edges = spark.createDataFrame([('1', '2', 1.0), 
                               ('2', '1', 1.0),
                              ('3', '1', 2.0),
                              ('1', '3', 2.0),
                               ('2', '3', 3.0),
                               ('3', '4', 3.0),
                               ('4', '3', 4.0),
                               ('5', '3', 4.0),
                               ('3', '5', 5.0),
                               ('4', '5', 5.0),
                              ('98', '99', 6.0),
                              ('99', '98', 6.0),
                              ('98', '100', 10.0)],
                              ['src', 'dst', 'weight'])


from python.algos.node2vec_retrieval import Node2VecEstimator

estimator = Node2VecEstimator(source_vertex_column_name='src',
                              destination_vertex_column_name='dst',
                              weight_column_name='weight',
                              trigger_vertex_column_name='dst',
                              random_walk_p = 0.5,
                              random_walk_q = 1.0,
                              debug=True)
model = estimator.fit(edges)

In [2]:
import pyspark.sql.functions as F

train_dataset = spark.read.parquet('s3://dmetasoul-bucket/demo/datasets/soc-pokec/demo_fg/train_dataset.parquet')
test_dataset = spark.read.parquet('s3://dmetasoul-bucket/demo/datasets/soc-pokec/demo_fg/test_dataset.parquet')
all_dataset = train_dataset.union(test_dataset)

edges = all_dataset.select(F.col('user_id'), F.col('friend_id'))


from python.algos.node2vec_retrieval import Node2VecEstimator

estimator = Node2VecEstimator(source_vertex_column_name='user_id',
                              destination_vertex_column_name='friend_id',
                              trigger_vertex_column_name='friend_id',
                              random_walk_p = 0.5,
                              random_walk_q = 1.0,
                              debug=True)
model = estimator.fit(edges)

Debug - edges:


                                                                                

+---+---+------+
|src|dst|weight|
+---+---+------+
|1  |10 |1.0   |
|1  |11 |1.0   |
|1  |12 |1.0   |
|1  |13 |1.0   |
|1  |14 |1.0   |
|1  |15 |1.0   |
|1  |16 |1.0   |
|1  |4  |1.0   |
|1  |5  |1.0   |
|1  |6  |1.0   |
+---+---+------+
only showing top 10 rows

Debug - attributes of vertices:


                                                                                

+-----+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|src  |attributes                                                                                                                                                         

                                                                                

Debug - vertices_lookup:
+-----+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|src  |attrib

                                                                                

+---+---+------+--------+-----------+
|src|dst|weight|next_dst|next_weight|
+---+---+------+--------+-----------+
|1  |10 |1.0   |305     |1.0        |
|1  |10 |1.0   |303     |1.0        |
|1  |10 |1.0   |301     |1.0        |
|1  |10 |1.0   |264     |1.0        |
|1  |10 |1.0   |62      |1.0        |
|1  |10 |1.0   |60      |1.0        |
|1  |10 |1.0   |33      |1.0        |
|1  |10 |1.0   |304     |1.0        |
|1  |10 |1.0   |302     |1.0        |
|1  |10 |1.0   |300     |1.0        |
+---+---+------+--------+-----------+
only showing top 10 rows

Debug - src_neighbors:
+-----+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|src  |src_neighbors          

                                                                                

+---+---+------+--------+-----------+-------------------------------------------------+
|src|dst|weight|next_dst|next_weight|src_neighbors                                    |
+---+---+------+--------+-----------+-------------------------------------------------+
|1  |10 |1.0   |305     |1.0        |[10, 11, 12, 13, 14, 15, 16, 4, 5, 6, 7, 8, 3, 9]|
|1  |10 |1.0   |303     |1.0        |[10, 11, 12, 13, 14, 15, 16, 4, 5, 6, 7, 8, 3, 9]|
|1  |10 |1.0   |301     |1.0        |[10, 11, 12, 13, 14, 15, 16, 4, 5, 6, 7, 8, 3, 9]|
|1  |10 |1.0   |264     |1.0        |[10, 11, 12, 13, 14, 15, 16, 4, 5, 6, 7, 8, 3, 9]|
|1  |10 |1.0   |62      |1.0        |[10, 11, 12, 13, 14, 15, 16, 4, 5, 6, 7, 8, 3, 9]|
|1  |10 |1.0   |60      |1.0        |[10, 11, 12, 13, 14, 15, 16, 4, 5, 6, 7, 8, 3, 9]|
|1  |10 |1.0   |33      |1.0        |[10, 11, 12, 13, 14, 15, 16, 4, 5, 6, 7, 8, 3, 9]|
|1  |10 |1.0   |304     |1.0        |[10, 11, 12, 13, 14, 15, 16, 4, 5, 6, 7, 8, 3, 9]|
|1  |10 |1.0   |302     |1.0    

                                                                                

+-----+-----+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

                                                                                

Debug - edges_lookup:
+-----+-----+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|src  |dst  |attributes                                                                                                                                                                                                                                                                                



+------+-----------------------------------------------------------------+
|origin|path                                                             |
+------+-----------------------------------------------------------------+
|11158 |[11158, 9240, 12393, 9250, 12795, 7977, 8251, 1019, 8893, 15359] |
|9244  |[9244, 6666, 4122, 2308, 4355, 8605, 4355, 10420, 1852, 1911]    |
|15191 |[15191, 1156, 14573]                                             |
|1683  |[1683, 1061, 29, 106, 57, 163, 3048, 1223, 134, 538]             |
|4005  |[4005, 229, 4088, 4620, 1235, 1223, 6095, 1237, 1260, 3088]      |
|4892  |[4892, 8467, 12870, 12576, 14741, 9813, 13015, 12539, 1190, 2447]|
|6512  |[6512, 250, 208, 7903, 7283, 12042, 3706, 1268, 11568, 1821]     |
|10028 |[10028, 9502, 407, 9552, 407, 10033, 11080, 12696, 366, 12538]   |
|10517 |[10517, 3355, 1516, 4213, 376, 1934, 7301, 12715, 12395, 14816]  |
|3922  |[3922, 9898, 15165, 15379, 1360, 12519, 165, 12715, 12395, 15355]|
+------+-----------------

                                                                                