Skip to content

Commit

Permalink
Asynchronous initialization of Spark session (IPython)
Browse files Browse the repository at this point in the history
- initialize the Spark variables to a generic wrapper object
- start a thread to initialize the SparkContext/SparkSession just before
  calling embed_kernel()
- the notebook will be connected (and ready) before the Spark session
  gets initialized (Yarn application status is still ACCEPTED)
- non-Spark related cells can be executed
- tab completion on any Spark variable will show the attribute
  WAITING_FOR_SPARK_SESSION_TO_BE_INITIALIZED (with value "Spark Session
  not yet initialized ...")
- running a notebook cell that references any of the Spark variables
  will wait for the initialization thread to complete (blocking) and
  delegate the execution to the actual Spark objects
- Yarn application status is RUNNING once Spark session is initialized

Closes #64
  • Loading branch information
ckadner committed Jul 18, 2017
1 parent 4ef03fd commit 6b0c2b7
Showing 1 changed file with 59 additions and 12 deletions.
Expand Up @@ -7,6 +7,62 @@
from jupyter_client.connect import write_connection_file
from IPython import embed_kernel
from pyspark.sql import SparkSession
from threading import Thread


class WaitingForSparkSessionToBeInitialized(object):
"""Wrapper object for SparkContext and other Spark session variables while the real Spark session is being
initialized in a background thread. The class name is intentionally worded verbosely explicit as it will show up
when executing a cell that contains only a Spark session variable like ``sc`` or ``sqlContext``.
"""

# private and public attributes that show up for tab completion, to indicate pending initialization of Spark session
_WAITING_FOR_SPARK_SESSION_TO_BE_INITIALIZED = 'Spark Session not yet initialized ...'
WAITING_FOR_SPARK_SESSION_TO_BE_INITIALIZED = 'Spark Session not yet initialized ...'

# the same wrapper class is used for all Spark session variables, so we need to record the name of the variable
def __init__(self, global_variable_name):
self._spark_session_variable = global_variable_name

# we intercept all method and attribute references on our temporary Spark session variable, wait for the thread to
# complete initializing the Spark sessions and then we forward the call to the real Spark objects
def __getattr__(self, name):
# ignore tab-completion request for __members__ or __methods__ and ignore meta property requests
if name.startswith("__"):
pass
elif name.startswith("_ipython_"):
pass
elif name.startswith("_repr_"):
pass
else:
# wait on thread to initialize the Spark session variables in global variable scope
thread_to_initialize_spark_session.join(timeout=None)
# now return attribute/function reference from actual Spark object
return getattr(globals()[self._spark_session_variable], name)


def initialize_spark_session():
"""Initialize Spark session and replace global variable placeholders with real Spark session object references."""
global spark, sc, sql, sqlContext, sqlCtx
spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext
sql = spark.sql
sqlContext = spark._wrapped
sqlCtx = sqlContext


def sql(query):
"""Placeholder function. When called will wait for Spark session to be initialized and call ``spark.sql(query)``"""
return spark.sql(query)


# placeholder objects for Spark session variables which are initialized in a background thread
spark = WaitingForSparkSessionToBeInitialized(global_variable_name='spark')
sc = WaitingForSparkSessionToBeInitialized(global_variable_name='sc')
sqlContext = WaitingForSparkSessionToBeInitialized(global_variable_name='sqlContext')
sqlCtx = WaitingForSparkSessionToBeInitialized(global_variable_name='sqlCtx')

thread_to_initialize_spark_session = Thread(target=initialize_spark_session)


def return_connection_info(connection_file, ip, response_addr):
Expand Down Expand Up @@ -45,18 +101,6 @@ def return_connection_info(connection_file, ip, response_addr):
arguments = vars(parser.parse_args())
connection_file = arguments['connection_file']
response_addr = arguments['response_address'] # Although argument uses dash, argparse converts to underscore.

# create a Spark session
spark = SparkSession.builder.getOrCreate()

# setup Spark session variables
sc = spark.sparkContext
sql = spark.sql

# setup Spark legacy variables for compatibility
sqlContext = spark._wrapped
sqlCtx = sqlContext

ip = "0.0.0.0"

# If the connection file doesn't exist, then we're using 'pull' or 'socket' mode - otherwise 'push' mode.
Expand All @@ -68,6 +112,9 @@ def return_connection_info(connection_file, ip, response_addr):
if response_addr:
return_connection_info(connection_file, ip, response_addr)

# start to initialize the Spark session in the background
thread_to_initialize_spark_session.start()

# launch the IPython kernel instance
embed_kernel(connection_file=connection_file, ip=ip)

Expand Down

0 comments on commit 6b0c2b7

Please sign in to comment.