In [64]:
import findspark
findspark.init()

import pyspark
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

from pyspark.sql.functions import udf, col, broadcast
import random

spark.sparkContext.setCheckpointDir("./spark_tmp/")

In [2]:
num_context = 30

## Create Dictionary Data

In [75]:
data = []
for context in range(num_context):
    for word_id in range(20000):
        word = str(hash(random.random()))
        data.append((context, word_id, word))
df = spark.createDataFrame(data, ['context', 'word_id', 'word'])

In [76]:
df.createOrReplaceTempView("dictionary")

## Create Coded Data

In [106]:
#data = [(6, (i*3+5)%566, (i*60+6)%13) for i in range(10000)] # some random data
data = [(6, (i*3+5)%5000, (i*60+6)%13) for i in range(100000)] # some random data
coded_df = spark.createDataFrame(data, ['context', 'word_id', 'word_id2']).distinct()

In [107]:
coded_df.createOrReplaceTempView("coded")

## Let the joining begin

### 1st try - SQL Syntax

In [108]:
%%time
sql_df = spark.sql("""
select a.context, a.word_id, b.word, a.word_id2, c.word as word2
from coded a
join dictionary b
on a.context = b.context and a.word_id = b.word_id 
join dictionary c
on a.context = c.context and a.word_id2 = c.word_id
""")
sql_count = sql_df.count()

Wall time: 24 s


In [109]:
sql_count

65000

### 2nd try - Spark SQL Syntax

In [20]:
coded_df.join(df, (coded_df['word_id'] == df['word_id']) & (coded_df['context'] == df['context'])).select(coded_df['word_id2'], df['*']).count()

7358

In [67]:
%%time
spark_sql_df = coded_df.join(df, (coded_df['word_id'] == df['word_id']) & (coded_df['context'] == df['context'])).select(coded_df['word_id2'], df['*'])
spark_sql_df = spark_sql_df.checkpoint()

Wall time: 3.7 s


In [68]:
spark_sql_df.explain()

== Physical Plan ==
*(1) Scan ExistingRDD[word_id2#8L,context#0L,word_id#1L,word#2]




In [69]:
# .select(spark_sql_df['*'], df['word'].alias('word2'))
spark_sql_df.where('word_id2 = 1').join(df, ((spark_sql_df['word_id2'] == df['word_id']) & (spark_sql_df['context'] == df['context']))).show()


+--------+-------+-------+-------------------+-------+-------+-------------------+
|word_id2|context|word_id|               word|context|word_id|               word|
+--------+-------+-------+-------------------+-------+-------+-------------------+
|       1|      6|      1|1968718652097126144|      6|      0|1166852827755916544|
|       1|      6|      1|1968718652097126144|      6|      1|1968718652097126144|
|       1|      6|      1|1968718652097126144|      6|      2|1714623463938671872|
|       1|      6|      1|1968718652097126144|      6|      3|2140051419228967168|
|       1|      6|      1|1968718652097126144|      6|      4| 417743325427254784|
|       1|      6|      1|1968718652097126144|      6|      5| 225461857228028928|
|       1|      6|      1|1968718652097126144|      6|      6|1855172547184932608|
|       1|      6|      1|1968718652097126144|      6|      7| 150798719757318400|
|       1|      6|      1|1968718652097126144|      6|      8| 865651495621323520|
|   

In [29]:
spark_sql_count

65000

### 3rd try - the plain old dict

In [111]:
# the dictionary data structure is indexed, therefore fetches data fast
indexed_data = {ctx:dict() for ctx in range(num_context)}
for context, word_id, word in data:
    indexed_data[context][word_id] = word

In [112]:
broadcast_indexed_data = spark.sparkContext.broadcast(indexed_data)

In [113]:
join_udf = udf(lambda ctx, word_id: broadcast_indexed_data.value[ctx][word_id])

In [114]:
%%time
udf_df = coded_df.withColumn('word', join_udf(col("context"), col("word_id")))
udf_df = udf_df.withColumn('word2', join_udf(col("context"), col("word_id2")))
udf_count = udf_df.count()

Wall time: 3.2 s


In [115]:
udf_count

65000

## Digging deeper

In [116]:
sql_df.explain()

== Physical Plan ==
*(10) Project [context#2116L, word_id#2117L, word#1954, word_id2#2118L, word#2128 AS word2#2125]
+- *(10) SortMergeJoin [context#2116L, word_id2#2118L], [context#2126L, word_id#2127L], Inner
   :- *(7) Sort [context#2116L ASC NULLS FIRST, word_id2#2118L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(context#2116L, word_id2#2118L, 200), true, [id=#4493]
   :     +- *(6) Project [context#2116L, word_id#2117L, word_id2#2118L, word#1954]
   :        +- *(6) SortMergeJoin [context#2116L, word_id#2117L], [context#1952L, word_id#1953L], Inner
   :           :- *(3) Sort [context#2116L ASC NULLS FIRST, word_id#2117L ASC NULLS FIRST], false, 0
   :           :  +- Exchange hashpartitioning(context#2116L, word_id#2117L, 200), true, [id=#4479]
   :           :     +- *(2) HashAggregate(keys=[context#2116L, word_id#2117L, word_id2#2118L], functions=[])
   :           :        +- Exchange hashpartitioning(context#2116L, word_id#2117L, word_id2#2118L, 200), true, [

In [None]:
spark_sql_df.explain()

In [117]:
udf_df.explain()

== Physical Plan ==
*(3) Project [context#2116L, word_id#2117L, word_id2#2118L, pythonUDF0#2175 AS word#2153, pythonUDF1#2176 AS word2#2159]
+- BatchEvalPython [<lambda>(context#2116L, word_id#2117L), <lambda>(context#2116L, word_id2#2118L)], [pythonUDF0#2175, pythonUDF1#2176]
   +- *(2) HashAggregate(keys=[context#2116L, word_id#2117L, word_id2#2118L], functions=[])
      +- Exchange hashpartitioning(context#2116L, word_id#2117L, word_id2#2118L, 200), true, [id=#4564]
         +- *(1) HashAggregate(keys=[context#2116L, word_id#2117L, word_id2#2118L], functions=[])
            +- *(1) Scan ExistingRDD[context#2116L,word_id#2117L,word_id2#2118L]




In [None]:
# how does it change as the amount of colunms to translate?