Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 62 additions & 39 deletions experimental/ssh/internal/server/jupyter-init.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,6 @@ def wrapper(*args, **kwargs):
return wrapper


@_log_exceptions
def _setup_dedicated_session():
from dbruntime import UserNamespaceInitializer

_user_namespace_initializer = UserNamespaceInitializer.getOrCreate()
_entry_point = _user_namespace_initializer.get_spark_entry_point()
_globals = _user_namespace_initializer.get_namespace_globals()
for name, value in _globals.items():
print(f"Registering global: {name} = {value}")
if name not in globals():
globals()[name] = value

# 'display' from the runtime uses custom widgets that don't work in Jupyter.
# We use the IPython display instead (in combination with the html formatter for DataFrames).
globals()["display"] = ip_display


@_log_exceptions
def _register_runtime_hooks():
from dbruntime.monkey_patches import apply_dataframe_display_patch
Expand Down Expand Up @@ -167,7 +150,7 @@ def _register_common_magics():


@_log_exceptions
def _register_pip_magics(user_namespace_initializer: any, entry_point: any):
def _register_pip_magics():
"""Register the pip magic command parser with IPython."""
from dbruntime.DatasetInfo import UserNamespaceDict
from dbruntime.PipMagicOverrides import PipMagicOverrides
Expand All @@ -181,7 +164,15 @@ def _register_pip_magics(user_namespace_initializer: any, entry_point: any):
entry_point,
)
ip = get_ipython()
ip.register_magics(PipMagicOverrides(entry_point, globals["sc"]._conf, user_ns))

try:
# Older DBRs
pip_magic = PipMagicOverrides(entry_point, ip.user_ns["sc"]._conf, user_ns)
except Exception:
# Newer DBRs
pip_magic = PipMagicOverrides(entry_point, user_ns, ip)

ip.register_magics(pip_magic)


@_log_exceptions
Expand All @@ -198,34 +189,66 @@ def df_html(df: DataFrame) -> str:
html_formatter.for_type(DataFrame, df_html)


@_log_exceptions
def _setup_serverless_session():
import IPython
def _create_spark_session(builder_fn):
from databricks.connect import DatabricksSession

user_ns = getattr(IPython.get_ipython(), "user_ns", {})
existing_session = getattr(user_ns, "spark", None)
user_ns = get_ipython().user_ns
existing_session = user_ns.get("spark")
# Clear the existing local spark session, otherwise DatabricksSession will re-use it.
user_ns["spark"] = None
try:
# Clear the existing local spark session, otherwise DatabricksSession will re-use it.
user_ns["spark"] = None
globals()["spark"] = None
# DatabricksSession will use the existing env vars for the connection.
spark_session = DatabricksSession.builder.serverless(True).getOrCreate()
user_ns["spark"] = spark_session
globals()["spark"] = spark_session
except Exception as e:
return builder_fn(DatabricksSession.builder).getOrCreate()
except Exception:
user_ns["spark"] = existing_session
globals()["spark"] = existing_session
raise e
raise


def _initialize_spark(is_serverless: bool, existing_spark: any):
from pyspark.sql.session import SparkSession

# On serverless always initialize a new remote Databricks Connect session.
if is_serverless:
return _create_spark_session(lambda b: b.serverless(True))
# On dedicated or standard initialize a new remote session if the existing spark session is local.
if existing_spark is None or isinstance(existing_spark, SparkSession):
return _create_spark_session(
lambda b: b.remote(
host=os.environ["DATABRICKS_HOST"],
token=os.environ["DATABRICKS_TOKEN"],
cluster_id=os.environ["DATABRICKS_CLUSTER_ID"],
)
)
# Otherwise re-use the existing remote session.
return existing_spark


if os.environ.get("DATABRICKS_JUPYTER_SERVERLESS") == "true":
_setup_serverless_session()
else:
_setup_dedicated_session()
_register_pip_magics()
@_log_exceptions
def _setup_globals(is_serverless: bool):
from dbruntime import UserNamespaceInitializer

ns = UserNamespaceInitializer.getOrCreate()
ns_globals = ns.get_namespace_globals()
existing_spark = ns_globals.get("spark")
spark = _initialize_spark(is_serverless, existing_spark)
try:
ns.db_connection.spark_provider.set_spark(spark)
except Exception as e:
print(f"Error updating spark provider: {e}")
ns_globals["spark"] = spark
if spark is not None:
ns_globals["table"] = spark.table
ns_globals["sql"] = spark.sql
user_ns = get_ipython().user_ns
for name, value in ns_globals.items():
print(f"Registering global: {name} = {value}")
user_ns[name] = value
# 'display' from the runtime uses custom widgets that don't work in Jupyter.
# We use the IPython display instead (in combination with the html formatter for DataFrames).
user_ns["display"] = ip_display


_setup_globals(os.environ.get("DATABRICKS_JUPYTER_SERVERLESS") == "true")
_register_pip_magics()
_register_common_magics()
_register_formatters()
_register_runtime_hooks()